From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update

---
 funasr/models/whisper/model.py |   17 +++++++++++++----
 1 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py
index 1eac2ff..35de1c9 100644
--- a/funasr/models/whisper/model.py
+++ b/funasr/models/whisper/model.py
@@ -24,7 +24,7 @@
 @tables.register("model_classes", "Whisper-large-v1")
 @tables.register("model_classes", "Whisper-large-v2")
 @tables.register("model_classes", "Whisper-large-v3")
-@tables.register("model_classes", "Whisper-WhisperWarp")
+@tables.register("model_classes", "WhisperWarp")
 class WhisperWarp(nn.Module):
     def __init__(self, *args, **kwargs):
         super().__init__()
@@ -35,11 +35,13 @@
                 model_or_path = model_or_path.replace("Whisper-", "")
             model = whisper.load_model(model_or_path)
         else:
-            whisper_dims = kwargs.get("whisper_dims", {})
-            dims = whisper.model.ModelDimensions(**whisper_dims)
+            dims = kwargs.get("dims", {})
+            dims = whisper.model.ModelDimensions(**dims)
             model = whisper.model.Whisper(dims=dims)
         
         self.model = model
+        
+        self.encoder_output_size = self.model.dims.n_audio_state
         
     def forward(self, ):
         pass
@@ -55,6 +57,13 @@
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
 
+        if frontend is None and not hasattr(self, "frontend"):
+            frontend_class = tables.frontend_classes.get("WhisperFrontend")
+            frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
+            self.frontend = frontend
+        else:
+            frontend = frontend if frontend is not None else self.frontend
+
         meta_data = {}
         if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
             speech, speech_lengths = data_in, data_lengths
@@ -65,7 +74,7 @@
         else:
             # extract fbank feats
             time1 = time.perf_counter()
-            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000),
                                                             data_type=kwargs.get("data_type", "sound"),
                                                             tokenizer=tokenizer)
             time2 = time.perf_counter()

--
Gitblit v1.9.1