游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/modules/e2e_asr_common.py
@@ -18,7 +18,6 @@
import torch
from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
from funasr.models.joint_net.joint_network import JointNetwork
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
@@ -268,7 +267,7 @@
    def __init__(
        self,
        decoder: AbsDecoder,
        decoder,
        joint_network: JointNetwork,
        token_list: List[int],
        sym_space: str,
@@ -297,12 +296,13 @@
        self.report_wer = report_wer
    def __call__(
        self, encoder_out: torch.Tensor, target: torch.Tensor
        self, encoder_out: torch.Tensor, target: torch.Tensor, encoder_out_lens: torch.Tensor,
    ) -> Tuple[Optional[float], Optional[float]]:
        """Calculate sentence-level WER or/and CER score for Transducer model.
        Args:
            encoder_out: Encoder output sequences. (B, T, D_enc)
            target: Target label ID sequences. (B, L)
            encoder_out_lens: Encoder output sequences length. (B,)
        Returns:
            : Sentence-level CER score.
            : Sentence-level WER score.
@@ -313,7 +313,10 @@
        encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
        batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
        batch_nbest = [
            self.beam_search(encoder_out[b][: encoder_out_lens[b]])
            for b in range(batchsize)
        ]
        pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
        char_pred, char_target = self.convert_to_char(pred, target)