From 997374b88fe6b2ae5cb4dcaf47d78cb3eff09fc2 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 11 六月 2024 19:56:52 +0800
Subject: [PATCH] add ctc inference code (#1806)

---
 funasr/models/whisper/model.py |   91 +++++++++++++++++++++++++++++++--------------
 1 files changed, 62 insertions(+), 29 deletions(-)

diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py
index f09405a..8e9245a 100644
--- a/funasr/models/whisper/model.py
+++ b/funasr/models/whisper/model.py
@@ -13,37 +13,66 @@
 from funasr.register import tables
 
 
-
+@tables.register("model_classes", "Whisper-tiny.en")
+@tables.register("model_classes", "Whisper-tiny")
+@tables.register("model_classes", "Whisper-base.en")
+@tables.register("model_classes", "Whisper-base")
+@tables.register("model_classes", "Whisper-small.en")
+@tables.register("model_classes", "Whisper-small")
+@tables.register("model_classes", "Whisper-medium.en")
+@tables.register("model_classes", "Whisper-medium")
+@tables.register("model_classes", "Whisper-large-v1")
+@tables.register("model_classes", "Whisper-large-v2")
+@tables.register("model_classes", "Whisper-large-v3")
 @tables.register("model_classes", "WhisperWarp")
 class WhisperWarp(nn.Module):
-    def __init__(self, whisper_dims: dict, **kwargs):
+    def __init__(self, *args, **kwargs):
         super().__init__()
         hub = kwargs.get("hub", "funasr")
         if hub == "openai":
-            init_param_path = kwargs.get("init_param_path", "large-v3")
-            model = whisper.load_model(init_param_path)
+            model_or_path = kwargs.get("model_path", "Whisper-large-v3")
+            if model_or_path.startswith("Whisper-"):
+                model_or_path = model_or_path.replace("Whisper-", "")
+            model = whisper.load_model(model_or_path)
         else:
-            dims = whisper.model.ModelDimensions(**whisper_dims)
+            dims = kwargs.get("dims", {})
+            dims = whisper.model.ModelDimensions(**dims)
             model = whisper.model.Whisper(dims=dims)
-        
+
         self.model = model
-        
-    def forward(self, ):
+
+        self.encoder_output_size = self.model.dims.n_audio_state
+
+    def forward(
+        self,
+    ):
         pass
-    
-    def inference(self,
-                  data_in,
-                  data_lengths=None,
-                  key: list = None,
-                  tokenizer=None,
-                  frontend=None,
-                  **kwargs,
-                  ):
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
 
+        if frontend is None and not hasattr(self, "frontend"):
+            frontend_class = tables.frontend_classes.get("WhisperFrontend")
+            frontend = frontend_class(
+                n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
+            )
+            self.frontend = frontend
+        else:
+            frontend = frontend if frontend is not None else self.frontend
+
         meta_data = {}
-        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+        if (
+            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+        ):  # fbank
             speech, speech_lengths = data_in, data_lengths
             if len(speech.shape) < 3:
                 speech = speech[None, :, :]
@@ -52,13 +81,18 @@
         else:
             # extract fbank feats
             time1 = time.perf_counter()
-            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
-                                                            data_type=kwargs.get("data_type", "sound"),
-                                                            tokenizer=tokenizer)
+            audio_sample_list = load_audio_text_image_video(
+                data_in,
+                fs=frontend.fs if hasattr(frontend, "fs") else 16000,
+                audio_fs=kwargs.get("fs", 16000),
+                data_type=kwargs.get("data_type", "sound"),
+                tokenizer=tokenizer,
+            )
             time2 = time.perf_counter()
             meta_data["load_data"] = f"{time2 - time1:0.3f}"
-            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
-                                                   frontend=frontend)
+            speech, speech_lengths = extract_fbank(
+                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+            )
             time3 = time.perf_counter()
             meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
             frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
@@ -68,18 +102,17 @@
         speech = speech.to(device=kwargs["device"])[0, :, :]
         speech_lengths = speech_lengths.to(device=kwargs["device"])
 
-        # detect the spoken language
-        _, probs = self.model.detect_language(speech)
-        print(f"Detected language: {max(probs, key=probs.get)}")
+        # # detect the spoken language
+        # _, probs = self.model.detect_language(speech)
+        # print(f"Detected language: {max(probs, key=probs.get)}")
 
         # decode the audio
-        options = whisper.DecodingOptions(language=kwargs.get("language", None), fp16=False)
+        options = whisper.DecodingOptions(**kwargs.get("DecodingOptions", {}))
         result = whisper.decode(self.model, speech, options)
 
         results = []
         result_i = {"key": key[0], "text": result.text}
 
         results.append(result_i)
-    
+
         return results, meta_data
-    
\ No newline at end of file

--
Gitblit v1.9.1