shixian.shi
2024-01-12 c3c78fc5e790d48b3a2f9da79199320c06108d38
funasr/models/transducer/model.py
@@ -17,7 +17,7 @@
import numpy as np
import time
from funasr.losses.label_smoothing_loss import (
   LabelSmoothingLoss,  # noqa: H301
    LabelSmoothingLoss,  # noqa: H301
)
# from funasr.models.ctc import CTC
# from funasr.models.decoder.abs_decoder import AbsDecoder
@@ -39,12 +39,12 @@
from funasr.models.model_class_factory import *
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
   from torch.cuda.amp import autocast
    from torch.cuda.amp import autocast
else:
   # Nothing to do if torch<1.6.0
   @contextmanager
   def autocast(enabled=True):
      yield
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
@@ -52,525 +52,526 @@
class Transducer(nn.Module):
   """ESPnet2ASRTransducerModel module definition."""
    """ESPnet2ASRTransducerModel module definition."""
   def __init__(
      self,
      frontend: Optional[str] = None,
      frontend_conf: Optional[Dict] = None,
      specaug: Optional[str] = None,
      specaug_conf: Optional[Dict] = None,
      normalize: str = None,
      normalize_conf: Optional[Dict] = None,
      encoder: str = None,
      encoder_conf: Optional[Dict] = None,
      decoder: str = None,
      decoder_conf: Optional[Dict] = None,
      joint_network: str = None,
      joint_network_conf: Optional[Dict] = None,
      transducer_weight: float = 1.0,
      fastemit_lambda: float = 0.0,
      auxiliary_ctc_weight: float = 0.0,
      auxiliary_ctc_dropout_rate: float = 0.0,
      auxiliary_lm_loss_weight: float = 0.0,
      auxiliary_lm_loss_smoothing: float = 0.0,
      input_size: int = 80,
      vocab_size: int = -1,
      ignore_id: int = -1,
      blank_id: int = 0,
      sos: int = 1,
      eos: int = 2,
      lsm_weight: float = 0.0,
      length_normalized_loss: bool = False,
      # report_cer: bool = True,
      # report_wer: bool = True,
      # sym_space: str = "<space>",
      # sym_blank: str = "<blank>",
      # extract_feats_in_collect_stats: bool = True,
      share_embedding: bool = False,
      # preencoder: Optional[AbsPreEncoder] = None,
      # postencoder: Optional[AbsPostEncoder] = None,
      **kwargs,
   ):
    def __init__(
        self,
        frontend: Optional[str] = None,
        frontend_conf: Optional[Dict] = None,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        decoder: str = None,
        decoder_conf: Optional[Dict] = None,
        joint_network: str = None,
        joint_network_conf: Optional[Dict] = None,
        transducer_weight: float = 1.0,
        fastemit_lambda: float = 0.0,
        auxiliary_ctc_weight: float = 0.0,
        auxiliary_ctc_dropout_rate: float = 0.0,
        auxiliary_lm_loss_weight: float = 0.0,
        auxiliary_lm_loss_smoothing: float = 0.0,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        # report_cer: bool = True,
        # report_wer: bool = True,
        # sym_space: str = "<space>",
        # sym_blank: str = "<blank>",
        # extract_feats_in_collect_stats: bool = True,
        share_embedding: bool = False,
        # preencoder: Optional[AbsPreEncoder] = None,
        # postencoder: Optional[AbsPostEncoder] = None,
        **kwargs,
    ):
      super().__init__()
        super().__init__()
      if frontend is not None:
         frontend_class = frontend_classes.get_class(frontend)
         frontend = frontend_class(**frontend_conf)
      if specaug is not None:
         specaug_class = specaug_classes.get_class(specaug)
         specaug = specaug_class(**specaug_conf)
      if normalize is not None:
         normalize_class = normalize_classes.get_class(normalize)
         normalize = normalize_class(**normalize_conf)
      encoder_class = encoder_classes.get_class(encoder)
      encoder = encoder_class(input_size=input_size, **encoder_conf)
      encoder_output_size = encoder.output_size()
        if frontend is not None:
            frontend_class = frontend_classes.get_class(frontend)
            frontend = frontend_class(**frontend_conf)
        if specaug is not None:
            specaug_class = specaug_classes.get_class(specaug)
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = normalize_classes.get_class(normalize)
            normalize = normalize_class(**normalize_conf)
        encoder_class = encoder_classes.get_class(encoder)
        encoder = encoder_class(input_size=input_size, **encoder_conf)
        encoder_output_size = encoder.output_size()
      decoder_class = decoder_classes.get_class(decoder)
      decoder = decoder_class(
         vocab_size=vocab_size,
         encoder_output_size=encoder_output_size,
         **decoder_conf,
      )
      decoder_output_size = decoder.output_size
        decoder_class = decoder_classes.get_class(decoder)
        decoder = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=encoder_output_size,
            **decoder_conf,
        )
        decoder_output_size = decoder.output_size
      joint_network_class = joint_network_classes.get_class(decoder)
      joint_network = joint_network_class(
         vocab_size,
         encoder_output_size,
         decoder_output_size,
         **joint_network_conf,
      )
      self.criterion_transducer = None
      self.error_calculator = None
      self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
      self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
      if self.use_auxiliary_ctc:
         self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
         self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
      if self.use_auxiliary_lm_loss:
         self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
         self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
      self.transducer_weight = transducer_weight
      self.fastemit_lambda = fastemit_lambda
      self.auxiliary_ctc_weight = auxiliary_ctc_weight
      self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
      self.blank_id = blank_id
      self.sos = sos if sos is not None else vocab_size - 1
      self.eos = eos if eos is not None else vocab_size - 1
      self.vocab_size = vocab_size
      self.ignore_id = ignore_id
      self.frontend = frontend
      self.specaug = specaug
      self.normalize = normalize
      self.encoder = encoder
      self.decoder = decoder
      self.joint_network = joint_network
        joint_network_class = joint_network_classes.get_class(decoder)
        joint_network = joint_network_class(
            vocab_size,
            encoder_output_size,
            decoder_output_size,
            **joint_network_conf,
        )
        self.criterion_transducer = None
        self.error_calculator = None
        self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
        self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
        if self.use_auxiliary_ctc:
            self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
            self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
        if self.use_auxiliary_lm_loss:
            self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
            self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
        self.transducer_weight = transducer_weight
        self.fastemit_lambda = fastemit_lambda
        self.auxiliary_ctc_weight = auxiliary_ctc_weight
        self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
        self.blank_id = blank_id
        self.sos = sos if sos is not None else vocab_size - 1
        self.eos = eos if eos is not None else vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        self.encoder = encoder
        self.decoder = decoder
        self.joint_network = joint_network
      self.criterion_att = LabelSmoothingLoss(
         size=vocab_size,
         padding_idx=ignore_id,
         smoothing=lsm_weight,
         normalize_length=length_normalized_loss,
      )
      #
      # if report_cer or report_wer:
      #    self.error_calculator = ErrorCalculator(
      #       token_list, sym_space, sym_blank, report_cer, report_wer
      #    )
      #
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        #
        # if report_cer or report_wer:
        #     self.error_calculator = ErrorCalculator(
        #         token_list, sym_space, sym_blank, report_cer, report_wer
        #     )
        #
      self.length_normalized_loss = length_normalized_loss
      self.beam_search = None
   def forward(
      self,
      speech: torch.Tensor,
      speech_lengths: torch.Tensor,
      text: torch.Tensor,
      text_lengths: torch.Tensor,
      **kwargs,
   ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
      """Encoder + Decoder + Calc loss
      Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            text: (Batch, Length)
            text_lengths: (Batch,)
      """
      # import pdb;
      # pdb.set_trace()
      if len(text_lengths.size()) > 1:
         text_lengths = text_lengths[:, 0]
      if len(speech_lengths.size()) > 1:
         speech_lengths = speech_lengths[:, 0]
      batch_size = speech.shape[0]
      # 1. Encoder
      encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
      if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
         encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
                                                                                     chunk_outs=None)
      # 2. Transducer-related I/O preparation
      decoder_in, target, t_len, u_len = get_transducer_task_io(
         text,
         encoder_out_lens,
         ignore_id=self.ignore_id,
      )
      # 3. Decoder
      self.decoder.set_device(encoder_out.device)
      decoder_out = self.decoder(decoder_in, u_len)
      # 4. Joint Network
      joint_out = self.joint_network(
         encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
      )
      # 5. Losses
      loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
         encoder_out,
         joint_out,
         target,
         t_len,
         u_len,
      )
      loss_ctc, loss_lm = 0.0, 0.0
      if self.use_auxiliary_ctc:
         loss_ctc = self._calc_ctc_loss(
            encoder_out,
            target,
            t_len,
            u_len,
         )
      if self.use_auxiliary_lm_loss:
         loss_lm = self._calc_lm_loss(decoder_out, target)
      loss = (
         self.transducer_weight * loss_trans
         + self.auxiliary_ctc_weight * loss_ctc
         + self.auxiliary_lm_loss_weight * loss_lm
      )
      stats = dict(
         loss=loss.detach(),
         loss_transducer=loss_trans.detach(),
         aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
         aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
         cer_transducer=cer_trans,
         wer_transducer=wer_trans,
      )
      # force_gatherable: to-device and to-tensor if scalar for DataParallel
      loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
      return loss, stats, weight
        self.length_normalized_loss = length_normalized_loss
        self.beam_search = None
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        # import pdb;
        # pdb.set_trace()
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
                                                                                        chunk_outs=None)
        # 2. Transducer-related I/O preparation
        decoder_in, target, t_len, u_len = get_transducer_task_io(
            text,
            encoder_out_lens,
            ignore_id=self.ignore_id,
        )
        # 3. Decoder
        self.decoder.set_device(encoder_out.device)
        decoder_out = self.decoder(decoder_in, u_len)
        # 4. Joint Network
        joint_out = self.joint_network(
            encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
        )
        # 5. Losses
        loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
            encoder_out,
            joint_out,
            target,
            t_len,
            u_len,
        )
        loss_ctc, loss_lm = 0.0, 0.0
        if self.use_auxiliary_ctc:
            loss_ctc = self._calc_ctc_loss(
                encoder_out,
                target,
                t_len,
                u_len,
            )
        if self.use_auxiliary_lm_loss:
            loss_lm = self._calc_lm_loss(decoder_out, target)
        loss = (
            self.transducer_weight * loss_trans
            + self.auxiliary_ctc_weight * loss_ctc
            + self.auxiliary_lm_loss_weight * loss_lm
        )
        stats = dict(
            loss=loss.detach(),
            loss_transducer=loss_trans.detach(),
            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
            cer_transducer=cer_trans,
            wer_transducer=wer_trans,
        )
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
   def encode(
      self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
   ) -> Tuple[torch.Tensor, torch.Tensor]:
      """Frontend + Encoder. Note that this method is used by asr_inference.py
      Args:
            speech: (Batch, Length, ...)
            speech_lengths: (Batch, )
            ind: int
      """
      with autocast(False):
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
         # Data augmentation
         if self.specaug is not None and self.training:
            speech, speech_lengths = self.specaug(speech, speech_lengths)
         # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
         if self.normalize is not None:
            speech, speech_lengths = self.normalize(speech, speech_lengths)
      # Forward encoder
      # feats: (Batch, Length, Dim)
      # -> encoder_out: (Batch, Length2, Dim2)
      if self.encoder.interctc_use_conditioning:
         encoder_out, encoder_out_lens, _ = self.encoder(
            speech, speech_lengths, ctc=self.ctc
         )
      else:
         encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
      intermediate_outs = None
      if isinstance(encoder_out, tuple):
         intermediate_outs = encoder_out[1]
         encoder_out = encoder_out[0]
      if intermediate_outs is not None:
         return (encoder_out, intermediate_outs), encoder_out_lens
      return encoder_out, encoder_out_lens
   def _calc_transducer_loss(
      self,
      encoder_out: torch.Tensor,
      joint_out: torch.Tensor,
      target: torch.Tensor,
      t_len: torch.Tensor,
      u_len: torch.Tensor,
   ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
      """Compute Transducer loss.
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(
                speech, speech_lengths, ctc=self.ctc
            )
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
        return encoder_out, encoder_out_lens
    def _calc_transducer_loss(
        self,
        encoder_out: torch.Tensor,
        joint_out: torch.Tensor,
        target: torch.Tensor,
        t_len: torch.Tensor,
        u_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
        """Compute Transducer loss.
      Args:
         encoder_out: Encoder output sequences. (B, T, D_enc)
         joint_out: Joint Network output sequences (B, T, U, D_joint)
         target: Target label ID sequences. (B, L)
         t_len: Encoder output sequences lengths. (B,)
         u_len: Target label ID sequences lengths. (B,)
        Args:
            encoder_out: Encoder output sequences. (B, T, D_enc)
            joint_out: Joint Network output sequences (B, T, U, D_joint)
            target: Target label ID sequences. (B, L)
            t_len: Encoder output sequences lengths. (B,)
            u_len: Target label ID sequences lengths. (B,)
      Return:
         loss_transducer: Transducer loss value.
         cer_transducer: Character error rate for Transducer.
         wer_transducer: Word Error Rate for Transducer.
        Return:
            loss_transducer: Transducer loss value.
            cer_transducer: Character error rate for Transducer.
            wer_transducer: Word Error Rate for Transducer.
      """
      if self.criterion_transducer is None:
         try:
            from warp_rnnt import rnnt_loss as RNNTLoss
            self.criterion_transducer = RNNTLoss
         except ImportError:
            logging.error(
               "warp-rnnt was not installed."
               "Please consult the installation documentation."
            )
            exit(1)
      log_probs = torch.log_softmax(joint_out, dim=-1)
      loss_transducer = self.criterion_transducer(
         log_probs,
         target,
         t_len,
         u_len,
         reduction="mean",
         blank=self.blank_id,
         fastemit_lambda=self.fastemit_lambda,
         gather=True,
      )
      if not self.training and (self.report_cer or self.report_wer):
         if self.error_calculator is None:
            from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator
            self.error_calculator = ErrorCalculator(
               self.decoder,
               self.joint_network,
               self.token_list,
               self.sym_space,
               self.sym_blank,
               report_cer=self.report_cer,
               report_wer=self.report_wer,
            )
         cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
         return loss_transducer, cer_transducer, wer_transducer
      return loss_transducer, None, None
   def _calc_ctc_loss(
      self,
      encoder_out: torch.Tensor,
      target: torch.Tensor,
      t_len: torch.Tensor,
      u_len: torch.Tensor,
   ) -> torch.Tensor:
      """Compute CTC loss.
        """
        if self.criterion_transducer is None:
            try:
                from warp_rnnt import rnnt_loss as RNNTLoss
                self.criterion_transducer = RNNTLoss
            except ImportError:
                logging.error(
                    "warp-rnnt was not installed."
                    "Please consult the installation documentation."
                )
                exit(1)
        log_probs = torch.log_softmax(joint_out, dim=-1)
        loss_transducer = self.criterion_transducer(
            log_probs,
            target,
            t_len,
            u_len,
            reduction="mean",
            blank=self.blank_id,
            fastemit_lambda=self.fastemit_lambda,
            gather=True,
        )
        if not self.training and (self.report_cer or self.report_wer):
            if self.error_calculator is None:
                from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator
                self.error_calculator = ErrorCalculator(
                    self.decoder,
                    self.joint_network,
                    self.token_list,
                    self.sym_space,
                    self.sym_blank,
                    report_cer=self.report_cer,
                    report_wer=self.report_wer,
                )
            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
            return loss_transducer, cer_transducer, wer_transducer
        return loss_transducer, None, None
    def _calc_ctc_loss(
        self,
        encoder_out: torch.Tensor,
        target: torch.Tensor,
        t_len: torch.Tensor,
        u_len: torch.Tensor,
    ) -> torch.Tensor:
        """Compute CTC loss.
      Args:
         encoder_out: Encoder output sequences. (B, T, D_enc)
         target: Target label ID sequences. (B, L)
         t_len: Encoder output sequences lengths. (B,)
         u_len: Target label ID sequences lengths. (B,)
        Args:
            encoder_out: Encoder output sequences. (B, T, D_enc)
            target: Target label ID sequences. (B, L)
            t_len: Encoder output sequences lengths. (B,)
            u_len: Target label ID sequences lengths. (B,)
      Return:
         loss_ctc: CTC loss value.
        Return:
            loss_ctc: CTC loss value.
      """
      ctc_in = self.ctc_lin(
         torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
      )
      ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
      target_mask = target != 0
      ctc_target = target[target_mask].cpu()
      with torch.backends.cudnn.flags(deterministic=True):
         loss_ctc = torch.nn.functional.ctc_loss(
            ctc_in,
            ctc_target,
            t_len,
            u_len,
            zero_infinity=True,
            reduction="sum",
         )
      loss_ctc /= target.size(0)
      return loss_ctc
   def _calc_lm_loss(
      self,
      decoder_out: torch.Tensor,
      target: torch.Tensor,
   ) -> torch.Tensor:
      """Compute LM loss.
        """
        ctc_in = self.ctc_lin(
            torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
        )
        ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
        target_mask = target != 0
        ctc_target = target[target_mask].cpu()
        with torch.backends.cudnn.flags(deterministic=True):
            loss_ctc = torch.nn.functional.ctc_loss(
                ctc_in,
                ctc_target,
                t_len,
                u_len,
                zero_infinity=True,
                reduction="sum",
            )
        loss_ctc /= target.size(0)
        return loss_ctc
    def _calc_lm_loss(
        self,
        decoder_out: torch.Tensor,
        target: torch.Tensor,
    ) -> torch.Tensor:
        """Compute LM loss.
      Args:
         decoder_out: Decoder output sequences. (B, U, D_dec)
         target: Target label ID sequences. (B, L)
        Args:
            decoder_out: Decoder output sequences. (B, U, D_dec)
            target: Target label ID sequences. (B, L)
      Return:
         loss_lm: LM loss value.
        Return:
            loss_lm: LM loss value.
      """
      lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
      lm_target = target.view(-1).type(torch.int64)
      with torch.no_grad():
         true_dist = lm_loss_in.clone()
         true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
         # Ignore blank ID (0)
         ignore = lm_target == 0
         lm_target = lm_target.masked_fill(ignore, 0)
         true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
      loss_lm = torch.nn.functional.kl_div(
         torch.log_softmax(lm_loss_in, dim=1),
         true_dist,
         reduction="none",
      )
      loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
         0
      )
      return loss_lm
   def init_beam_search(self,
                        **kwargs,
                        ):
      from funasr.models.transformer.search import BeamSearch
      from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
      from funasr.models.transformer.scorers.length_bonus import LengthBonus
      # 1. Build ASR model
      scorers = {}
      if self.ctc != None:
         ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
         scorers.update(
            ctc=ctc
         )
      token_list = kwargs.get("token_list")
      scorers.update(
         length_bonus=LengthBonus(len(token_list)),
      )
        """
        lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
        lm_target = target.view(-1).type(torch.int64)
        with torch.no_grad():
            true_dist = lm_loss_in.clone()
            true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
            # Ignore blank ID (0)
            ignore = lm_target == 0
            lm_target = lm_target.masked_fill(ignore, 0)
            true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
        loss_lm = torch.nn.functional.kl_div(
            torch.log_softmax(lm_loss_in, dim=1),
            true_dist,
            reduction="none",
        )
        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
            0
        )
        return loss_lm
    def init_beam_search(self,
                         **kwargs,
                         ):
        from funasr.models.transformer.search import BeamSearch
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus
        # 1. Build ASR model
        scorers = {}
        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = kwargs.get("token_list")
        scorers.update(
            length_bonus=LengthBonus(len(token_list)),
        )
      # 3. Build ngram model
      # ngram is not supported now
      ngram = None
      scorers["ngram"] = ngram
      weights = dict(
         decoder=1.0 - kwargs.get("decoding_ctc_weight"),
         ctc=kwargs.get("decoding_ctc_weight", 0.0),
         lm=kwargs.get("lm_weight", 0.0),
         ngram=kwargs.get("ngram_weight", 0.0),
         length_bonus=kwargs.get("penalty", 0.0),
      )
      beam_search = BeamSearch(
         beam_size=kwargs.get("beam_size", 2),
         weights=weights,
         scorers=scorers,
         sos=self.sos,
         eos=self.eos,
         vocab_size=len(token_list),
         token_list=token_list,
         pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
      )
      # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
      # for scorer in scorers.values():
      #    if isinstance(scorer, torch.nn.Module):
      #       scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
      self.beam_search = beam_search
   def generate(self,
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        weights = dict(
            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
            ctc=kwargs.get("decoding_ctc_weight", 0.0),
            lm=kwargs.get("lm_weight", 0.0),
            ngram=kwargs.get("ngram_weight", 0.0),
            length_bonus=kwargs.get("penalty", 0.0),
        )
        beam_search = BeamSearch(
            beam_size=kwargs.get("beam_size", 2),
            weights=weights,
            scorers=scorers,
            sos=self.sos,
            eos=self.eos,
            vocab_size=len(token_list),
            token_list=token_list,
            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
        )
        # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
        # for scorer in scorers.values():
        #     if isinstance(scorer, torch.nn.Module):
        #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
        self.beam_search = beam_search
    def generate(self,
             data_in: list,
             data_lengths: list=None,
             key: list=None,
             tokenizer=None,
             **kwargs,
             ):
      if kwargs.get("batch_size", 1) > 1:
         raise NotImplementedError("batch decoding is not implemented")
      # init beamsearch
      is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
      is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
      if self.beam_search is None and (is_use_lm or is_use_ctc):
         logging.info("enable beam_search")
         self.init_beam_search(**kwargs)
         self.nbest = kwargs.get("nbest", 1)
      meta_data = {}
      # extract fbank feats
      time1 = time.perf_counter()
      audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
      time2 = time.perf_counter()
      meta_data["load_data"] = f"{time2 - time1:0.3f}"
      speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
      time3 = time.perf_counter()
      meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
      meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
      speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
        # extract fbank feats
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
        time3 = time.perf_counter()
        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
        meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
      # Encoder
      encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
      if isinstance(encoder_out, tuple):
         encoder_out = encoder_out[0]
      # c. Passed the encoder result and the beam search
      nbest_hyps = self.beam_search(
         x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
      )
      nbest_hyps = nbest_hyps[: self.nbest]
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
        )
        nbest_hyps = nbest_hyps[: self.nbest]
      results = []
      b, n, d = encoder_out.size()
      for i in range(b):
        results = []
        b, n, d = encoder_out.size()
        for i in range(b):
         for nbest_idx, hyp in enumerate(nbest_hyps):
            ibest_writer = None
            if ibest_writer is None and kwargs.get("output_dir") is not None:
               writer = DatadirWriter(kwargs.get("output_dir"))
               ibest_writer = writer[f"{nbest_idx+1}best_recog"]
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
               token_int = hyp.yseq[1:last_pos]
            else:
               token_int = hyp.yseq[1:last_pos].tolist()
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
            # Change integer-ids to tokens
            token = tokenizer.ids2tokens(token_int)
            text = tokenizer.tokens2text(token)
            text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
            result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
            results.append(result_i)
            if ibest_writer is not None:
               ibest_writer["token"][key[i]] = " ".join(token)
               ibest_writer["text"][key[i]] = text
               ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
      return results, meta_data
            for nbest_idx, hyp in enumerate(nbest_hyps):
                ibest_writer = None
                if ibest_writer is None and kwargs.get("output_dir") is not None:
                    writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = writer[f"{nbest_idx+1}best_recog"]
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
                # Change integer-ids to tokens
                token = tokenizer.ids2tokens(token_int)
                text = tokenizer.tokens2text(token)
                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
                results.append(result_i)
                if ibest_writer is not None:
                    ibest_writer["token"][key[i]] = " ".join(token)
                    ibest_writer["text"][key[i]] = text
                    ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
        return results, meta_data