yhliang
2023-06-16 e8528b8f6208cee52ed9c02ecfa9185f84706502
funasr/bin/asr_infer.py
@@ -1651,15 +1651,17 @@
        assert check_argument_types()
        
        # 1. Build ASR model
        from funasr.tasks.sa_asr import ASRTask
        from funasr.tasks.asr import ASRTaskSAASR
        scorers = {}
        asr_model, asr_train_args = ASRTask.build_model_from_file(
        asr_model, asr_train_args = ASRTaskSAASR.build_model_from_file(
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            if asr_train_args.frontend == 'wav_frontend':
                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
            from funasr.tasks.sa_asr import frontend_choices
            if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
            else:
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()