From 790bf549448c92f8a19ae1455ace15ff5d7a2e31 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 04 三月 2024 20:35:06 +0800
Subject: [PATCH] Dev gzf (#1422)

---
 funasr/models/whisper/model.py |   23 ++++++++++++++++++-----
 1 files changed, 18 insertions(+), 5 deletions(-)

diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py
index f09405a..1eac2ff 100644
--- a/funasr/models/whisper/model.py
+++ b/funasr/models/whisper/model.py
@@ -13,16 +13,29 @@
 from funasr.register import tables
 
 
-
-@tables.register("model_classes", "WhisperWarp")
+@tables.register("model_classes", "Whisper-tiny.en")
+@tables.register("model_classes", "Whisper-tiny")
+@tables.register("model_classes", "Whisper-base.en")
+@tables.register("model_classes", "Whisper-base")
+@tables.register("model_classes", "Whisper-small.en")
+@tables.register("model_classes", "Whisper-small")
+@tables.register("model_classes", "Whisper-medium.en")
+@tables.register("model_classes", "Whisper-medium")
+@tables.register("model_classes", "Whisper-large-v1")
+@tables.register("model_classes", "Whisper-large-v2")
+@tables.register("model_classes", "Whisper-large-v3")
+@tables.register("model_classes", "Whisper-WhisperWarp")
 class WhisperWarp(nn.Module):
-    def __init__(self, whisper_dims: dict, **kwargs):
+    def __init__(self, *args, **kwargs):
         super().__init__()
         hub = kwargs.get("hub", "funasr")
         if hub == "openai":
-            init_param_path = kwargs.get("init_param_path", "large-v3")
-            model = whisper.load_model(init_param_path)
+            model_or_path = kwargs.get("model_path", "Whisper-large-v3")
+            if model_or_path.startswith("Whisper-"):
+                model_or_path = model_or_path.replace("Whisper-", "")
+            model = whisper.load_model(model_or_path)
         else:
+            whisper_dims = kwargs.get("whisper_dims", {})
             dims = whisper.model.ModelDimensions(**whisper_dims)
             model = whisper.model.Whisper(dims=dims)
         

--
Gitblit v1.9.1