游雁
2023-11-17 244c033fbaeae15faf8b0351365bdb7607b2e2bb
python cli
6个文件已修改
3个文件已添加
902 ■■■■■ 已修改文件
README.md 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README_zh.md 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/__init__.py 127 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/argument.py 262 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 258 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/inference_cli.py 139 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/download_and_prepare_model.py 93 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/version.txt 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md
@@ -76,6 +76,15 @@
FunASR supports inference and fine-tuning of models trained on industrial data for tens of thousands of hours. For more details, please refer to [modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html). It also supports training and fine-tuning of models on academic standard datasets. For more information, please refer to [egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html).
Below is a quick start tutorial. Test audio files ([Mandarin](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav), [English]()).
### Command-line usage
```shell
funasr --model paraformer-zh asr_example_zh.wav
```
Notes: Support recognition of single audio file, as well as file list in Kaldi-style wav.scp format: `wav_id wav_pat`
### Speech Recognition (Non-streaming)
```python
from funasr import infer
README_zh.md
@@ -70,6 +70,15 @@
FunASR支持数万小时工业数据训练的模型的推理和微调,详细信息可以参阅([modelscope_egs](https://alibaba-damo-academy.github.io/FunASR/en/modelscope_pipeline/quick_start.html));也支持学术标准数据集模型的训练和微调,详细信息可以参阅([egs](https://alibaba-damo-academy.github.io/FunASR/en/academic_recipe/asr_recipe.html))。
下面为快速上手教程,测试音频([中文](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav),[英文]())
### 可执行命令行
```shell
funasr --model paraformer-zh asr_example_zh.wav
```
注:支持单条音频文件识别,也支持文件列表,列表为kaldi风格wav.scp:`wav_id   wav_path`
### 非实时语音识别
```python
from funasr import infer
funasr/__init__.py
@@ -1,135 +1,10 @@
"""Initialize funasr package."""
import os
from pathlib import Path
import torch
import numpy as np
dirname = os.path.dirname(__file__)
version_file = os.path.join(dirname, "version.txt")
with open(version_file, "r") as f:
    __version__ = f.read().strip()
def prepare_model(
    model: str = None,
    # mode: str = None,
    vad_model: str = None,
    punc_model: str = None,
    model_hub: str = "ms",
    cache_dir: str = None,
    **kwargs,
):
    if not Path(model).exists():
        if model_hub == "ms" or model_hub == "modelscope":
            try:
                from modelscope.hub.snapshot_download import snapshot_download as download_tool
                model = name_maps_ms[model] if model is not None else None
                vad_model = name_maps_ms[vad_model] if vad_model is not None else None
                punc_model = name_maps_ms[punc_model] if punc_model is not None else None
            except:
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
                      "\npip3 install -U modelscope\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
        elif model_hub == "hf" or model_hub == "huggingface":
            download_tool = 0
        else:
            raise "model_hub must be on of ms or hf, but get {}".format(model_hub)
        try:
            model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None))
            print("model have been downloaded to: {}".format(model))
        except:
            raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
                model)
        if vad_model is not None and not Path(vad_model).exists():
            vad_model = download_tool(vad_model, cache_dir=cache_dir)
            print("model have been downloaded to: {}".format(vad_model))
        if punc_model is not None and not Path(punc_model).exists():
            punc_model = download_tool(punc_model, cache_dir=cache_dir)
            print("model have been downloaded to: {}".format(punc_model))
        # asr
        kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"),
                       "asr_model_file": None if model is None else os.path.join(model, "model.pb"),
                       "asr_train_config": None if model is None else os.path.join(model, "config.yaml"),
                       })
        mode = kwargs.get("mode", None)
        if mode is None:
            import json
            json_file = os.path.join(model, 'configuration.json')
            with open(json_file, 'r') as f:
                config_data = json.load(f)
                if config_data['task'] == "punctuation":
                    mode = config_data['model']['punc_model_config']['mode']
                else:
                    mode = config_data['model']['model_config']['mode']
        if vad_model is not None and "vad" not in mode:
            mode = "paraformer_vad"
        kwargs["mode"] = mode
        # vad
        kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"),
                       "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"),
                       "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"),
                       })
        # punc
        kwargs.update({
            "punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"),
            "punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"),
        })
        return model, vad_model, punc_model, kwargs
