游雁
2023-06-21 1af8a233ce99b6c6a8a119eaa7363ebae1f2570f
funasr/bin/asr_infer.py
@@ -316,7 +316,7 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, device
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            scorers["lm"] = lm.lm
@@ -636,7 +636,7 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, device
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            scorers["lm"] = lm.lm
@@ -1120,7 +1120,7 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, device
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            lm.to(device)
            scorers["lm"] = lm.lm
@@ -1343,7 +1343,7 @@
        if lm_train_config is not None:
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, device
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            lm_scorer = lm.lm
        else:
@@ -1636,8 +1636,10 @@
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            if asr_train_args.frontend == 'wav_frontend':
                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
            from funasr.tasks.sa_asr import frontend_choices
            if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend":
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
            else:
                frontend_class = frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
@@ -1659,7 +1661,7 @@
        # 2. Build Language model
        if lm_train_config is not None:
            lm, lm_train_args = build_model_from_file(
                lm_train_config, lm_file, None, device
                lm_train_config, lm_file, None, device, task_name="lm"
            )
            scorers["lm"] = lm.lm