From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/whisper/model.py |   72 ++++++++++++++++++++++--------------
 1 files changed, 44 insertions(+), 28 deletions(-)

diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py
index 35de1c9..a332100 100644
--- a/funasr/models/whisper/model.py
+++ b/funasr/models/whisper/model.py
@@ -7,7 +7,10 @@
 import torch.nn.functional as F
 from torch import Tensor
 from torch import nn
+
 import whisper
+# import whisper_timestamped as whisper
+
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
 from funasr.register import tables
@@ -38,34 +41,41 @@
             dims = kwargs.get("dims", {})
             dims = whisper.model.ModelDimensions(**dims)
             model = whisper.model.Whisper(dims=dims)
-        
+
         self.model = model
-        
+
         self.encoder_output_size = self.model.dims.n_audio_state
-        
-    def forward(self, ):
+
+    def forward(
+        self,
+    ):
         pass
-    
-    def inference(self,
-                  data_in,
-                  data_lengths=None,
-                  key: list = None,
-                  tokenizer=None,
-                  frontend=None,
-                  **kwargs,
-                  ):
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
         if kwargs.get("batch_size", 1) > 1:
             raise NotImplementedError("batch decoding is not implemented")
 
         if frontend is None and not hasattr(self, "frontend"):
             frontend_class = tables.frontend_classes.get("WhisperFrontend")
-            frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
+            frontend = frontend_class(
+                n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
+            )
             self.frontend = frontend
         else:
             frontend = frontend if frontend is not None else self.frontend
 
         meta_data = {}
-        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+        if (
+            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+        ):  # fbank
             speech, speech_lengths = data_in, data_lengths
             if len(speech.shape) < 3:
                 speech = speech[None, :, :]
@@ -74,13 +84,18 @@
         else:
             # extract fbank feats
             time1 = time.perf_counter()
-            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000),
-                                                            data_type=kwargs.get("data_type", "sound"),
-                                                            tokenizer=tokenizer)
+            audio_sample_list = load_audio_text_image_video(
+                data_in,
+                fs=frontend.fs if hasattr(frontend, "fs") else 16000,
+                audio_fs=kwargs.get("fs", 16000),
+                data_type=kwargs.get("data_type", "sound"),
+                tokenizer=tokenizer,
+            )
             time2 = time.perf_counter()
             meta_data["load_data"] = f"{time2 - time1:0.3f}"
-            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
-                                                   frontend=frontend)
+            speech, speech_lengths = extract_fbank(
+                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+            )
             time3 = time.perf_counter()
             meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
             frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
@@ -90,18 +105,19 @@
         speech = speech.to(device=kwargs["device"])[0, :, :]
         speech_lengths = speech_lengths.to(device=kwargs["device"])
 
-        # detect the spoken language
-        _, probs = self.model.detect_language(speech)
-        print(f"Detected language: {max(probs, key=probs.get)}")
+        # # detect the spoken language
+        # _, probs = self.model.detect_language(speech)
+        # print(f"Detected language: {max(probs, key=probs.get)}")
 
         # decode the audio
-        options = whisper.DecodingOptions(language=kwargs.get("language", None), fp16=False)
-        result = whisper.decode(self.model, speech, options)
-
+        options = whisper.DecodingOptions(**kwargs.get("DecodingOptions", {}))
+        
+        result = whisper.decode(self.model, speech, language='english')
+        # result = whisper.transcribe(self.model, speech)
+        
         results = []
         result_i = {"key": key[0], "text": result.text}
 
         results.append(result_i)
-    
+
         return results, meta_data
-    
\ No newline at end of file

--
Gitblit v1.9.1