| | |
| | | else: |
| | | if pre_token_length[i] == 0: |
| | | yseq = torch.tensor( |
| | | [self.asr_model.sos] + [self.asr_model.eos], device=yseq.device |
| | | [self.asr_model.sos] + [self.asr_model.eos], device=pre_acoustic_embeds.device |
| | | ) |
| | | score = torch.tensor(0.0, device=yseq.device) |
| | | score = torch.tensor(0.0, device=pre_acoustic_embeds.device) |
| | | else: |
| | | yseq = am_scores.argmax(dim=-1) |
| | | score = am_scores.max(dim=-1)[0] |
| | |
| | | assert check_argument_types() |
| | | |
| | | # 1. Build ASR model |
| | | from funasr.tasks.sa_asr import ASRTask |
| | | from funasr.tasks.asr import ASRTaskSAASR |
| | | scorers = {} |
| | | asr_model, asr_train_args = ASRTask.build_model_from_file( |
| | | asr_model, asr_train_args = ASRTaskSAASR.build_model_from_file( |
| | | asr_train_config, asr_model_file, cmvn_file, device |
| | | ) |
| | | frontend = None |
| | | if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: |
| | | if asr_train_args.frontend == 'wav_frontend': |
| | | frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) |
| | | from funasr.tasks.sa_asr import frontend_choices |
| | | if asr_train_args.frontend == 'wav_frontend' or asr_train_args.frontend == "multichannelfrontend": |
| | | frontend_class = frontend_choices.get_class(asr_train_args.frontend) |
| | | frontend = frontend_class(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval() |
| | | else: |
| | | frontend_class = frontend_choices.get_class(asr_train_args.frontend) |
| | | frontend = frontend_class(**asr_train_args.frontend_conf).eval() |