zhifu gao
2024-04-23 2ac38adbe5f4e1374a079e032ed4b504351a207c
funasr/models/whisper/model.py
@@ -24,7 +24,7 @@
@tables.register("model_classes", "Whisper-large-v1")
@tables.register("model_classes", "Whisper-large-v2")
@tables.register("model_classes", "Whisper-large-v3")
@tables.register("model_classes", "Whisper-WhisperWarp")
@tables.register("model_classes", "WhisperWarp")
class WhisperWarp(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
@@ -35,11 +35,13 @@
                model_or_path = model_or_path.replace("Whisper-", "")
            model = whisper.load_model(model_or_path)
        else:
            whisper_dims = kwargs.get("whisper_dims", {})
            dims = whisper.model.ModelDimensions(**whisper_dims)
            dims = kwargs.get("dims", {})
            dims = whisper.model.ModelDimensions(**dims)
            model = whisper.model.Whisper(dims=dims)
        
        self.model = model
        self.encoder_output_size = self.model.dims.n_audio_state
        
    def forward(self, ):
        pass
@@ -55,6 +57,13 @@
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        if frontend is None and not hasattr(self, "frontend"):
            frontend_class = tables.frontend_classes.get("WhisperFrontend")
            frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
            self.frontend = frontend
        else:
            frontend = frontend if frontend is not None else self.frontend
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths
@@ -65,7 +74,7 @@
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            time2 = time.perf_counter()
@@ -81,12 +90,12 @@
        speech = speech.to(device=kwargs["device"])[0, :, :]
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # detect the spoken language
        _, probs = self.model.detect_language(speech)
        print(f"Detected language: {max(probs, key=probs.get)}")
        # # detect the spoken language
        # _, probs = self.model.detect_language(speech)
        # print(f"Detected language: {max(probs, key=probs.get)}")
        # decode the audio
        options = whisper.DecodingOptions(language=kwargs.get("language", None), fp16=False)
        options = whisper.DecodingOptions(**kwargs.get("DecodingOptions", {}))
        result = whisper.decode(self.model, speech, options)
        results = []