From 9595a9432fadfbdacd4e6897f6b9a83957699558 Mon Sep 17 00:00:00 2001
From: seanzhang-zhichen <74812416+seanzhang-zhichen@users.noreply.github.com>
Date: 星期二, 05 三月 2024 17:42:14 +0800
Subject: [PATCH] modify paraformer train doc (#1427)
---
funasr/models/whisper/model.py | 32 ++++++++++++++++++++++++++------
1 files changed, 26 insertions(+), 6 deletions(-)
diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py
index f09405a..73d70d7 100644
--- a/funasr/models/whisper/model.py
+++ b/funasr/models/whisper/model.py
@@ -13,17 +13,30 @@
from funasr.register import tables
-
+@tables.register("model_classes", "Whisper-tiny.en")
+@tables.register("model_classes", "Whisper-tiny")
+@tables.register("model_classes", "Whisper-base.en")
+@tables.register("model_classes", "Whisper-base")
+@tables.register("model_classes", "Whisper-small.en")
+@tables.register("model_classes", "Whisper-small")
+@tables.register("model_classes", "Whisper-medium.en")
+@tables.register("model_classes", "Whisper-medium")
+@tables.register("model_classes", "Whisper-large-v1")
+@tables.register("model_classes", "Whisper-large-v2")
+@tables.register("model_classes", "Whisper-large-v3")
@tables.register("model_classes", "WhisperWarp")
class WhisperWarp(nn.Module):
- def __init__(self, whisper_dims: dict, **kwargs):
+ def __init__(self, *args, **kwargs):
super().__init__()
hub = kwargs.get("hub", "funasr")
if hub == "openai":
- init_param_path = kwargs.get("init_param_path", "large-v3")
- model = whisper.load_model(init_param_path)
+ model_or_path = kwargs.get("model_path", "Whisper-large-v3")
+ if model_or_path.startswith("Whisper-"):
+ model_or_path = model_or_path.replace("Whisper-", "")
+ model = whisper.load_model(model_or_path)
else:
- dims = whisper.model.ModelDimensions(**whisper_dims)
+ dims = kwargs.get("dims", {})
+ dims = whisper.model.ModelDimensions(**dims)
model = whisper.model.Whisper(dims=dims)
self.model = model
@@ -42,6 +55,13 @@
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
+ if frontend is None and not hasattr(self, "frontend"):
+ frontend_class = tables.frontend_classes.get("WhisperFrontend")
+ frontend = frontend_class(n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True))
+ self.frontend = frontend
+ else:
+ frontend = frontend if frontend is not None else self.frontend
+
meta_data = {}
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
speech, speech_lengths = data_in, data_lengths
@@ -52,7 +72,7 @@
else:
# extract fbank feats
time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs if hasattr(frontend, "fs") else 16000, audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer)
time2 = time.perf_counter()
--
Gitblit v1.9.1