From bdb8a99da425fca4813952718a62ba02ef06fa6e Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 26 四月 2023 10:34:48 +0800
Subject: [PATCH] update error calculator for rnnt
---
funasr/modules/e2e_asr_common.py | 8 ++++++--
funasr/models/e2e_asr_transducer.py | 8 +++++---
2 files changed, 11 insertions(+), 5 deletions(-)
diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py
index 657dd75..f8ba0f0 100644
--- a/funasr/models/e2e_asr_transducer.py
+++ b/funasr/models/e2e_asr_transducer.py
@@ -386,7 +386,7 @@
if not self.training and (self.report_cer or self.report_wer):
if self.error_calculator is None:
- from espnet2.asr_transducer.error_calculator import ErrorCalculator
+ from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
self.error_calculator = ErrorCalculator(
self.decoder,
@@ -398,7 +398,7 @@
report_wer=self.report_wer,
)
- cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+ cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
return loss_transducer, cer_transducer, wer_transducer
@@ -889,6 +889,8 @@
if not self.training and (self.report_cer or self.report_wer):
if self.error_calculator is None:
+ from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
+
self.error_calculator = ErrorCalculator(
self.decoder,
self.joint_network,
@@ -899,7 +901,7 @@
report_wer=self.report_wer,
)
- cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+ cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
return loss_transducer, cer_transducer, wer_transducer
return loss_transducer, None, None
diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py
index f430fcb..98006f9 100644
--- a/funasr/modules/e2e_asr_common.py
+++ b/funasr/modules/e2e_asr_common.py
@@ -296,12 +296,13 @@
self.report_wer = report_wer
def __call__(
- self, encoder_out: torch.Tensor, target: torch.Tensor
+ self, encoder_out: torch.Tensor, target: torch.Tensor, encoder_out_lens: torch.Tensor,
) -> Tuple[Optional[float], Optional[float]]:
"""Calculate sentence-level WER or/and CER score for Transducer model.
Args:
encoder_out: Encoder output sequences. (B, T, D_enc)
target: Target label ID sequences. (B, L)
+ encoder_out_lens: Encoder output sequences length. (B,)
Returns:
: Sentence-level CER score.
: Sentence-level WER score.
@@ -312,7 +313,10 @@
encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
- batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
+ batch_nbest = [
+ self.beam_search(encoder_out[b][: encoder_out_lens[b]])
+ for b in range(batchsize)
+ ]
pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
char_pred, char_target = self.convert_to_char(pred, target)
--
Gitblit v1.9.1