name_maps_ms = {
    "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
    "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
    "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
    "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
    "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
}
def infer(task_name: str = "asr",
            model: str = None,
            # mode: str = None,
            vad_model: str = None,
            punc_model: str = None,
            model_hub: str = "ms",
            cache_dir: str = None,
            **kwargs,
          ):
    model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
    if task_name == "asr":
        from funasr.bin.asr_inference_launch import inference_launch
        inference_pipeline = inference_launch(**kwargs)
    elif task_name == "":
        pipeline = 1
    elif task_name == "":
        pipeline = 2
    elif task_name == "":
        pipeline = 2
    def _infer_fn(input, **kwargs):
        data_type = kwargs.get('data_type', 'sound')
        data_path_and_name_and_type = [input, 'speech', data_type]
        raw_inputs = None
        if isinstance(input, torch.Tensor):
            input = input.numpy()
        if isinstance(input, np.ndarray):
            data_path_and_name_and_type = None
            raw_inputs = input
        return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
    return _infer_fn
if __name__ == '__main__':
    pass
from funasr.bin.inference_cli import infer
funasr/bin/argument.py
New file
@@ -0,0 +1,262 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import sys
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import config_argparse
import argparse
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="ASR Decoding",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Note(kamo): Use '_' instead of '-' as separator.
    # '-' is confusing if written in yaml.
    parser.add_argument(
        "--log_level",
        type=lambda x: x.upper(),
        default="INFO",
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )
    parser.add_argument("--output_dir", type=str, default=None)
    parser.add_argument(
        "--ngpu",
        type=int,
        default=1,
        help="The number of gpus. 0 indicates CPU mode",
    )
    parser.add_argument(
        "--njob",
        type=int,
        default=1,
        help="The number of jobs for each gpu",
    )
    parser.add_argument(
        "--gpuid_list",
        type=str,
        default="",
        help="The visible gpus",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=1,
        help="The number of workers used for DataLoader",
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
        type=str2triple_str,
        required=False,
        action="append",
    )
    group.add_argument("--key_file", type=str_or_none)
    parser.add_argument(
        "--hotword",
        type=str_or_none,
        default=None,
        help="hotword file path or hotwords seperated by space"
    )
    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
    group.add_argument(
        "--mc",
        type=bool,
        default=False,
        help="MultiChannel input",
    )
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--vad_infer_config",
        type=str,
        help="VAD infer configuration",
    )
    group.add_argument(
        "--vad_model_file",
        type=str,
        help="VAD model parameter file",
    )
    group.add_argument(
        "--punc_infer_config",
        type=str,
        help="PUNC infer configuration",
    )
    group.add_argument(
        "--punc_model_file",
        type=str,
        help="PUNC model parameter file",
    )
    group.add_argument(
        "--cmvn_file",
        type=str,
        help="Global CMVN file",
    )
    group.add_argument(
        "--asr_train_config",
        type=str,
        help="ASR training configuration",
    )
    group.add_argument(
        "--asr_model_file",
        type=str,
        help="ASR model parameter file",
    )
    group.add_argument(
        "--sv_model_file",
        type=str,
        help="SV model parameter file",
    )
    group.add_argument(
        "--lm_train_config",
        type=str,
        help="LM training configuration",
    )
    group.add_argument(
        "--lm_file",
        type=str,
        help="LM parameter file",
    )
    group.add_argument(
        "--word_lm_train_config",
        type=str,
        help="Word LM training configuration",
    )
    group.add_argument(
        "--word_lm_file",
        type=str,
        help="Word LM parameter file",
    )
    group.add_argument(
        "--ngram_file",
        type=str,
        help="N-gram parameter file",
    )
    group.add_argument(
        "--model_tag",
        type=str,
        help="Pretrained model tag. If specify this option, *_train_config and "
             "*_file will be overwritten",
    )
    group.add_argument(
        "--beam_search_config",
        default={},
        help="The keyword arguments for transducer beam search.",
    )
    group = parser.add_argument_group("Beam-search related")
    group.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="The batch size for inference",
    )
    group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
    group.add_argument(
        "--maxlenratio",
        type=float,
        default=0.0,
        help="Input length ratio to obtain max output length. "
             "If maxlenratio=0.0 (default), it uses a end-detect "
             "function "
             "to automatically find maximum hypothesis lengths."
             "If maxlenratio<0.0, its absolute value is interpreted"
             "as a constant max output length",
    )
    group.add_argument(
        "--minlenratio",
        type=float,
        default=0.0,
        help="Input length ratio to obtain min output length",
    )
    group.add_argument(
        "--ctc_weight",
        type=float,
        default=0.0,
        help="CTC weight in joint decoding",
    )
    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
    group.add_argument("--streaming", type=str2bool, default=False)
    group.add_argument("--fake_streaming", type=str2bool, default=False)
    group.add_argument("--full_utt", type=str2bool, default=False)
    group.add_argument("--chunk_size", type=int, default=16)
    group.add_argument("--left_context", type=int, default=16)
    group.add_argument("--right_context", type=int, default=0)
    group.add_argument(
        "--display_partial_hypotheses",
        type=bool,
        default=False,
        help="Whether to display partial hypotheses during chunk-by-chunk inference.",
    )
    group = parser.add_argument_group("Dynamic quantization related")
    group.add_argument(
        "--quantize_asr_model",
        type=bool,
        default=False,
        help="Apply dynamic quantization to ASR model.",
    )
    group.add_argument(
        "--quantize_modules",
        nargs="*",
        default=None,
        help="""Module names to apply dynamic quantization on.
        The module names are provided as a list, where each name is separated
        by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
        Each specified name should be an attribute of 'torch.nn', e.g.:
        torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
    )
    group.add_argument(
        "--quantize_dtype",
        type=str,
        default="qint8",
        choices=["float16", "qint8"],
        help="Dtype for dynamic quantization.",
    )
    group = parser.add_argument_group("Text converter related")
    group.add_argument(
        "--token_type",
        type=str_or_none,
        default=None,
        choices=["char", "bpe", None],
        help="The token type for ASR model. "
             "If not given, refers from the training args",
    )
    group.add_argument(
        "--bpemodel",
        type=str_or_none,
        default=None,
        help="The model path of sentencepiece. "
             "If not given, refers from the training args",
    )
    group.add_argument("--token_num_relax", type=int, default=1, help="")
    group.add_argument("--decoding_ind", type=int, default=0, help="")
    group.add_argument("--decoding_mode", type=str, default="model1", help="")
    group.add_argument(
        "--ctc_weight2",
        type=float,
        default=0.0,
        help="CTC weight in joint decoding",
    )
    return parser
funasr/bin/asr_inference_launch.py
@@ -675,11 +675,13 @@
                beg_idx = end_idx
                batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
                batch = to_device(batch, device=device)
                # print("batch: ", speech_j.shape[0])
                beg_asr = time.time()
                results = speech2text(**batch)
                end_asr = time.time()
                # print("time cost asr: ", end_asr - beg_asr)
                if speech2text.device != "cpu":
                    print("batch: ", speech_j.shape[0])
                    print("time cost asr: ", end_asr - beg_asr)
                if len(results) < 1:
                    results = [["", [], [], [], [], [], []]]
@@ -2218,259 +2220,9 @@
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def get_parser():
    parser = config_argparse.ArgumentParser(
        description="ASR Decoding",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    # Note(kamo): Use '_' instead of '-' as separator.
    # '-' is confusing if written in yaml.
    parser.add_argument(
        "--log_level",
        type=lambda x: x.upper(),
        default="INFO",
        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
        help="The verbose level of logging",
    )
    parser.add_argument("--output_dir", type=str, required=True)
    parser.add_argument(
        "--ngpu",
        type=int,
        default=0,
        help="The number of gpus. 0 indicates CPU mode",
    )
    parser.add_argument(
        "--njob",
        type=int,
        default=1,
        help="The number of jobs for each gpu",
    )
    parser.add_argument(
        "--gpuid_list",
        type=str,
        default="",
        help="The visible gpus",
    )
    parser.add_argument("--seed", type=int, default=0, help="Random seed")
    parser.add_argument(
        "--dtype",
        default="float32",
        choices=["float16", "float32", "float64"],
        help="Data type",
    )
    parser.add_argument(
        "--num_workers",
        type=int,
        default=1,
        help="The number of workers used for DataLoader",
    )
    group = parser.add_argument_group("Input data related")
    group.add_argument(
        "--data_path_and_name_and_type",
        type=str2triple_str,
        required=True,
        action="append",
    )
    group.add_argument("--key_file", type=str_or_none)
    parser.add_argument(
        "--hotword",
        type=str_or_none,
        default=None,
        help="hotword file path or hotwords seperated by space"
    )
    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
    group.add_argument(
        "--mc",
        type=bool,
        default=False,
        help="MultiChannel input",
    )
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--vad_infer_config",
        type=str,
        help="VAD infer configuration",
    )
    group.add_argument(
        "--vad_model_file",
        type=str,
        help="VAD model parameter file",
    )
    group.add_argument(
        "--punc_infer_config",
        type=str,
        help="PUNC infer configuration",
    )
    group.add_argument(
        "--punc_model_file",
        type=str,
        help="PUNC model parameter file",
    )
    group.add_argument(
        "--cmvn_file",
        type=str,
        help="Global CMVN file",
    )
    group.add_argument(
        "--asr_train_config",
        type=str,
        help="ASR training configuration",
    )
    group.add_argument(
        "--asr_model_file",
        type=str,
        help="ASR model parameter file",
    )
    group.add_argument(
        "--sv_model_file",
        type=str,
        help="SV model parameter file",
    )
    group.add_argument(
        "--lm_train_config",
        type=str,
        help="LM training configuration",
    )
    group.add_argument(
        "--lm_file",
        type=str,
        help="LM parameter file",
    )
    group.add_argument(
        "--word_lm_train_config",
        type=str,
        help="Word LM training configuration",
    )
    group.add_argument(
        "--word_lm_file",
        type=str,
        help="Word LM parameter file",
    )
    group.add_argument(
        "--ngram_file",
        type=str,
        help="N-gram parameter file",
    )
    group.add_argument(
        "--model_tag",
        type=str,
        help="Pretrained model tag. If specify this option, *_train_config and "
             "*_file will be overwritten",
    )
    group.add_argument(
        "--beam_search_config",
        default={},
        help="The keyword arguments for transducer beam search.",
    )
    group = parser.add_argument_group("Beam-search related")
    group.add_argument(
        "--batch_size",
        type=int,
        default=1,
        help="The batch size for inference",
    )
    group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
    group.add_argument(
        "--maxlenratio",
        type=float,
        default=0.0,
        help="Input length ratio to obtain max output length. "
             "If maxlenratio=0.0 (default), it uses a end-detect "
             "function "
             "to automatically find maximum hypothesis lengths."
             "If maxlenratio<0.0, its absolute value is interpreted"
             "as a constant max output length",
    )
    group.add_argument(
        "--minlenratio",
        type=float,
        default=0.0,
        help="Input length ratio to obtain min output length",
    )
    group.add_argument(
        "--ctc_weight",
        type=float,
        default=0.0,
        help="CTC weight in joint decoding",
    )
    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
    group.add_argument("--streaming", type=str2bool, default=False)
    group.add_argument("--fake_streaming", type=str2bool, default=False)
    group.add_argument("--full_utt", type=str2bool, default=False)
    group.add_argument("--chunk_size", type=int, default=16)
    group.add_argument("--left_context", type=int, default=16)
    group.add_argument("--right_context", type=int, default=0)
    group.add_argument(
        "--display_partial_hypotheses",
        type=bool,
        default=False,
        help="Whether to display partial hypotheses during chunk-by-chunk inference.",
    )
    group = parser.add_argument_group("Dynamic quantization related")
    group.add_argument(
        "--quantize_asr_model",
        type=bool,
        default=False,
        help="Apply dynamic quantization to ASR model.",
    )
    group.add_argument(
        "--quantize_modules",
        nargs="*",
        default=None,
        help="""Module names to apply dynamic quantization on.
        The module names are provided as a list, where each name is separated
        by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
        Each specified name should be an attribute of 'torch.nn', e.g.:
        torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
    )
    group.add_argument(
        "--quantize_dtype",
        type=str,
        default="qint8",
        choices=["float16", "qint8"],
        help="Dtype for dynamic quantization.",
    )
    group = parser.add_argument_group("Text converter related")
    group.add_argument(
        "--token_type",
        type=str_or_none,
        default=None,
        choices=["char", "bpe", None],
        help="The token type for ASR model. "
             "If not given, refers from the training args",
    )
    group.add_argument(
        "--bpemodel",
        type=str_or_none,
        default=None,
        help="The model path of sentencepiece. "
             "If not given, refers from the training args",
    )
    group.add_argument("--token_num_relax", type=int, default=1, help="")
    group.add_argument("--decoding_ind", type=int, default=0, help="")
    group.add_argument("--decoding_mode", type=str, default="model1", help="")
    group.add_argument(
        "--ctc_weight2",
        type=float,
        default=0.0,
        help="CTC weight in joint decoding",
    )
    return parser
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    from funasr.bin.argument import get_parser
    parser = get_parser()
    parser.add_argument(
        "--mode",
funasr/bin/inference_cli.py
New file
@@ -0,0 +1,139 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
import logging
import torch
import numpy as np
from funasr.utils.download_and_prepare_model import prepare_model
from funasr.utils.types import str2bool
def infer(task_name: str = "asr",
          model: str = None,
          # mode: str = None,
          vad_model: str = None,
          disable_vad: bool = False,
          punc_model: str = None,
          disable_punc: bool = False,
          model_hub: str = "ms",
          cache_dir: str = None,
          **kwargs,
          ):
    # set logging messages
    logging.basicConfig(
        level=logging.ERROR,
    )
    model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
    if task_name == "asr":
        from funasr.bin.asr_inference_launch import inference_launch
        inference_pipeline = inference_launch(**kwargs)
    elif task_name == "":
        pipeline = 1
    elif task_name == "":
        pipeline = 2
    elif task_name == "":
        pipeline = 2
    def _infer_fn(input, **kwargs):
        data_type = kwargs.get('data_type', 'sound')
        data_path_and_name_and_type = [input, 'speech', data_type]
        raw_inputs = None
        if isinstance(input, torch.Tensor):
            input = input.numpy()
        if isinstance(input, np.ndarray):
            data_path_and_name_and_type = None
            raw_inputs = input
        return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
    return _infer_fn
def main(cmd=None):
    # print(get_commandline_args(), file=sys.stderr)
    from funasr.bin.argument import get_parser
    parser = get_parser()
    parser.add_argument('input', help='input file to transcribe')
    parser.add_argument(
        "--task_name",
        type=str,
        default="asr",
        help="The decoding mode",
    )
    parser.add_argument(
        "-m",
        "--model",
        type=str,
        default="paraformer-zh",
        help="The asr mode name",
    )
    parser.add_argument(
        "-v",
        "--vad_model",
        type=str,
        default="fsmn-vad",
        help="vad model name",
    )
    parser.add_argument(
        "-dv",
        "--disable_vad",
        type=str2bool,
        default=False,
        help="",
    )
    parser.add_argument(
        "-p",
        "--punc_model",
        type=str,
        default="ct-punc",
        help="",
    )
    parser.add_argument(
        "-dp",
        "--disable_punc",
        type=str2bool,
        default=False,
        help="",
    )
    parser.add_argument(
        "--batch_size_token",
        type=int,
        default=5000,
        help="",
    )
    parser.add_argument(
        "--batch_size_token_threshold_s",
        type=int,
        default=35,
        help="",
    )
    parser.add_argument(
        "--max_single_segment_time",
        type=int,
        default=5000,
        help="",
    )
    args = parser.parse_args(cmd)
    kwargs = vars(args)
    # set logging messages
    logging.basicConfig(
        level=logging.ERROR,
    )
    logging.info("Decoding args: {}".format(kwargs))
    # kwargs["ncpu"] = 2 #os.cpu_count()
    kwargs.pop("data_path_and_name_and_type")
    print("args: {}".format(kwargs))
    p = infer(**kwargs)
    res = p(**kwargs)
    print(res)
funasr/utils/download_and_prepare_model.py
New file
@@ -0,0 +1,93 @@
import os
from pathlib import Path
import logging
name_maps_ms = {
    "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
    "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
    "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
    "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
    "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
}
def prepare_model(
    model: str = None,
    # mode: str = None,
    vad_model: str = None,
    punc_model: str = None,
    model_hub: str = "ms",
    cache_dir: str = None,
    **kwargs,
):
    if not Path(model).exists():
        if model_hub == "ms" or model_hub == "modelscope":
            from modelscope.utils.logger import get_logger
            logger = get_logger(log_level=logging.CRITICAL)
            logger.setLevel(logging.CRITICAL)
            try:
                from modelscope.hub.snapshot_download import snapshot_download as download_tool
                model = name_maps_ms[model] if model is not None else None
                vad_model = name_maps_ms[vad_model] if vad_model is not None else None
                punc_model = name_maps_ms[punc_model] if punc_model is not None else None
            except:
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
                      "\npip3 install -U modelscope\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            try:
                model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None))
                print("asr model have been downloaded to: {}".format(model))
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
                    model)
        elif model_hub == "hf" or model_hub == "huggingface":
            download_tool = 0
        else:
            raise "model_hub must be on of ms or hf, but get {}".format(model_hub)
        if vad_model is not None and not Path(vad_model).exists():
            vad_model = download_tool(vad_model, cache_dir=cache_dir)
            print("vad_model have been downloaded to: {}".format(vad_model))
        if punc_model is not None and not Path(punc_model).exists():
            punc_model = download_tool(punc_model, cache_dir=cache_dir)
            print("punc_model have been downloaded to: {}".format(punc_model))
        # asr
        kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"),
                       "asr_model_file": None if model is None else os.path.join(model, "model.pb"),
                       "asr_train_config": None if model is None else os.path.join(model, "config.yaml"),
                       })
        mode = kwargs.get("mode", None)
        if mode is None:
            import json
            json_file = os.path.join(model, 'configuration.json')
            with open(json_file, 'r') as f:
                config_data = json.load(f)
                if config_data['task'] == "punctuation":
                    mode = config_data['model']['punc_model_config']['mode']
                else:
                    mode = config_data['model']['model_config']['mode']
        if vad_model is not None and "vad" not in mode:
            mode = "paraformer_vad"
        kwargs["mode"] = mode
        # vad
        kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"),
                       "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"),
                       "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"),
                       })
        # punc
        kwargs.update({
            "punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"),
            "punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"),
        })
        return model, vad_model, punc_model, kwargs
funasr/version.txt
@@ -1 +1 @@
0.8.4
0.8.5
setup.py
@@ -129,4 +129,7 @@
        "License :: OSI Approved :: Apache Software License",
        "Topic :: Software Development :: Libraries :: Python Modules",
    ],
    entry_points={"console_scripts": [
        "funasr = funasr.bin.inference_cli:main",
    ]},
)