From 244c033fbaeae15faf8b0351365bdb7607b2e2bb Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 十一月 2023 15:19:53 +0800
Subject: [PATCH] python cli

---
 funasr/version.txt                         |    2 
 funasr/bin/inference_cli.py                |  139 ++++++++
 funasr/bin/argument.py                     |  262 ++++++++++++++++
 funasr/utils/download_and_prepare_model.py |   93 +++++
 setup.py                                   |    3 
 README_zh.md                               |    9 
 funasr/__init__.py                         |  127 -------
 funasr/bin/asr_inference_launch.py         |  258 ---------------
 README.md                                  |    9 
 9 files changed, 522 insertions(+), 380 deletions(-)

diff --git a/README.md b/README.md
index f73c0ca..001ce3f 100644
--- a/README.md
+++ b/README.md
@@ -76,6 +76,15 @@
 FunASR supports inference and fine-tuning of models trained on industrial data for tens of thousands of hours. For more details, please refer to [modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html). It also supports training and fine-tuning of models on academic standard datasets. For more information, please refer to [egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html).
 
 Below is a quick start tutorial. Test audio files ([Mandarin](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav), [English]()).
+
+### Command-line usage
+
+```shell
+funasr --model paraformer-zh asr_example_zh.wav
+```
+
+Notes: Support recognition of single audio file, as well as file list in Kaldi-style wav.scp format: `wav_id wav_pat`
+
 ### Speech Recognition (Non-streaming)
 ```python
 from funasr import infer
diff --git a/README_zh.md b/README_zh.md
index 554c0b6..504c715 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -70,6 +70,15 @@
 FunASR鏀寔鏁颁竾灏忔椂宸ヤ笟鏁版嵁璁粌鐨勬ā鍨嬬殑鎺ㄧ悊鍜屽井璋冿紝璇︾粏淇℃伅鍙互鍙傞槄锛圼modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html)锛夛紱涔熸敮鎸佸鏈爣鍑嗘暟鎹泦妯″瀷鐨勮缁冨拰寰皟锛岃缁嗕俊鎭彲浠ュ弬闃咃紙[egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html)锛夈��
 
 涓嬮潰涓哄揩閫熶笂鎵嬫暀绋嬶紝娴嬭瘯闊抽锛圼涓枃](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav)锛孾鑻辨枃]()锛�
+
+### 鍙墽琛屽懡浠よ
+
+```shell
+funasr --model paraformer-zh asr_example_zh.wav
+```
+
+娉細鏀寔鍗曟潯闊抽鏂囦欢璇嗗埆锛屼篃鏀寔鏂囦欢鍒楄〃锛屽垪琛ㄤ负kaldi椋庢牸wav.scp锛歚wav_id   wav_path`
+
 ### 闈炲疄鏃惰闊宠瘑鍒�
 ```python
 from funasr import infer
diff --git a/funasr/__init__.py b/funasr/__init__.py
index aab4289..d0b7aa5 100644
--- a/funasr/__init__.py
+++ b/funasr/__init__.py
@@ -1,135 +1,10 @@
 """Initialize funasr package."""
 
 import os
-from pathlib import Path
-import torch
-import numpy as np
 
 dirname = os.path.dirname(__file__)
 version_file = os.path.join(dirname, "version.txt")
 with open(version_file, "r") as f:
     __version__ = f.read().strip()
 
-
-def prepare_model(
-    model: str = None,
-    # mode: str = None,
-    vad_model: str = None,
-    punc_model: str = None,
-    model_hub: str = "ms",
-    cache_dir: str = None,
-    **kwargs,
-):
-    if not Path(model).exists():
-        if model_hub == "ms" or model_hub == "modelscope":
-            try:
-                from modelscope.hub.snapshot_download import snapshot_download as download_tool
-                model = name_maps_ms[model] if model is not None else None
-                vad_model = name_maps_ms[vad_model] if vad_model is not None else None
-                punc_model = name_maps_ms[punc_model] if punc_model is not None else None
-            except:
-                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
-                      "\npip3 install -U modelscope\n" \
-                      "For the users in China, you could install with the command:\n" \
-                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
-        elif model_hub == "hf" or model_hub == "huggingface":
-            download_tool = 0
-        else:
-            raise "model_hub must be on of ms or hf, but get {}".format(model_hub)
-        try:
-            model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None))
-            print("model have been downloaded to: {}".format(model))
-        except:
-            raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
-                model)
-        
-        if vad_model is not None and not Path(vad_model).exists():
-            vad_model = download_tool(vad_model, cache_dir=cache_dir)
-            print("model have been downloaded to: {}".format(vad_model))
-        if punc_model is not None and not Path(punc_model).exists():
-            punc_model = download_tool(punc_model, cache_dir=cache_dir)
-            print("model have been downloaded to: {}".format(punc_model))
-        
-        # asr
-        kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"),
-                       "asr_model_file": None if model is None else os.path.join(model, "model.pb"),
-                       "asr_train_config": None if model is None else os.path.join(model, "config.yaml"),
-                       })
-        mode = kwargs.get("mode", None)
-        if mode is None:
-            import json
-            json_file = os.path.join(model, 'configuration.json')
-            with open(json_file, 'r') as f:
-                config_data = json.load(f)
-                if config_data['task'] == "punctuation":
-                    mode = config_data['model']['punc_model_config']['mode']
-                else:
-                    mode = config_data['model']['model_config']['mode']
-        if vad_model is not None and "vad" not in mode:
-            mode = "paraformer_vad"
-        kwargs["mode"] = mode
-        # vad
-        kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"),
-                       "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"),
-                       "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"),
-                       })
-        # punc
-        kwargs.update({
-            "punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"),
-            "punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"),
-        })
-        
-        
-        return model, vad_model, punc_model, kwargs
-
-name_maps_ms = {
-    "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
-    "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
-    "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
-    "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
-    "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
-    "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-    "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
-    "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
-}
-
-def infer(task_name: str = "asr",
-            model: str = None,
-            # mode: str = None,
-            vad_model: str = None,
-            punc_model: str = None,
-            model_hub: str = "ms",
-            cache_dir: str = None,
-            **kwargs,
-          ):
-
-    model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
-    if task_name == "asr":
-        from funasr.bin.asr_inference_launch import inference_launch
-
-        inference_pipeline = inference_launch(**kwargs)
-    elif task_name == "":
-        pipeline = 1
-    elif task_name == "":
-        pipeline = 2
-    elif task_name == "":
-        pipeline = 2
-    
-    def _infer_fn(input, **kwargs):
-        data_type = kwargs.get('data_type', 'sound')
-        data_path_and_name_and_type = [input, 'speech', data_type]
-        raw_inputs = None
-        if isinstance(input, torch.Tensor):
-            input = input.numpy()
-        if isinstance(input, np.ndarray):
-            data_path_and_name_and_type = None
-            raw_inputs = input
-            
-
-        
-        return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
-    
-    return _infer_fn
-
-if __name__ == '__main__':
-    pass
\ No newline at end of file
+from funasr.bin.inference_cli import infer
\ No newline at end of file
diff --git a/funasr/bin/argument.py b/funasr/bin/argument.py
new file mode 100644
index 0000000..0ea4ac9
--- /dev/null
+++ b/funasr/bin/argument.py
@@ -0,0 +1,262 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import sys
+
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils import config_argparse
+import argparse
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="ASR Decoding",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    # Note(kamo): Use '_' instead of '-' as separator.
+    # '-' is confusing if written in yaml.
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument("--output_dir", type=str, default=None)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=1,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument(
+        "--njob",
+        type=int,
+        default=1,
+        help="The number of jobs for each gpu",
+    )
+    parser.add_argument(
+        "--gpuid_list",
+        type=str,
+        default="",
+        help="The visible gpus",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument(
+        "--data_path_and_name_and_type",
+        type=str2triple_str,
+        required=False,
+        action="append",
+    )
+    group.add_argument("--key_file", type=str_or_none)
+    parser.add_argument(
+        "--hotword",
+        type=str_or_none,
+        default=None,
+        help="hotword file path or hotwords seperated by space"
+    )
+    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+    group.add_argument(
+        "--mc",
+        type=bool,
+        default=False,
+        help="MultiChannel input",
+    )
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument(
+        "--vad_infer_config",
+        type=str,
+        help="VAD infer configuration",
+    )
+    group.add_argument(
+        "--vad_model_file",
+        type=str,
+        help="VAD model parameter file",
+    )
+    group.add_argument(
+        "--punc_infer_config",
+        type=str,
+        help="PUNC infer configuration",
+    )
+    group.add_argument(
+        "--punc_model_file",
+        type=str,
+        help="PUNC model parameter file",
+    )
+    group.add_argument(
+        "--cmvn_file",
+        type=str,
+        help="Global CMVN file",
+    )
+    group.add_argument(
+        "--asr_train_config",
+        type=str,
+        help="ASR training configuration",
+    )
+    group.add_argument(
+        "--asr_model_file",
+        type=str,
+        help="ASR model parameter file",
+    )
+    group.add_argument(
+        "--sv_model_file",
+        type=str,
+        help="SV model parameter file",
+    )
+    group.add_argument(
+        "--lm_train_config",
+        type=str,
+        help="LM training configuration",
+    )
+    group.add_argument(
+        "--lm_file",
+        type=str,
+        help="LM parameter file",
+    )
+    group.add_argument(
+        "--word_lm_train_config",
+        type=str,
+        help="Word LM training configuration",
+    )
+    group.add_argument(
+        "--word_lm_file",
+        type=str,
+        help="Word LM parameter file",
+    )
+    group.add_argument(
+        "--ngram_file",
+        type=str,
+        help="N-gram parameter file",
+    )
+    group.add_argument(
+        "--model_tag",
+        type=str,
+        help="Pretrained model tag. If specify this option, *_train_config and "
+             "*_file will be overwritten",
+    )
+    group.add_argument(
+        "--beam_search_config",
+        default={},
+        help="The keyword arguments for transducer beam search.",
+    )
+
+    group = parser.add_argument_group("Beam-search related")
+    group.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+    group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
+    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+    group.add_argument(
+        "--maxlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain max output length. "
+             "If maxlenratio=0.0 (default), it uses a end-detect "
+             "function "
+             "to automatically find maximum hypothesis lengths."
+             "If maxlenratio<0.0, its absolute value is interpreted"
+             "as a constant max output length",
+    )
+    group.add_argument(
+        "--minlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain min output length",
+    )
+    group.add_argument(
+        "--ctc_weight",
+        type=float,
+        default=0.0,
+        help="CTC weight in joint decoding",
+    )
+    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+    group.add_argument("--streaming", type=str2bool, default=False)
+    group.add_argument("--fake_streaming", type=str2bool, default=False)
+    group.add_argument("--full_utt", type=str2bool, default=False)
+    group.add_argument("--chunk_size", type=int, default=16)
+    group.add_argument("--left_context", type=int, default=16)
+    group.add_argument("--right_context", type=int, default=0)
+    group.add_argument(
+        "--display_partial_hypotheses",
+        type=bool,
+        default=False,
+        help="Whether to display partial hypotheses during chunk-by-chunk inference.",
+    )
+
+    group = parser.add_argument_group("Dynamic quantization related")
+    group.add_argument(
+        "--quantize_asr_model",
+        type=bool,
+        default=False,
+        help="Apply dynamic quantization to ASR model.",
+    )
+    group.add_argument(
+        "--quantize_modules",
+        nargs="*",
+        default=None,
+        help="""Module names to apply dynamic quantization on.
+        The module names are provided as a list, where each name is separated
+        by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
+        Each specified name should be an attribute of 'torch.nn', e.g.:
+        torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
+    )
+    group.add_argument(
+        "--quantize_dtype",
+        type=str,
+        default="qint8",
+        choices=["float16", "qint8"],
+        help="Dtype for dynamic quantization.",
+    )
+
+    group = parser.add_argument_group("Text converter related")
+    group.add_argument(
+        "--token_type",
+        type=str_or_none,
+        default=None,
+        choices=["char", "bpe", None],
+        help="The token type for ASR model. "
+             "If not given, refers from the training args",
+    )
+    group.add_argument(
+        "--bpemodel",
+        type=str_or_none,
+        default=None,
+        help="The model path of sentencepiece. "
+             "If not given, refers from the training args",
+    )
+    group.add_argument("--token_num_relax", type=int, default=1, help="")
+    group.add_argument("--decoding_ind", type=int, default=0, help="")
+    group.add_argument("--decoding_mode", type=str, default="model1", help="")
+    group.add_argument(
+        "--ctc_weight2",
+        type=float,
+        default=0.0,
+        help="CTC weight in joint decoding",
+    )
+    return parser
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index e93d740..e1a32c5 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -675,11 +675,13 @@
                 beg_idx = end_idx
                 batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
                 batch = to_device(batch, device=device)
-                # print("batch: ", speech_j.shape[0])
+
                 beg_asr = time.time()
                 results = speech2text(**batch)
                 end_asr = time.time()
-                # print("time cost asr: ", end_asr - beg_asr)
+                if speech2text.device != "cpu":
+                    print("batch: ", speech_j.shape[0])
+                    print("time cost asr: ", end_asr - beg_asr)
 
                 if len(results) < 1:
                     results = [["", [], [], [], [], [], []]]
@@ -2218,259 +2220,9 @@
         logging.info("Unknown decoding mode: {}".format(mode))
         return None
 
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="ASR Decoding",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    # Note(kamo): Use '_' instead of '-' as separator.
-    # '-' is confusing if written in yaml.
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-
-    parser.add_argument("--output_dir", type=str, required=True)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument(
-        "--njob",
-        type=int,
-        default=1,
-        help="The number of jobs for each gpu",
-    )
-    parser.add_argument(
-        "--gpuid_list",
-        type=str,
-        default="",
-        help="The visible gpus",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        required=True,
-        action="append",
-    )
-    group.add_argument("--key_file", type=str_or_none)
-    parser.add_argument(
-        "--hotword",
-        type=str_or_none,
-        default=None,
-        help="hotword file path or hotwords seperated by space"
-    )
-    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-    group.add_argument(
-        "--mc",
-        type=bool,
-        default=False,
-        help="MultiChannel input",
-    )
-
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument(
-        "--vad_infer_config",
-        type=str,
-        help="VAD infer configuration",
-    )
-    group.add_argument(
-        "--vad_model_file",
-        type=str,
-        help="VAD model parameter file",
-    )
-    group.add_argument(
-        "--punc_infer_config",
-        type=str,
-        help="PUNC infer configuration",
-    )
-    group.add_argument(
-        "--punc_model_file",
-        type=str,
-        help="PUNC model parameter file",
-    )
-    group.add_argument(
-        "--cmvn_file",
-        type=str,
-        help="Global CMVN file",
-    )
-    group.add_argument(
-        "--asr_train_config",
-        type=str,
-        help="ASR training configuration",
-    )
-    group.add_argument(
-        "--asr_model_file",
-        type=str,
-        help="ASR model parameter file",
-    )
-    group.add_argument(
-        "--sv_model_file",
-        type=str,
-        help="SV model parameter file",
-    )
-    group.add_argument(
-        "--lm_train_config",
-        type=str,
-        help="LM training configuration",
-    )
-    group.add_argument(
-        "--lm_file",
-        type=str,
-        help="LM parameter file",
-    )
-    group.add_argument(
-        "--word_lm_train_config",
-        type=str,
-        help="Word LM training configuration",
-    )
-    group.add_argument(
-        "--word_lm_file",
-        type=str,
-        help="Word LM parameter file",
-    )
-    group.add_argument(
-        "--ngram_file",
-        type=str,
-        help="N-gram parameter file",
-    )
-    group.add_argument(
-        "--model_tag",
-        type=str,
-        help="Pretrained model tag. If specify this option, *_train_config and "
-             "*_file will be overwritten",
-    )
-    group.add_argument(
-        "--beam_search_config",
-        default={},
-        help="The keyword arguments for transducer beam search.",
-    )
-
-    group = parser.add_argument_group("Beam-search related")
-    group.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-    group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
-    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
-    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
-    group.add_argument(
-        "--maxlenratio",
-        type=float,
-        default=0.0,
-        help="Input length ratio to obtain max output length. "
-             "If maxlenratio=0.0 (default), it uses a end-detect "
-             "function "
-             "to automatically find maximum hypothesis lengths."
-             "If maxlenratio<0.0, its absolute value is interpreted"
-             "as a constant max output length",
-    )
-    group.add_argument(
-        "--minlenratio",
-        type=float,
-        default=0.0,
-        help="Input length ratio to obtain min output length",
-    )
-    group.add_argument(
-        "--ctc_weight",
-        type=float,
-        default=0.0,
-        help="CTC weight in joint decoding",
-    )
-    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
-    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
-    group.add_argument("--streaming", type=str2bool, default=False)
-    group.add_argument("--fake_streaming", type=str2bool, default=False)
-    group.add_argument("--full_utt", type=str2bool, default=False)
-    group.add_argument("--chunk_size", type=int, default=16)
-    group.add_argument("--left_context", type=int, default=16)
-    group.add_argument("--right_context", type=int, default=0)
-    group.add_argument(
-        "--display_partial_hypotheses",
-        type=bool,
-        default=False,
-        help="Whether to display partial hypotheses during chunk-by-chunk inference.",
-    )
-
-    group = parser.add_argument_group("Dynamic quantization related")
-    group.add_argument(
-        "--quantize_asr_model",
-        type=bool,
-        default=False,
-        help="Apply dynamic quantization to ASR model.",
-    )
-    group.add_argument(
-        "--quantize_modules",
-        nargs="*",
-        default=None,
-        help="""Module names to apply dynamic quantization on.
-        The module names are provided as a list, where each name is separated
-        by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
-        Each specified name should be an attribute of 'torch.nn', e.g.:
-        torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
-    )
-    group.add_argument(
-        "--quantize_dtype",
-        type=str,
-        default="qint8",
-        choices=["float16", "qint8"],
-        help="Dtype for dynamic quantization.",
-    )
-
-    group = parser.add_argument_group("Text converter related")
-    group.add_argument(
-        "--token_type",
-        type=str_or_none,
-        default=None,
-        choices=["char", "bpe", None],
-        help="The token type for ASR model. "
-             "If not given, refers from the training args",
-    )
-    group.add_argument(
-        "--bpemodel",
-        type=str_or_none,
-        default=None,
-        help="The model path of sentencepiece. "
-             "If not given, refers from the training args",
-    )
-    group.add_argument("--token_num_relax", type=int, default=1, help="")
-    group.add_argument("--decoding_ind", type=int, default=0, help="")
-    group.add_argument("--decoding_mode", type=str, default="model1", help="")
-    group.add_argument(
-        "--ctc_weight2",
-        type=float,
-        default=0.0,
-        help="CTC weight in joint decoding",
-    )
-    return parser
-
-
 def main(cmd=None):
     print(get_commandline_args(), file=sys.stderr)
+    from funasr.bin.argument import get_parser
     parser = get_parser()
     parser.add_argument(
         "--mode",
diff --git a/funasr/bin/inference_cli.py b/funasr/bin/inference_cli.py
new file mode 100644
index 0000000..f4c66f1
--- /dev/null
+++ b/funasr/bin/inference_cli.py
@@ -0,0 +1,139 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import os
+
+import logging
+import torch
+import numpy as np
+from funasr.utils.download_and_prepare_model import prepare_model
+
+from funasr.utils.types import str2bool
+
+def infer(task_name: str = "asr",
+          model: str = None,
+          # mode: str = None,
+          vad_model: str = None,
+          disable_vad: bool = False,
+          punc_model: str = None,
+          disable_punc: bool = False,
+          model_hub: str = "ms",
+          cache_dir: str = None,
+          **kwargs,
+          ):
+
+	# set logging messages
+	logging.basicConfig(
+		level=logging.ERROR,
+	)
+
+	model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
+	if task_name == "asr":
+		from funasr.bin.asr_inference_launch import inference_launch
+		
+		inference_pipeline = inference_launch(**kwargs)
+	elif task_name == "":
+		pipeline = 1
+	elif task_name == "":
+		pipeline = 2
+	elif task_name == "":
+		pipeline = 2
+	
+	def _infer_fn(input, **kwargs):
+		data_type = kwargs.get('data_type', 'sound')
+		data_path_and_name_and_type = [input, 'speech', data_type]
+		raw_inputs = None
+		if isinstance(input, torch.Tensor):
+			input = input.numpy()
+		if isinstance(input, np.ndarray):
+			data_path_and_name_and_type = None
+			raw_inputs = input
+		
+		return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
+	
+	return _infer_fn
+
+
+def main(cmd=None):
+	# print(get_commandline_args(), file=sys.stderr)
+	from funasr.bin.argument import get_parser
+	
+	parser = get_parser()
+	parser.add_argument('input', help='input file to transcribe')
+	parser.add_argument(
+	    "--task_name",
+	    type=str,
+	    default="asr",
+	    help="The decoding mode",
+	)
+	parser.add_argument(
+		"-m",
+	    "--model",
+	    type=str,
+	    default="paraformer-zh",
+	    help="The asr mode name",
+	)
+	parser.add_argument(
+		"-v",
+	    "--vad_model",
+	    type=str,
+	    default="fsmn-vad",
+	    help="vad model name",
+	)
+	parser.add_argument(
+		"-dv",
+	    "--disable_vad",
+	    type=str2bool,
+	    default=False,
+	    help="",
+	)
+	parser.add_argument(
+		"-p",
+	    "--punc_model",
+	    type=str,
+	    default="ct-punc",
+	    help="",
+	)
+	parser.add_argument(
+		"-dp",
+	    "--disable_punc",
+	    type=str2bool,
+	    default=False,
+	    help="",
+	)
+	parser.add_argument(
+	    "--batch_size_token",
+	    type=int,
+	    default=5000,
+	    help="",
+	)
+	parser.add_argument(
+	    "--batch_size_token_threshold_s",
+	    type=int,
+	    default=35,
+	    help="",
+	)
+	parser.add_argument(
+	    "--max_single_segment_time",
+	    type=int,
+	    default=5000,
+	    help="",
+	)
+	args = parser.parse_args(cmd)
+	kwargs = vars(args)
+	
+	# set logging messages
+	logging.basicConfig(
+		level=logging.ERROR,
+	)
+	logging.info("Decoding args: {}".format(kwargs))
+	
+	# kwargs["ncpu"] = 2 #os.cpu_count()
+	kwargs.pop("data_path_and_name_and_type")
+	print("args: {}".format(kwargs))
+	p = infer(**kwargs)
+	
+	res = p(**kwargs)
+	print(res)
diff --git a/funasr/utils/download_and_prepare_model.py b/funasr/utils/download_and_prepare_model.py
new file mode 100644
index 0000000..af8a3f3
--- /dev/null
+++ b/funasr/utils/download_and_prepare_model.py
@@ -0,0 +1,93 @@
+
+import os
+from pathlib import Path
+import logging
+
+name_maps_ms = {
+    "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
+    "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
+    "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
+    "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
+    "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
+    "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+    "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
+    "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
+}
+
+def prepare_model(
+	model: str = None,
+	# mode: str = None,
+	vad_model: str = None,
+	punc_model: str = None,
+	model_hub: str = "ms",
+	cache_dir: str = None,
+	**kwargs,
+):
+	if not Path(model).exists():
+		if model_hub == "ms" or model_hub == "modelscope":
+			from modelscope.utils.logger import get_logger
+			
+			logger = get_logger(log_level=logging.CRITICAL)
+			logger.setLevel(logging.CRITICAL)
+			try:
+				from modelscope.hub.snapshot_download import snapshot_download as download_tool
+				model = name_maps_ms[model] if model is not None else None
+				vad_model = name_maps_ms[vad_model] if vad_model is not None else None
+				punc_model = name_maps_ms[punc_model] if punc_model is not None else None
+			except:
+				raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
+				      "\npip3 install -U modelscope\n" \
+				      "For the users in China, you could install with the command:\n" \
+				      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+
+			try:
+				model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None))
+				print("asr model have been downloaded to: {}".format(model))
+			except:
+				raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
+					model)
+		
+		elif model_hub == "hf" or model_hub == "huggingface":
+			download_tool = 0
+		else:
+			raise "model_hub must be on of ms or hf, but get {}".format(model_hub)
+
+		
+		if vad_model is not None and not Path(vad_model).exists():
+			vad_model = download_tool(vad_model, cache_dir=cache_dir)
+			print("vad_model have been downloaded to: {}".format(vad_model))
+		if punc_model is not None and not Path(punc_model).exists():
+			punc_model = download_tool(punc_model, cache_dir=cache_dir)
+			print("punc_model have been downloaded to: {}".format(punc_model))
+		
+		# asr
+		kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"),
+		               "asr_model_file": None if model is None else os.path.join(model, "model.pb"),
+		               "asr_train_config": None if model is None else os.path.join(model, "config.yaml"),
+		               })
+		mode = kwargs.get("mode", None)
+		if mode is None:
+			import json
+			json_file = os.path.join(model, 'configuration.json')
+			with open(json_file, 'r') as f:
+				config_data = json.load(f)
+				if config_data['task'] == "punctuation":
+					mode = config_data['model']['punc_model_config']['mode']
+				else:
+					mode = config_data['model']['model_config']['mode']
+		if vad_model is not None and "vad" not in mode:
+			mode = "paraformer_vad"
+		kwargs["mode"] = mode
+		# vad
+		kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"),
+		               "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"),
+		               "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"),
+		               })
+		# punc
+		kwargs.update({
+			"punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"),
+			"punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"),
+		})
+		
+		
+		return model, vad_model, punc_model, kwargs
diff --git a/funasr/version.txt b/funasr/version.txt
index b60d719..7ada0d3 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.8.4
+0.8.5
diff --git a/setup.py b/setup.py
index 069e394..dd485d3 100644
--- a/setup.py
+++ b/setup.py
@@ -129,4 +129,7 @@
         "License :: OSI Approved :: Apache Software License",
         "Topic :: Software Development :: Libraries :: Python Modules",
     ],
+    entry_points={"console_scripts": [
+        "funasr = funasr.bin.inference_cli:main",
+    ]},
 )

--
Gitblit v1.9.1