| | |
| | | |
| | | if not self.training and (self.report_cer or self.report_wer): |
| | | if self.error_calculator is None: |
| | | from espnet2.asr_transducer.error_calculator import ErrorCalculator |
| | | from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator |
| | | |
| | | self.error_calculator = ErrorCalculator( |
| | | self.decoder, |
| | |
| | | report_wer=self.report_wer, |
| | | ) |
| | | |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len) |
| | | |
| | | return loss_transducer, cer_transducer, wer_transducer |
| | | |
| | |
| | | sym_blank: str = "<blank>", |
| | | report_cer: bool = True, |
| | | report_wer: bool = True, |
| | | sym_sos: str = "<sos/eos>", |
| | | sym_eos: str = "<sos/eos>", |
| | | sym_sos: str = "<s>", |
| | | sym_eos: str = "</s>", |
| | | extract_feats_in_collect_stats: bool = True, |
| | | lsm_weight: float = 0.0, |
| | | length_normalized_loss: bool = False, |
| | |
| | | |
| | | if not self.training and (self.report_cer or self.report_wer): |
| | | if self.error_calculator is None: |
| | | from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator |
| | | |
| | | self.error_calculator = ErrorCalculator( |
| | | self.decoder, |
| | | self.joint_network, |
| | |
| | | report_wer=self.report_wer, |
| | | ) |
| | | |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len) |
| | | return loss_transducer, cer_transducer, wer_transducer |
| | | |
| | | return loss_transducer, None, None |