| | |
| | | |
| | | if not self.training and (self.report_cer or self.report_wer): |
| | | if self.error_calculator is None: |
| | | from espnet2.asr_transducer.error_calculator import ErrorCalculator |
| | | from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator |
| | | |
| | | self.error_calculator = ErrorCalculator( |
| | | self.decoder, |
| | |
| | | report_wer=self.report_wer, |
| | | ) |
| | | |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len) |
| | | |
| | | return loss_transducer, cer_transducer, wer_transducer |
| | | |
| | |
| | | |
| | | if not self.training and (self.report_cer or self.report_wer): |
| | | if self.error_calculator is None: |
| | | from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator |
| | | |
| | | self.error_calculator = ErrorCalculator( |
| | | self.decoder, |
| | | self.joint_network, |
| | |
| | | report_wer=self.report_wer, |
| | | ) |
| | | |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len) |
| | | return loss_transducer, cer_transducer, wer_transducer |
| | | |
| | | return loss_transducer, None, None |