yhliang
2023-04-27 32d2b3ec153e53176da710ebcc0aba5669effd8a
funasr/models/e2e_asr_transducer.py
@@ -386,7 +386,7 @@
        if not self.training and (self.report_cer or self.report_wer):
            if self.error_calculator is None:
                from espnet2.asr_transducer.error_calculator import ErrorCalculator
                from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
                self.error_calculator = ErrorCalculator(
                    self.decoder,
@@ -398,7 +398,7 @@
                    report_wer=self.report_wer,
                )
            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
            return loss_transducer, cer_transducer, wer_transducer
@@ -531,8 +531,8 @@
        sym_blank: str = "<blank>",
        report_cer: bool = True,
        report_wer: bool = True,
        sym_sos: str = "<sos/eos>",
        sym_eos: str = "<sos/eos>",
        sym_sos: str = "<s>",
        sym_eos: str = "</s>",
        extract_feats_in_collect_stats: bool = True,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
@@ -889,6 +889,8 @@
        if not self.training and (self.report_cer or self.report_wer):
            if self.error_calculator is None:
                from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
                self.error_calculator = ErrorCalculator(
                    self.decoder,
                    self.joint_network,
@@ -899,7 +901,7 @@
                    report_wer=self.report_wer,
                )
            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
            return loss_transducer, cer_transducer, wer_transducer
        return loss_transducer, None, None