From 790bf549448c92f8a19ae1455ace15ff5d7a2e31 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 04 三月 2024 20:35:06 +0800
Subject: [PATCH] Dev gzf (#1422)
---
funasr/version.txt | 2
funasr/download/download_from_hub.py | 25 ++++++-
examples/industrial_data_pretraining/whisper/infer_from_openai.sh | 24 ++++++++
funasr/auto/auto_model.py | 4
funasr/bin/compute_audio_cmvn.py | 2
README_zh.md | 2
examples/industrial_data_pretraining/whisper/demo_from_openai.py | 17 +++++
README.md | 2
funasr/models/qwen_audio/model.py | 2
funasr/bin/train.py | 2
funasr/models/whisper/model.py | 23 ++++++-
funasr/auto/auto_frontend.py | 2
funasr/download/name_maps_from_hub.py | 15 +++++
13 files changed, 103 insertions(+), 19 deletions(-)
diff --git a/README.md b/README.md
index 6d4116c..d436d5e 100644
--- a/README.md
+++ b/README.md
@@ -115,7 +115,7 @@
hotword='榄旀惌')
print(res)
```
-Note: `model_hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download.
+Note: `hub`: represents the model repository, `ms` stands for selecting ModelScope download, `hf` stands for selecting Huggingface download.
### Speech Recognition (Streaming)
```python
diff --git a/README_zh.md b/README_zh.md
index f9e72ed..8a34c82 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -111,7 +111,7 @@
hotword='榄旀惌')
print(res)
```
-娉細`model_hub`锛氳〃绀烘ā鍨嬩粨搴擄紝`ms`涓洪�夋嫨modelscope涓嬭浇锛宍hf`涓洪�夋嫨huggingface涓嬭浇銆�
+娉細`hub`锛氳〃绀烘ā鍨嬩粨搴擄紝`ms`涓洪�夋嫨modelscope涓嬭浇锛宍hf`涓洪�夋嫨huggingface涓嬭浇銆�
### 瀹炴椂璇煶璇嗗埆
diff --git a/examples/industrial_data_pretraining/whisper/demo_from_openai.py b/examples/industrial_data_pretraining/whisper/demo_from_openai.py
new file mode 100644
index 0000000..0b88a95
--- /dev/null
+++ b/examples/industrial_data_pretraining/whisper/demo_from_openai.py
@@ -0,0 +1,17 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+# model = AutoModel(model="Whisper-small", hub="openai")
+# model = AutoModel(model="Whisper-medium", hub="openai")
+model = AutoModel(model="Whisper-large-v2", hub="openai")
+# model = AutoModel(model="Whisper-large-v3", hub="openai")
+
+res = model.generate(
+ language=None,
+ task="transcribe",
+ input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
+print(res)
diff --git a/examples/industrial_data_pretraining/whisper/infer_from_openai.sh b/examples/industrial_data_pretraining/whisper/infer_from_openai.sh
new file mode 100644
index 0000000..461d75e
--- /dev/null
+++ b/examples/industrial_data_pretraining/whisper/infer_from_openai.sh
@@ -0,0 +1,24 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
+
+output_dir="./outputs/debug"
+
+#model="Whisper-small"
+#model="Whisper-medium"
+model="Whisper-large-v2"
+#model="Whisper-large-v3"
+hub="openai"
+
+device="cuda:0" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
+
+python -m funasr.bin.inference \
+++model=${model} \
+++hub=${hub} \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++device="${device}" \
diff --git a/funasr/auto/auto_frontend.py b/funasr/auto/auto_frontend.py
index 35ea23f..b802b83 100644
--- a/funasr/auto/auto_frontend.py
+++ b/funasr/auto/auto_frontend.py
@@ -31,7 +31,7 @@
def __init__(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
- logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+ logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(**kwargs)
# build frontend
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index ec3c3f3..70d09df 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -143,7 +143,7 @@
def build_model(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
- logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+ logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
@@ -180,7 +180,7 @@
# build model
model_class = tables.model_classes.get(kwargs["model"])
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
+ model = model_class(**kwargs, **kwargs.get("model_conf", {}), vocab_size=vocab_size)
model.to(device)
# init_param
diff --git a/funasr/bin/compute_audio_cmvn.py b/funasr/bin/compute_audio_cmvn.py
index 4561bec..ffad652 100644
--- a/funasr/bin/compute_audio_cmvn.py
+++ b/funasr/bin/compute_audio_cmvn.py
@@ -18,7 +18,7 @@
assert "model" in kwargs
if "model_conf" not in kwargs:
- logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+ logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 569757a..3c93371 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -35,7 +35,7 @@
assert "model" in kwargs
if "model_conf" not in kwargs:
- logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+ logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms")))
kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index b4253cd..4f07daa 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -2,13 +2,20 @@
import json
from omegaconf import OmegaConf
-from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf
+from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf, name_maps_openai
def download_model(**kwargs):
- model_hub = kwargs.get("model_hub", "ms")
- if model_hub == "ms":
+ hub = kwargs.get("hub", "ms")
+ if hub == "ms":
kwargs = download_from_ms(**kwargs)
+ elif hub == "hf":
+ pass
+ elif hub == "openai":
+ model_or_path = kwargs.get("model")
+ if model_or_path in name_maps_openai:
+ model_or_path = name_maps_openai[model_or_path]
+ kwargs["model_path"] = model_or_path
return kwargs
@@ -18,7 +25,13 @@
model_or_path = name_maps_ms[model_or_path]
model_revision = kwargs.get("model_revision")
if not os.path.exists(model_or_path):
- model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("check_latest", True))
+ try:
+ model_or_path = get_or_download_model_dir(model_or_path, model_revision,
+ is_training=kwargs.get("is_training"),
+ check_latest=kwargs.get("check_latest", True))
+ except Exception as e:
+ print(f"Download: {model_or_path} failed!: {e}")
+
kwargs["model_path"] = model_or_path
if os.path.exists(os.path.join(model_or_path, "configuration.json")):
@@ -50,7 +63,9 @@
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
- return OmegaConf.to_container(kwargs, resolve=True)
+ if isinstance(kwargs, OmegaConf):
+ kwargs = OmegaConf.to_container(kwargs, resolve=True)
+ return kwargs
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
diff --git a/funasr/download/name_maps_from_hub.py b/funasr/download/name_maps_from_hub.py
index fe493a7..e1bc295 100644
--- a/funasr/download/name_maps_from_hub.py
+++ b/funasr/download/name_maps_from_hub.py
@@ -13,4 +13,19 @@
name_maps_hf = {
+}
+
+name_maps_openai = {
+ "Whisper-tiny.en": "tiny.en",
+ "Whisper-tiny": "tiny",
+ "Whisper-base.en": "base.en",
+ "Whisper-base": "base",
+ "Whisper-small.en": "small.en",
+ "Whisper-small": "small",
+ "Whisper-medium.en": "medium.en",
+ "Whisper-medium": "medium",
+ "Whisper-large-v1": "large-v1",
+ "Whisper-large-v2": "large-v2",
+ "Whisper-large-v3": "large-v3",
+ "Whisper-large": "large",
}
\ No newline at end of file
diff --git a/funasr/models/qwen_audio/model.py b/funasr/models/qwen_audio/model.py
index f09405a..805234b 100644
--- a/funasr/models/qwen_audio/model.py
+++ b/funasr/models/qwen_audio/model.py
@@ -14,7 +14,7 @@
-@tables.register("model_classes", "WhisperWarp")
+@tables.register("model_classes", "QwenAudioWarp")
class WhisperWarp(nn.Module):
def __init__(self, whisper_dims: dict, **kwargs):
super().__init__()
diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py
index f09405a..1eac2ff 100644
--- a/funasr/models/whisper/model.py
+++ b/funasr/models/whisper/model.py
@@ -13,16 +13,29 @@
from funasr.register import tables
-
-@tables.register("model_classes", "WhisperWarp")
+@tables.register("model_classes", "Whisper-tiny.en")
+@tables.register("model_classes", "Whisper-tiny")
+@tables.register("model_classes", "Whisper-base.en")
+@tables.register("model_classes", "Whisper-base")
+@tables.register("model_classes", "Whisper-small.en")
+@tables.register("model_classes", "Whisper-small")
+@tables.register("model_classes", "Whisper-medium.en")
+@tables.register("model_classes", "Whisper-medium")
+@tables.register("model_classes", "Whisper-large-v1")
+@tables.register("model_classes", "Whisper-large-v2")
+@tables.register("model_classes", "Whisper-large-v3")
+@tables.register("model_classes", "Whisper-WhisperWarp")
class WhisperWarp(nn.Module):
- def __init__(self, whisper_dims: dict, **kwargs):
+ def __init__(self, *args, **kwargs):
super().__init__()
hub = kwargs.get("hub", "funasr")
if hub == "openai":
- init_param_path = kwargs.get("init_param_path", "large-v3")
- model = whisper.load_model(init_param_path)
+ model_or_path = kwargs.get("model_path", "Whisper-large-v3")
+ if model_or_path.startswith("Whisper-"):
+ model_or_path = model_or_path.replace("Whisper-", "")
+ model = whisper.load_model(model_or_path)
else:
+ whisper_dims = kwargs.get("whisper_dims", {})
dims = whisper.model.ModelDimensions(**whisper_dims)
model = whisper.model.Whisper(dims=dims)
diff --git a/funasr/version.txt b/funasr/version.txt
index 59e9e60..bb83058 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-1.0.11
+1.0.12
--
Gitblit v1.9.1