| | |
| | | import torch.nn.functional as F |
| | | from torch import Tensor |
| | | from torch import nn |
| | | |
| | | import whisper |
| | | |
| | | # import whisper_timestamped as whisper |
| | | |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | from funasr.register import tables |
| | |
| | | @tables.register("model_classes", "Whisper-large-v1") |
| | | @tables.register("model_classes", "Whisper-large-v2") |
| | | @tables.register("model_classes", "Whisper-large-v3") |
| | | @tables.register("model_classes", "Whisper-large-v3-turbo") |
| | | @tables.register("model_classes", "WhisperWarp") |
| | | class WhisperWarp(nn.Module): |
| | | def __init__(self, *args, **kwargs): |
| | |
| | | dims = kwargs.get("dims", {}) |
| | | dims = whisper.model.ModelDimensions(**dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | |
| | | |
| | | self.model = model |
| | | |
| | | |
| | | self.encoder_output_size = self.model.dims.n_audio_state |
| | | |
| | | def forward(self, ): |
| | | |
| | | def forward( |
| | | self, |
| | | ): |
| | | pass |
| | | |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | | frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)) |
| | | frontend = frontend_class( |
| | | n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True) |
| | | ) |
| | | self.frontend = frontend |
| | | else: |
| | | frontend = frontend if frontend is not None else self.frontend |
| | | |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | if ( |
| | | isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" |
| | | ): # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | audio_sample_list = load_audio_text_image_video( |
| | | data_in, |
| | | fs=frontend.fs if hasattr(frontend, "fs") else 16000, |
| | | audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer, |
| | | ) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | speech, speech_lengths = extract_fbank( |
| | | audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend |
| | | ) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10 |
| | |
| | | |
| | | # decode the audio |
| | | options = whisper.DecodingOptions(**kwargs.get("DecodingOptions", {})) |
| | | result = whisper.decode(self.model, speech, options) |
| | | |
| | | result = whisper.decode(self.model, speech, options=options) |
| | | # result = whisper.transcribe(self.model, speech) |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "text": result.text} |
| | | |
| | | results.append(result_i) |
| | | |
| | | |
| | | return results, meta_data |
| | | |