zhifu gao
2023-04-26 97ed4fada4770a11ee803e72020d793fcc5251c0
Merge pull request #423 from alibaba-damo-academy/dev_aky

update error calculator for rnnt
2个文件已修改
16 ■■■■■ 已修改文件
funasr/models/e2e_asr_transducer.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/e2e_asr_common.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_transducer.py
@@ -386,7 +386,7 @@
        if not self.training and (self.report_cer or self.report_wer):
            if self.error_calculator is None:
                from espnet2.asr_transducer.error_calculator import ErrorCalculator
                from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
                self.error_calculator = ErrorCalculator(
                    self.decoder,
@@ -398,7 +398,7 @@
                    report_wer=self.report_wer,
                )
            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
            return loss_transducer, cer_transducer, wer_transducer
@@ -889,6 +889,8 @@
        if not self.training and (self.report_cer or self.report_wer):
            if self.error_calculator is None:
                from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
                self.error_calculator = ErrorCalculator(
                    self.decoder,
                    self.joint_network,
@@ -899,7 +901,7 @@
                    report_wer=self.report_wer,
                )
            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
            return loss_transducer, cer_transducer, wer_transducer
        return loss_transducer, None, None
funasr/modules/e2e_asr_common.py
@@ -296,12 +296,13 @@
        self.report_wer = report_wer
    def __call__(
        self, encoder_out: torch.Tensor, target: torch.Tensor
        self, encoder_out: torch.Tensor, target: torch.Tensor, encoder_out_lens: torch.Tensor,
    ) -> Tuple[Optional[float], Optional[float]]:
        """Calculate sentence-level WER or/and CER score for Transducer model.
        Args:
            encoder_out: Encoder output sequences. (B, T, D_enc)
            target: Target label ID sequences. (B, L)
            encoder_out_lens: Encoder output sequences length. (B,)
        Returns:
            : Sentence-level CER score.
            : Sentence-level WER score.
@@ -312,7 +313,10 @@
        encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
        batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
        batch_nbest = [
            self.beam_search(encoder_out[b][: encoder_out_lens[b]])
            for b in range(batchsize)
        ]
        pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
        char_pred, char_target = self.convert_to_char(pred, target)