From 244c033fbaeae15faf8b0351365bdb7607b2e2bb Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 十一月 2023 15:19:53 +0800
Subject: [PATCH] python cli
---
funasr/version.txt | 2
funasr/bin/inference_cli.py | 139 ++++++++
funasr/bin/argument.py | 262 ++++++++++++++++
funasr/utils/download_and_prepare_model.py | 93 +++++
setup.py | 3
README_zh.md | 9
funasr/__init__.py | 127 -------
funasr/bin/asr_inference_launch.py | 258 ---------------
README.md | 9
9 files changed, 522 insertions(+), 380 deletions(-)
diff --git a/README.md b/README.md
index f73c0ca..001ce3f 100644
--- a/README.md
+++ b/README.md
@@ -76,6 +76,15 @@
FunASR supports inference and fine-tuning of models trained on industrial data for tens of thousands of hours. For more details, please refer to [modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html). It also supports training and fine-tuning of models on academic standard datasets. For more information, please refer to [egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html).
Below is a quick start tutorial. Test audio files ([Mandarin](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav), [English]()).
+
+### Command-line usage
+
+```shell
+funasr --model paraformer-zh asr_example_zh.wav
+```
+
+Notes: Support recognition of single audio file, as well as file list in Kaldi-style wav.scp format: `wav_id wav_pat`
+
### Speech Recognition (Non-streaming)
```python
from funasr import infer
diff --git a/README_zh.md b/README_zh.md
index 554c0b6..504c715 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -70,6 +70,15 @@
FunASR鏀寔鏁颁竾灏忔椂宸ヤ笟鏁版嵁璁粌鐨勬ā鍨嬬殑鎺ㄧ悊鍜屽井璋冿紝璇︾粏淇℃伅鍙互鍙傞槄锛圼modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)锛夛紱涔熸敮鎸佸鏈爣鍑嗘暟鎹泦妯″瀷鐨勮缁冨拰寰皟锛岃缁嗕俊鎭彲浠ュ弬闃咃紙[egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html)锛夈��
涓嬮潰涓哄揩閫熶笂鎵嬫暀绋嬶紝娴嬭瘯闊抽锛圼涓枃](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav)锛孾鑻辨枃]()锛�
+
+### 鍙墽琛屽懡浠よ
+
+```shell
+funasr --model paraformer-zh asr_example_zh.wav
+```
+
+娉細鏀寔鍗曟潯闊抽鏂囦欢璇嗗埆锛屼篃鏀寔鏂囦欢鍒楄〃锛屽垪琛ㄤ负kaldi椋庢牸wav.scp锛歚wav_id wav_path`
+
### 闈炲疄鏃惰闊宠瘑鍒�
```python
from funasr import infer
diff --git a/funasr/__init__.py b/funasr/__init__.py
index aab4289..d0b7aa5 100644
--- a/funasr/__init__.py
+++ b/funasr/__init__.py
@@ -1,135 +1,10 @@
"""Initialize funasr package."""
import os
-from pathlib import Path
-import torch
-import numpy as np
dirname = os.path.dirname(__file__)
version_file = os.path.join(dirname, "version.txt")
with open(version_file, "r") as f:
__version__ = f.read().strip()
-
-def prepare_model(
- model: str = None,
- # mode: str = None,
- vad_model: str = None,
- punc_model: str = None,
- model_hub: str = "ms",
- cache_dir: str = None,
- **kwargs,
-):
- if not Path(model).exists():
- if model_hub == "ms" or model_hub == "modelscope":
- try:
- from modelscope.hub.snapshot_download import snapshot_download as download_tool
- model = name_maps_ms[model] if model is not None else None
- vad_model = name_maps_ms[vad_model] if vad_model is not None else None
- punc_model = name_maps_ms[punc_model] if punc_model is not None else None
- except:
- raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
- "\npip3 install -U modelscope\n" \
- "For the users in China, you could install with the command:\n" \
- "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
- elif model_hub == "hf" or model_hub == "huggingface":
- download_tool = 0
- else:
- raise "model_hub must be on of ms or hf, but get {}".format(model_hub)
- try:
- model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None))
- print("model have been downloaded to: {}".format(model))
- except:
- raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
- model)
-
- if vad_model is not None and not Path(vad_model).exists():
- vad_model = download_tool(vad_model, cache_dir=cache_dir)
- print("model have been downloaded to: {}".format(vad_model))
- if punc_model is not None and not Path(punc_model).exists():
- punc_model = download_tool(punc_model, cache_dir=cache_dir)
- print("model have been downloaded to: {}".format(punc_model))
-
- # asr
- kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"),
- "asr_model_file": None if model is None else os.path.join(model, "model.pb"),
- "asr_train_config": None if model is None else os.path.join(model, "config.yaml"),
- })
- mode = kwargs.get("mode", None)
- if mode is None:
- import json
- json_file = os.path.join(model, 'configuration.json')
- with open(json_file, 'r') as f:
- config_data = json.load(f)
- if config_data['task'] == "punctuation":
- mode = config_data['model']['punc_model_config']['mode']
- else:
- mode = config_data['model']['model_config']['mode']
- if vad_model is not None and "vad" not in mode:
- mode = "paraformer_vad"
- kwargs["mode"] = mode
- # vad
- kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"),
- "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"),
- "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"),
- })
- # punc
- kwargs.update({
- "punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"),
- "punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"),
- })
-
-
- return model, vad_model, punc_model, kwargs
-
-name_maps_ms = {
- "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
- "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
- "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
- "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
- "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
- "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
- "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
-}
-
-def infer(task_name: str = "asr",
- model: str = None,
- # mode: str = None,
- vad_model: str = None,
- punc_model: str = None,
- model_hub: str = "ms",
- cache_dir: str = None,
- **kwargs,
- ):
-
- model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
- if task_name == "asr":
- from funasr.bin.asr_inference_launch import inference_launch
-
- inference_pipeline = inference_launch(**kwargs)
- elif task_name == "":
- pipeline = 1
- elif task_name == "":
- pipeline = 2
- elif task_name == "":
- pipeline = 2
-
- def _infer_fn(input, **kwargs):
- data_type = kwargs.get('data_type', 'sound')
- data_path_and_name_and_type = [input, 'speech', data_type]
- raw_inputs = None
- if isinstance(input, torch.Tensor):
- input = input.numpy()
- if isinstance(input, np.ndarray):
- data_path_and_name_and_type = None
- raw_inputs = input
-
-
-
- return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
-
- return _infer_fn
-
-if __name__ == '__main__':
- pass
\ No newline at end of file
+from funasr.bin.inference_cli import infer
\ No newline at end of file
diff --git a/funasr/bin/argument.py b/funasr/bin/argument.py
new file mode 100644
index 0000000..0ea4ac9
--- /dev/null
+++ b/funasr/bin/argument.py
@@ -0,0 +1,262 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import sys
+
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import config_argparse
+import argparse
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="ASR Decoding",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, default=None)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=1,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument(
+ "--njob",
+ type=int,
+ default=1,
+ help="The number of jobs for each gpu",
+ )
+ parser.add_argument(
+ "--gpuid_list",
+ type=str,
+ default="",
+ help="The visible gpus",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=False,
+ action="append",
+ )
+ group.add_argument("--key_file", type=str_or_none)
+ parser.add_argument(
+ "--hotword",
+ type=str_or_none,
+ default=None,
+ help="hotword file path or hotwords seperated by space"
+ )
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+ group.add_argument(
+ "--mc",
+ type=bool,
+ default=False,
+ help="MultiChannel input",
+ )
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--vad_infer_config",
+ type=str,
+ help="VAD infer configuration",
+ )
+ group.add_argument(
+ "--vad_model_file",
+ type=str,
+ help="VAD model parameter file",
+ )
+ group.add_argument(
+ "--punc_infer_config",
+ type=str,
+ help="PUNC infer configuration",
+ )
+ group.add_argument(
+ "--punc_model_file",
+ type=str,
+ help="PUNC model parameter file",
+ )
+ group.add_argument(
+ "--cmvn_file",
+ type=str,
+ help="Global CMVN file",
+ )
+ group.add_argument(
+ "--asr_train_config",
+ type=str,
+ help="ASR training configuration",
+ )
+ group.add_argument(
+ "--asr_model_file",
+ type=str,
+ help="ASR model parameter file",
+ )
+ group.add_argument(
+ "--sv_model_file",
+ type=str,
+ help="SV model parameter file",
+ )
+ group.add_argument(
+ "--lm_train_config",
+ type=str,
+ help="LM training configuration",
+ )
+ group.add_argument(
+ "--lm_file",
+ type=str,
+ help="LM parameter file",
+ )
+ group.add_argument(
+ "--word_lm_train_config",
+ type=str,
+ help="Word LM training configuration",
+ )
+ group.add_argument(
+ "--word_lm_file",
+ type=str,
+ help="Word LM parameter file",
+ )
+ group.add_argument(
+ "--ngram_file",
+ type=str,
+ help="N-gram parameter file",
+ )
+ group.add_argument(
+ "--model_tag",
+ type=str,
+ help="Pretrained model tag. If specify this option, *_train_config and "
+ "*_file will be overwritten",
+ )
+ group.add_argument(
+ "--beam_search_config",
+ default={},
+ help="The keyword arguments for transducer beam search.",
+ )
+
+ group = parser.add_argument_group("Beam-search related")
+ group.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
+ group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+ group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+ group.add_argument(
+ "--maxlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain max output length. "
+ "If maxlenratio=0.0 (default), it uses a end-detect "
+ "function "
+ "to automatically find maximum hypothesis lengths."
+ "If maxlenratio<0.0, its absolute value is interpreted"
+ "as a constant max output length",
+ )
+ group.add_argument(
+ "--minlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain min output length",
+ )
+ group.add_argument(
+ "--ctc_weight",
+ type=float,
+ default=0.0,
+ help="CTC weight in joint decoding",
+ )
+ group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+ group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+ group.add_argument("--streaming", type=str2bool, default=False)
+ group.add_argument("--fake_streaming", type=str2bool, default=False)
+ group.add_argument("--full_utt", type=str2bool, default=False)
+ group.add_argument("--chunk_size", type=int, default=16)
+ group.add_argument("--left_context", type=int, default=16)
+ group.add_argument("--right_context", type=int, default=0)
+ group.add_argument(
+ "--display_partial_hypotheses",
+ type=bool,
+ default=False,
+ help="Whether to display partial hypotheses during chunk-by-chunk inference.",
+ )
+
+ group = parser.add_argument_group("Dynamic quantization related")
+ group.add_argument(
+ "--quantize_asr_model",
+ type=bool,
+ default=False,
+ help="Apply dynamic quantization to ASR model.",
+ )
+ group.add_argument(
+ "--quantize_modules",
+ nargs="*",
+ default=None,
+ help="""Module names to apply dynamic quantization on.
+ The module names are provided as a list, where each name is separated
+ by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
+ Each specified name should be an attribute of 'torch.nn', e.g.:
+ torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
+ )
+ group.add_argument(
+ "--quantize_dtype",
+ type=str,
+ default="qint8",
+ choices=["float16", "qint8"],
+ help="Dtype for dynamic quantization.",
+ )
+
+ group = parser.add_argument_group("Text converter related")
+ group.add_argument(
+ "--token_type",
+ type=str_or_none,
+ default=None,
+ choices=["char", "bpe", None],
+ help="The token type for ASR model. "
+ "If not given, refers from the training args",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model path of sentencepiece. "
+ "If not given, refers from the training args",
+ )
+ group.add_argument("--token_num_relax", type=int, default=1, help="")
+ group.add_argument("--decoding_ind", type=int, default=0, help="")
+ group.add_argument("--decoding_mode", type=str, default="model1", help="")
+ group.add_argument(
+ "--ctc_weight2",
+ type=float,
+ default=0.0,
+ help="CTC weight in joint decoding",
+ )
+ return parser
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index e93d740..e1a32c5 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -675,11 +675,13 @@
beg_idx = end_idx
batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
batch = to_device(batch, device=device)
- # print("batch: ", speech_j.shape[0])
+
beg_asr = time.time()
results = speech2text(**batch)
end_asr = time.time()
- # print("time cost asr: ", end_asr - beg_asr)
+ if speech2text.device != "cpu":
+ print("batch: ", speech_j.shape[0])
+ print("time cost asr: ", end_asr - beg_asr)
if len(results) < 1:
results = [["", [], [], [], [], [], []]]
@@ -2218,259 +2220,9 @@
logging.info("Unknown decoding mode: {}".format(mode))
return None
-
-def get_parser():
- parser = config_argparse.ArgumentParser(
- description="ASR Decoding",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter,
- )
-
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
- parser.add_argument(
- "--log_level",
- type=lambda x: x.upper(),
- default="INFO",
- choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
- help="The verbose level of logging",
- )
-
- parser.add_argument("--output_dir", type=str, required=True)
- parser.add_argument(
- "--ngpu",
- type=int,
- default=0,
- help="The number of gpus. 0 indicates CPU mode",
- )
- parser.add_argument(
- "--njob",
- type=int,
- default=1,
- help="The number of jobs for each gpu",
- )
- parser.add_argument(
- "--gpuid_list",
- type=str,
- default="",
- help="The visible gpus",
- )
- parser.add_argument("--seed", type=int, default=0, help="Random seed")
- parser.add_argument(
- "--dtype",
- default="float32",
- choices=["float16", "float32", "float64"],
- help="Data type",
- )
- parser.add_argument(
- "--num_workers",
- type=int,
- default=1,
- help="The number of workers used for DataLoader",
- )
-
- group = parser.add_argument_group("Input data related")
- group.add_argument(
- "--data_path_and_name_and_type",
- type=str2triple_str,
- required=True,
- action="append",
- )
- group.add_argument("--key_file", type=str_or_none)
- parser.add_argument(
- "--hotword",
- type=str_or_none,
- default=None,
- help="hotword file path or hotwords seperated by space"
- )
- group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
- group.add_argument(
- "--mc",
- type=bool,
- default=False,
- help="MultiChannel input",
- )
-
- group = parser.add_argument_group("The model configuration related")
- group.add_argument(
- "--vad_infer_config",
- type=str,
- help="VAD infer configuration",
- )
- group.add_argument(
- "--vad_model_file",
- type=str,
- help="VAD model parameter file",
- )
- group.add_argument(
- "--punc_infer_config",
- type=str,
- help="PUNC infer configuration",
- )
- group.add_argument(
- "--punc_model_file",
- type=str,
- help="PUNC model parameter file",
- )
- group.add_argument(
- "--cmvn_file",
- type=str,
- help="Global CMVN file",
- )
- group.add_argument(
- "--asr_train_config",
- type=str,
- help="ASR training configuration",
- )
- group.add_argument(
- "--asr_model_file",
- type=str,
- help="ASR model parameter file",
- )
- group.add_argument(
- "--sv_model_file",
- type=str,
- help="SV model parameter file",
- )
- group.add_argument(
- "--lm_train_config",
- type=str,
- help="LM training configuration",
- )
- group.add_argument(
- "--lm_file",
- type=str,
- help="LM parameter file",
- )
- group.add_argument(
- "--word_lm_train_config",
- type=str,
- help="Word LM training configuration",
- )
- group.add_argument(
- "--word_lm_file",
- type=str,
- help="Word LM parameter file",
- )
- group.add_argument(
- "--ngram_file",
- type=str,
- help="N-gram parameter file",
- )
- group.add_argument(
- "--model_tag",
- type=str,
- help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
- )
- group.add_argument(
- "--beam_search_config",
- default={},
- help="The keyword arguments for transducer beam search.",
- )
-
- group = parser.add_argument_group("Beam-search related")
- group.add_argument(
- "--batch_size",
- type=int,
- default=1,
- help="The batch size for inference",
- )
- group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=20, help="Beam size")
- group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
- group.add_argument(
- "--maxlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain max output length. "
- "If maxlenratio=0.0 (default), it uses a end-detect "
- "function "
- "to automatically find maximum hypothesis lengths."
- "If maxlenratio<0.0, its absolute value is interpreted"
- "as a constant max output length",
- )
- group.add_argument(
- "--minlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain min output length",
- )
- group.add_argument(
- "--ctc_weight",
- type=float,
- default=0.0,
- help="CTC weight in joint decoding",
- )
- group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
- group.add_argument("--streaming", type=str2bool, default=False)
- group.add_argument("--fake_streaming", type=str2bool, default=False)
- group.add_argument("--full_utt", type=str2bool, default=False)
- group.add_argument("--chunk_size", type=int, default=16)
- group.add_argument("--left_context", type=int, default=16)
- group.add_argument("--right_context", type=int, default=0)
- group.add_argument(
- "--display_partial_hypotheses",
- type=bool,
- default=False,
- help="Whether to display partial hypotheses during chunk-by-chunk inference.",
- )
-
- group = parser.add_argument_group("Dynamic quantization related")
- group.add_argument(
- "--quantize_asr_model",
- type=bool,
- default=False,
- help="Apply dynamic quantization to ASR model.",
- )
- group.add_argument(
- "--quantize_modules",
- nargs="*",
- default=None,
- help="""Module names to apply dynamic quantization on.
- The module names are provided as a list, where each name is separated
- by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
- Each specified name should be an attribute of 'torch.nn', e.g.:
- torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
- )
- group.add_argument(
- "--quantize_dtype",
- type=str,
- default="qint8",
- choices=["float16", "qint8"],
- help="Dtype for dynamic quantization.",
- )
-
- group = parser.add_argument_group("Text converter related")
- group.add_argument(
- "--token_type",
- type=str_or_none,
- default=None,
- choices=["char", "bpe", None],
- help="The token type for ASR model. "
- "If not given, refers from the training args",
- )
- group.add_argument(
- "--bpemodel",
- type=str_or_none,
- default=None,
- help="The model path of sentencepiece. "
- "If not given, refers from the training args",
- )
- group.add_argument("--token_num_relax", type=int, default=1, help="")
- group.add_argument("--decoding_ind", type=int, default=0, help="")
- group.add_argument("--decoding_mode", type=str, default="model1", help="")
- group.add_argument(
- "--ctc_weight2",
- type=float,
- default=0.0,
- help="CTC weight in joint decoding",
- )
- return parser
-
-
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
+ from funasr.bin.argument import get_parser
parser = get_parser()
parser.add_argument(
"--mode",
diff --git a/funasr/bin/inference_cli.py b/funasr/bin/inference_cli.py
new file mode 100644
index 0000000..f4c66f1
--- /dev/null
+++ b/funasr/bin/inference_cli.py
@@ -0,0 +1,139 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import os
+
+import logging
+import torch
+import numpy as np
+from funasr.utils.download_and_prepare_model import prepare_model
+
+from funasr.utils.types import str2bool
+
+def infer(task_name: str = "asr",
+ model: str = None,
+ # mode: str = None,
+ vad_model: str = None,
+ disable_vad: bool = False,
+ punc_model: str = None,
+ disable_punc: bool = False,
+ model_hub: str = "ms",
+ cache_dir: str = None,
+ **kwargs,
+ ):
+
+ # set logging messages
+ logging.basicConfig(
+ level=logging.ERROR,
+ )
+
+ model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
+ if task_name == "asr":
+ from funasr.bin.asr_inference_launch import inference_launch
+
+ inference_pipeline = inference_launch(**kwargs)
+ elif task_name == "":
+ pipeline = 1
+ elif task_name == "":
+ pipeline = 2
+ elif task_name == "":
+ pipeline = 2
+
+ def _infer_fn(input, **kwargs):
+ data_type = kwargs.get('data_type', 'sound')
+ data_path_and_name_and_type = [input, 'speech', data_type]
+ raw_inputs = None
+ if isinstance(input, torch.Tensor):
+ input = input.numpy()
+ if isinstance(input, np.ndarray):
+ data_path_and_name_and_type = None
+ raw_inputs = input
+
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
+
+ return _infer_fn
+
+
+def main(cmd=None):
+ # print(get_commandline_args(), file=sys.stderr)
+ from funasr.bin.argument import get_parser
+
+ parser = get_parser()
+ parser.add_argument('input', help='input file to transcribe')
+ parser.add_argument(
+ "--task_name",
+ type=str,
+ default="asr",
+ help="The decoding mode",
+ )
+ parser.add_argument(
+ "-m",
+ "--model",
+ type=str,
+ default="paraformer-zh",
+ help="The asr mode name",
+ )
+ parser.add_argument(
+ "-v",
+ "--vad_model",
+ type=str,
+ default="fsmn-vad",
+ help="vad model name",
+ )
+ parser.add_argument(
+ "-dv",
+ "--disable_vad",
+ type=str2bool,
+ default=False,
+ help="",
+ )
+ parser.add_argument(
+ "-p",
+ "--punc_model",
+ type=str,
+ default="ct-punc",
+ help="",
+ )
+ parser.add_argument(
+ "-dp",
+ "--disable_punc",
+ type=str2bool,
+ default=False,
+ help="",
+ )
+ parser.add_argument(
+ "--batch_size_token",
+ type=int,
+ default=5000,
+ help="",
+ )
+ parser.add_argument(
+ "--batch_size_token_threshold_s",
+ type=int,
+ default=35,
+ help="",
+ )
+ parser.add_argument(
+ "--max_single_segment_time",
+ type=int,
+ default=5000,
+ help="",
+ )
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+
+ # set logging messages
+ logging.basicConfig(
+ level=logging.ERROR,
+ )
+ logging.info("Decoding args: {}".format(kwargs))
+
+ # kwargs["ncpu"] = 2 #os.cpu_count()
+ kwargs.pop("data_path_and_name_and_type")
+ print("args: {}".format(kwargs))
+ p = infer(**kwargs)
+
+ res = p(**kwargs)
+ print(res)
diff --git a/funasr/utils/download_and_prepare_model.py b/funasr/utils/download_and_prepare_model.py
new file mode 100644
index 0000000..af8a3f3
--- /dev/null
+++ b/funasr/utils/download_and_prepare_model.py
@@ -0,0 +1,93 @@
+
+import os
+from pathlib import Path
+import logging
+
+name_maps_ms = {
+ "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+ "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
+ "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
+ "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
+ "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
+ "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+ "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
+ "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
+}
+
+def prepare_model(
+ model: str = None,
+ # mode: str = None,
+ vad_model: str = None,
+ punc_model: str = None,
+ model_hub: str = "ms",
+ cache_dir: str = None,
+ **kwargs,
+):
+ if not Path(model).exists():
+ if model_hub == "ms" or model_hub == "modelscope":
+ from modelscope.utils.logger import get_logger
+
+ logger = get_logger(log_level=logging.CRITICAL)
+ logger.setLevel(logging.CRITICAL)
+ try:
+ from modelscope.hub.snapshot_download import snapshot_download as download_tool
+ model = name_maps_ms[model] if model is not None else None
+ vad_model = name_maps_ms[vad_model] if vad_model is not None else None
+ punc_model = name_maps_ms[punc_model] if punc_model is not None else None
+ except:
+ raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
+ "\npip3 install -U modelscope\n" \
+ "For the users in China, you could install with the command:\n" \
+ "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+
+ try:
+ model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None))
+ print("asr model have been downloaded to: {}".format(model))
+ except:
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
+ model)
+
+ elif model_hub == "hf" or model_hub == "huggingface":
+ download_tool = 0
+ else:
+ raise "model_hub must be on of ms or hf, but get {}".format(model_hub)
+
+
+ if vad_model is not None and not Path(vad_model).exists():
+ vad_model = download_tool(vad_model, cache_dir=cache_dir)
+ print("vad_model have been downloaded to: {}".format(vad_model))
+ if punc_model is not None and not Path(punc_model).exists():
+ punc_model = download_tool(punc_model, cache_dir=cache_dir)
+ print("punc_model have been downloaded to: {}".format(punc_model))
+
+ # asr
+ kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"),
+ "asr_model_file": None if model is None else os.path.join(model, "model.pb"),
+ "asr_train_config": None if model is None else os.path.join(model, "config.yaml"),
+ })
+ mode = kwargs.get("mode", None)
+ if mode is None:
+ import json
+ json_file = os.path.join(model, 'configuration.json')
+ with open(json_file, 'r') as f:
+ config_data = json.load(f)
+ if config_data['task'] == "punctuation":
+ mode = config_data['model']['punc_model_config']['mode']
+ else:
+ mode = config_data['model']['model_config']['mode']
+ if vad_model is not None and "vad" not in mode:
+ mode = "paraformer_vad"
+ kwargs["mode"] = mode
+ # vad
+ kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"),
+ "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"),
+ "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"),
+ })
+ # punc
+ kwargs.update({
+ "punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"),
+ "punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"),
+ })
+
+
+ return model, vad_model, punc_model, kwargs
diff --git a/funasr/version.txt b/funasr/version.txt
index b60d719..7ada0d3 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.8.4
+0.8.5
diff --git a/setup.py b/setup.py
index 069e394..dd485d3 100644
--- a/setup.py
+++ b/setup.py
@@ -129,4 +129,7 @@
"License :: OSI Approved :: Apache Software License",
"Topic :: Software Development :: Libraries :: Python Modules",
],
+ entry_points={"console_scripts": [
+ "funasr = funasr.bin.inference_cli:main",
+ ]},
)
--
Gitblit v1.9.1