From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages
---
funasr/__init__.py | 127 ++++++++++++++++++++++++++++++++++++++++++
1 files changed, 127 insertions(+), 0 deletions(-)
diff --git a/funasr/__init__.py b/funasr/__init__.py
index f297bc3..aab4289 100644
--- a/funasr/__init__.py
+++ b/funasr/__init__.py
@@ -1,8 +1,135 @@
"""Initialize funasr package."""
import os
+from pathlib import Path
+import torch
+import numpy as np
dirname = os.path.dirname(__file__)
version_file = os.path.join(dirname, "version.txt")
with open(version_file, "r") as f:
__version__ = f.read().strip()
+
+
+def prepare_model(
+ model: str = None,
+ # mode: str = None,
+ vad_model: str = None,
+ punc_model: str = None,
+ model_hub: str = "ms",
+ cache_dir: str = None,
+ **kwargs,
+):
+ if not Path(model).exists():
+ if model_hub == "ms" or model_hub == "modelscope":
+ try:
+ from modelscope.hub.snapshot_download import snapshot_download as download_tool
+ model = name_maps_ms[model] if model is not None else None
+ vad_model = name_maps_ms[vad_model] if vad_model is not None else None
+ punc_model = name_maps_ms[punc_model] if punc_model is not None else None
+ except:
+ raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
+ "\npip3 install -U modelscope\n" \
+ "For the users in China, you could install with the command:\n" \
+ "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+ elif model_hub == "hf" or model_hub == "huggingface":
+ download_tool = 0
+ else:
+ raise "model_hub must be on of ms or hf, but get {}".format(model_hub)
+ try:
+ model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None))
+ print("model have been downloaded to: {}".format(model))
+ except:
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
+ model)
+
+ if vad_model is not None and not Path(vad_model).exists():
+ vad_model = download_tool(vad_model, cache_dir=cache_dir)
+ print("model have been downloaded to: {}".format(vad_model))
+ if punc_model is not None and not Path(punc_model).exists():
+ punc_model = download_tool(punc_model, cache_dir=cache_dir)
+ print("model have been downloaded to: {}".format(punc_model))
+
+ # asr
+ kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"),
+ "asr_model_file": None if model is None else os.path.join(model, "model.pb"),
+ "asr_train_config": None if model is None else os.path.join(model, "config.yaml"),
+ })
+ mode = kwargs.get("mode", None)
+ if mode is None:
+ import json
+ json_file = os.path.join(model, 'configuration.json')
+ with open(json_file, 'r') as f:
+ config_data = json.load(f)
+ if config_data['task'] == "punctuation":
+ mode = config_data['model']['punc_model_config']['mode']
+ else:
+ mode = config_data['model']['model_config']['mode']
+ if vad_model is not None and "vad" not in mode:
+ mode = "paraformer_vad"
+ kwargs["mode"] = mode
+ # vad
+ kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"),
+ "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"),
+ "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"),
+ })
+ # punc
+ kwargs.update({
+ "punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"),
+ "punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"),
+ })
+
+
+ return model, vad_model, punc_model, kwargs
+
+name_maps_ms = {
+ "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+ "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
+ "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
+ "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
+ "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
+ "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+ "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
+ "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
+}
+
+def infer(task_name: str = "asr",
+ model: str = None,
+ # mode: str = None,
+ vad_model: str = None,
+ punc_model: str = None,
+ model_hub: str = "ms",
+ cache_dir: str = None,
+ **kwargs,
+ ):
+
+ model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
+ if task_name == "asr":
+ from funasr.bin.asr_inference_launch import inference_launch
+
+ inference_pipeline = inference_launch(**kwargs)
+ elif task_name == "":
+ pipeline = 1
+ elif task_name == "":
+ pipeline = 2
+ elif task_name == "":
+ pipeline = 2
+
+ def _infer_fn(input, **kwargs):
+ data_type = kwargs.get('data_type', 'sound')
+ data_path_and_name_and_type = [input, 'speech', data_type]
+ raw_inputs = None
+ if isinstance(input, torch.Tensor):
+ input = input.numpy()
+ if isinstance(input, np.ndarray):
+ data_path_and_name_and_type = None
+ raw_inputs = input
+
+
+
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
+
+ return _infer_fn
+
+if __name__ == '__main__':
+ pass
\ No newline at end of file
--
Gitblit v1.9.1