From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/whisper/model.py | 83 +++++++++++++++++++++++++++--------------
1 files changed, 54 insertions(+), 29 deletions(-)
diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py
index 1eac2ff..a332100 100644
--- a/funasr/models/whisper/model.py
+++ b/funasr/models/whisper/model.py
@@ -7,7 +7,10 @@
import torch.nn.functional as F
from torch import Tensor
from torch import nn
+
import whisper
+# import whisper_timestamped as whisper
+
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.register import tables
@@ -24,7 +27,7 @@
@tables.register("model_classes", "Whisper-large-v1")
@tables.register("model_classes", "Whisper-large-v2")
@tables.register("model_classes", "Whisper-large-v3")
-@tables.register("model_classes", "Whisper-WhisperWarp")
+@tables.register("model_classes", "WhisperWarp")
class WhisperWarp(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
@@ -35,28 +38,44 @@
model_or_path = model_or_path.replace("Whisper-", "")
model = whisper.load_model(model_or_path)
else:
- whisper_dims = kwargs.get("whisper_dims", {})
- dims = whisper.model.ModelDimensions(**whisper_dims)
+ dims = kwargs.get("dims", {})
+ dims = whisper.model.ModelDimensions(**dims)
model = whisper.model.Whisper(dims=dims)
-
+
self.model = model
-
- def forward(self, ):
+
+ self.encoder_output_size = self.model.dims.n_audio_state
+
+ def forward(
+ self,
+ ):
pass
-
- def inference(self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
+
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
+ if frontend is None and not hasattr(self, "frontend"):
+ frontend_class = tables.frontend_classes.get("WhisperFrontend")
+ frontend = frontend_class(
+ n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
+ )
+ self.frontend = frontend
+ else:
+ frontend = frontend if frontend is not None else self.frontend
+
meta_data = {}
- if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+ if (
+ isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+ ): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
@@ -65,13 +84,18 @@
else:
# extract fbank feats
time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer)
+ audio_sample_list = load_audio_text_image_video(
+ data_in,
+ fs=frontend.fs if hasattr(frontend, "fs") else 16000,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer,
+ )
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
- frontend=frontend)
+ speech, speech_lengths = extract_fbank(
+ audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+ )
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
@@ -81,18 +105,19 @@
speech = speech.to(device=kwargs["device"])[0, :, :]
speech_lengths = speech_lengths.to(device=kwargs["device"])
- # detect the spoken language
- _, probs = self.model.detect_language(speech)
- print(f"Detected language: {max(probs, key=probs.get)}")
+ # # detect the spoken language
+ # _, probs = self.model.detect_language(speech)
+ # print(f"Detected language: {max(probs, key=probs.get)}")
# decode the audio
- options = whisper.DecodingOptions(language=kwargs.get("language", None), fp16=False)
- result = whisper.decode(self.model, speech, options)
-
+ options = whisper.DecodingOptions(**kwargs.get("DecodingOptions", {}))
+
+ result = whisper.decode(self.model, speech, language='english')
+ # result = whisper.transcribe(self.model, speech)
+
results = []
result_i = {"key": key[0], "text": result.text}
results.append(result_i)
-
+
return results, meta_data
-
\ No newline at end of file
--
Gitblit v1.9.1