| | |
| | | @tables.register("model_classes", "Whisper-large-v1") |
| | | @tables.register("model_classes", "Whisper-large-v2") |
| | | @tables.register("model_classes", "Whisper-large-v3") |
| | | @tables.register("model_classes", "Whisper-WhisperWarp") |
| | | @tables.register("model_classes", "WhisperWarp") |
| | | class WhisperWarp(nn.Module): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__() |
| | |
| | | model_or_path = model_or_path.replace("Whisper-", "") |
| | | model = whisper.load_model(model_or_path) |
| | | else: |
| | | whisper_dims = kwargs.get("whisper_dims", {}) |
| | | dims = whisper.model.ModelDimensions(**whisper_dims) |
| | | dims = kwargs.get("dims", {}) |
| | | dims = whisper.model.ModelDimensions(**dims) |
| | | model = whisper.model.Whisper(dims=dims) |
| | | |
| | | self.model = model |
| | | |
| | | self.encoder_output_size = self.model.dims.n_audio_state |
| | | |
| | | def forward(self, ): |
| | | pass |
| | |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | if frontend is None and not hasattr(self, "frontend"): |
| | | frontend_class = tables.frontend_classes.get("WhisperFrontend") |
| | | frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)) |
| | | self.frontend = frontend |
| | | else: |
| | | frontend = frontend if frontend is not None else self.frontend |
| | | |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | time2 = time.perf_counter() |
| | |
| | | speech = speech.to(device=kwargs["device"])[0, :, :] |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | # detect the spoken language |
| | | _, probs = self.model.detect_language(speech) |
| | | print(f"Detected language: {max(probs, key=probs.get)}") |
| | | # # detect the spoken language |
| | | # _, probs = self.model.detect_language(speech) |
| | | # print(f"Detected language: {max(probs, key=probs.get)}") |
| | | |
| | | # decode the audio |
| | | options = whisper.DecodingOptions(language=kwargs.get("language", None), fp16=False) |
| | | options = whisper.DecodingOptions(**kwargs.get("DecodingOptions", {})) |
| | | result = whisper.decode(self.model, speech, options) |
| | | |
| | | results = [] |