From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 16:49:56 +0800
Subject: [PATCH] rnnt reorg

---
 funasr/modules/e2e_asr_common.py |  151 ++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 151 insertions(+), 0 deletions(-)

diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py
index 92f9079..9b5039c 100644
--- a/funasr/modules/e2e_asr_common.py
+++ b/funasr/modules/e2e_asr_common.py
@@ -6,6 +6,8 @@
 
 """Common functions for ASR."""
 
+from typing import List, Optional, Tuple
+
 import json
 import logging
 import sys
@@ -13,7 +15,11 @@
 from itertools import groupby
 import numpy as np
 import six
+import torch
 
+from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.joint_network import JointNetwork
 
 def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
     """End detection.
@@ -247,3 +253,148 @@
             word_eds.append(editdistance.eval(hyp_words, ref_words))
             word_ref_lens.append(len(ref_words))
         return float(sum(word_eds)) / sum(word_ref_lens)
+
+class ErrorCalculatorTransducer:
+    """Calculate CER and WER for transducer models.
+    Args:
+        decoder: Decoder module.
+        joint_network: Joint Network module.
+        token_list: List of token units.
+        sym_space: Space symbol.
+        sym_blank: Blank symbol.
+        report_cer: Whether to compute CER.
+        report_wer: Whether to compute WER.
+    """
+
+    def __init__(
+        self,
+        decoder: AbsDecoder,
+        joint_network: JointNetwork,
+        token_list: List[int],
+        sym_space: str,
+        sym_blank: str,
+        report_cer: bool = False,
+        report_wer: bool = False,
+    ) -> None:
+        """Construct an ErrorCalculatorTransducer object."""
+        super().__init__()
+
+        self.beam_search = BeamSearchTransducer(
+            decoder=decoder,
+            joint_network=joint_network,
+            beam_size=1,
+            search_type="default",
+            score_norm=False,
+        )
+
+        self.decoder = decoder
+
+        self.token_list = token_list
+        self.space = sym_space
+        self.blank = sym_blank
+
+        self.report_cer = report_cer
+        self.report_wer = report_wer
+
+    def __call__(
+        self, encoder_out: torch.Tensor, target: torch.Tensor
+    ) -> Tuple[Optional[float], Optional[float]]:
+        """Calculate sentence-level WER or/and CER score for Transducer model.
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            target: Target label ID sequences. (B, L)
+        Returns:
+            : Sentence-level CER score.
+            : Sentence-level WER score.
+        """
+        cer, wer = None, None
+
+        batchsize = int(encoder_out.size(0))
+
+        encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
+
+        batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
+        pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
+
+        char_pred, char_target = self.convert_to_char(pred, target)
+
+        if self.report_cer:
+            cer = self.calculate_cer(char_pred, char_target)
+
+        if self.report_wer:
+            wer = self.calculate_wer(char_pred, char_target)
+
+        return cer, wer
+
+    def convert_to_char(
+        self, pred: torch.Tensor, target: torch.Tensor
+    ) -> Tuple[List, List]:
+        """Convert label ID sequences to character sequences.
+        Args:
+            pred: Prediction label ID sequences. (B, U)
+            target: Target label ID sequences. (B, L)
+        Returns:
+            char_pred: Prediction character sequences. (B, ?)
+            char_target: Target character sequences. (B, ?)
+        """
+        char_pred, char_target = [], []
+
+        for i, pred_i in enumerate(pred):
+            char_pred_i = [self.token_list[int(h)] for h in pred_i]
+            char_target_i = [self.token_list[int(r)] for r in target[i]]
+
+            char_pred_i = "".join(char_pred_i).replace(self.space, " ")
+            char_pred_i = char_pred_i.replace(self.blank, "")
+
+            char_target_i = "".join(char_target_i).replace(self.space, " ")
+            char_target_i = char_target_i.replace(self.blank, "")
+
+            char_pred.append(char_pred_i)
+            char_target.append(char_target_i)
+
+        return char_pred, char_target
+
+    def calculate_cer(
+        self, char_pred: torch.Tensor, char_target: torch.Tensor
+    ) -> float:
+        """Calculate sentence-level CER score.
+        Args:
+            char_pred: Prediction character sequences. (B, ?)
+            char_target: Target character sequences. (B, ?)
+        Returns:
+            : Average sentence-level CER score.
+        """
+        import editdistance
+
+        distances, lens = [], []
+
+        for i, char_pred_i in enumerate(char_pred):
+            pred = char_pred_i.replace(" ", "")
+            target = char_target[i].replace(" ", "")
+            distances.append(editdistance.eval(pred, target))
+            lens.append(len(target))
+
+        return float(sum(distances)) / sum(lens)
+
+    def calculate_wer(
+        self, char_pred: torch.Tensor, char_target: torch.Tensor
+    ) -> float:
+        """Calculate sentence-level WER score.
+        Args:
+            char_pred: Prediction character sequences. (B, ?)
+            char_target: Target character sequences. (B, ?)
+        Returns:
+            : Average sentence-level WER score
+        """
+        import editdistance
+
+        distances, lens = [], []
+
+        for i, char_pred_i in enumerate(char_pred):
+            pred = char_pred_i.replace("鈻�", " ").split()
+            target = char_target[i].replace("鈻�", " ").split()
+
+            distances.append(editdistance.eval(pred, target))
+            lens.append(len(target))
+
+        return float(sum(distances)) / sum(lens)

--
Gitblit v1.9.1