huangmingming
2023-01-30 adcee8828ef5d78b575043954deb662a35e318f7
funasr/bin/asr_inference_launch.py
old mode 100755 new mode 100644
@@ -6,6 +6,7 @@
import logging
import os
import sys
from typing import Union, Dict, Any
from funasr.utils import config_argparse
from funasr.utils.cli_utils import get_commandline_args
@@ -74,6 +75,21 @@
    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--vad_infer_config",
        type=str,
        help="VAD infer configuration",
    )
    group.add_argument(
        "--vad_model_file",
        type=str,
        help="VAD model parameter file",
    )
    group.add_argument(
        "--cmvn_file",
        type=str,
        help="Global CMVN file",
    )
    group.add_argument(
        "--asr_train_config",
        type=str,
@@ -146,7 +162,7 @@
    group.add_argument(
        "--ctc_weight",
        type=float,
        default=0.5,
        default=0.0,
        help="CTC weight in joint decoding",
    )
    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
@@ -181,6 +197,33 @@
    return parser
def inference_launch(**kwargs):
    if 'mode' in kwargs:
        mode = kwargs['mode']
    else:
        logging.info("Unknown decoding mode.")
        return None
    if mode == "asr":
        from funasr.bin.asr_inference import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "uniasr":
        from funasr.bin.asr_inference_uniasr import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "paraformer":
        from funasr.bin.asr_inference_paraformer import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "paraformer_vad_punc":
        from funasr.bin.asr_inference_paraformer_vad_punc import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "vad":
        from funasr.bin.vad_inference import inference_modelscope
        return inference_modelscope(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def main(cmd=None):
    print(get_commandline_args(), file=sys.stderr)
    parser = get_parser()
@@ -208,17 +251,7 @@
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
    if args.mode == "asr":
        from funasr.bin.asr_inference import inference
        inference(**kwargs)
    elif args.mode == "uniasr":
        from funasr.bin.asr_inference_uniasr import inference
        inference(**kwargs)
    elif args.mode == "paraformer":
        from funasr.bin.asr_inference_paraformer import inference
        inference(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(args.mode))
    inference_launch(**kwargs)
if __name__ == "__main__":