游雁
2023-02-14 1d4ab65c8bfebaecbcb0eec0064bae9a321cad75
funasr/bin/asr_inference_launch.py
old mode 100755 new mode 100644
@@ -76,6 +76,21 @@
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--vad_infer_config",
        type=str,
        help="VAD infer configuration",
    )
    group.add_argument(
        "--vad_model_file",
        type=str,
        help="VAD model parameter file",
    )
    group.add_argument(
        "--cmvn_file",
        type=str,
        help="Global CMVN file",
    )
    group.add_argument(
        "--asr_train_config",
        type=str,
        help="ASR training configuration",
@@ -147,7 +162,7 @@
    group.add_argument(
        "--ctc_weight",
        type=float,
        default=0.5,
        default=0.0,
        help="CTC weight in joint decoding",
    )
    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
@@ -182,17 +197,50 @@
    return parser
def set_parameters(language: str = None,
                   sample_rate: Union[int, Dict[Any, int]] = None):
    if language is not None:
        global global_asr_language
        global_asr_language = language
    if sample_rate is not None:
        global global_sample_rate
        global_sample_rate = sample_rate
def inference_launch(**kwargs):
    if 'mode' in kwargs:
        mode = kwargs['mode']
    else:
        logging.info("Unknown decoding mode.")
        return None
    if mode == "asr":
        from funasr.bin.asr_inference import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "uniasr":
        from funasr.bin.asr_inference_uniasr import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "uniasr_vad":
        from funasr.bin.asr_inference_uniasr_vad import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "paraformer":
        from funasr.bin.asr_inference_paraformer import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "paraformer_vad":
        from funasr.bin.asr_inference_paraformer_vad import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "paraformer_punc":
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
    elif mode == "paraformer_vad_punc":
        from funasr.bin.asr_inference_paraformer_vad_punc import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "vad":
        from funasr.bin.vad_inference import inference_modelscope
        return inference_modelscope(**kwargs)
    elif mode == "mfcca":
        from funasr.bin.asr_inference_mfcca import inference_modelscope
        return inference_modelscope(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
def inference_launch(mode, **kwargs):
def inference_launch_funasr(**kwargs):
    if 'mode' in kwargs:
        mode = kwargs['mode']
    else:
        logging.info("Unknown decoding mode.")
        return None
    if mode == "asr":
        from funasr.bin.asr_inference import inference
        return inference(**kwargs)
@@ -202,6 +250,15 @@
    elif mode == "paraformer":
        from funasr.bin.asr_inference_paraformer import inference
        return inference(**kwargs)
    elif mode == "paraformer_vad_punc":
        from funasr.bin.asr_inference_paraformer_vad_punc import inference
        return inference(**kwargs)
    elif mode == "vad":
        from funasr.bin.vad_inference import inference
        return inference(**kwargs)
    elif mode == "mfcca":
        from funasr.bin.asr_inference_mfcca import inference_modelscope
        return inference_modelscope(**kwargs)
    else:
        logging.info("Unknown decoding mode: {}".format(mode))
        return None
@@ -234,7 +291,7 @@
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
    inference_launch(**kwargs)
    inference_launch_funasr(**kwargs)
if __name__ == "__main__":