Merge pull request #423 from alibaba-damo-academy/dev_aky
update error calculator for rnnt
| | |
| | | |
| | | if not self.training and (self.report_cer or self.report_wer): |
| | | if self.error_calculator is None: |
| | | from espnet2.asr_transducer.error_calculator import ErrorCalculator |
| | | from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator |
| | | |
| | | self.error_calculator = ErrorCalculator( |
| | | self.decoder, |
| | |
| | | report_wer=self.report_wer, |
| | | ) |
| | | |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len) |
| | | |
| | | return loss_transducer, cer_transducer, wer_transducer |
| | | |
| | |
| | | |
| | | if not self.training and (self.report_cer or self.report_wer): |
| | | if self.error_calculator is None: |
| | | from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator |
| | | |
| | | self.error_calculator = ErrorCalculator( |
| | | self.decoder, |
| | | self.joint_network, |
| | |
| | | report_wer=self.report_wer, |
| | | ) |
| | | |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len) |
| | | return loss_transducer, cer_transducer, wer_transducer |
| | | |
| | | return loss_transducer, None, None |
| | |
| | | self.report_wer = report_wer |
| | | |
| | | def __call__( |
| | | self, encoder_out: torch.Tensor, target: torch.Tensor |
| | | self, encoder_out: torch.Tensor, target: torch.Tensor, encoder_out_lens: torch.Tensor, |
| | | ) -> Tuple[Optional[float], Optional[float]]: |
| | | """Calculate sentence-level WER or/and CER score for Transducer model. |
| | | Args: |
| | | encoder_out: Encoder output sequences. (B, T, D_enc) |
| | | target: Target label ID sequences. (B, L) |
| | | encoder_out_lens: Encoder output sequences length. (B,) |
| | | Returns: |
| | | : Sentence-level CER score. |
| | | : Sentence-level WER score. |
| | |
| | | |
| | | encoder_out = encoder_out.to(next(self.decoder.parameters()).device) |
| | | |
| | | batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)] |
| | | batch_nbest = [ |
| | | self.beam_search(encoder_out[b][: encoder_out_lens[b]]) |
| | | for b in range(batchsize) |
| | | ] |
| | | pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest] |
| | | |
| | | char_pred, char_target = self.convert_to_char(pred, target) |