游雁
2024-02-21 e772c7eb9e5439aaff2f599e79f0b3c8fdca22c2
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
merge
1个文件已修改
3个文件已删除
1865 ■■■■ 已修改文件
funasr/models/bat/attention.py 238 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/bat/cif_predictor.py 220 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/bat/conformer_chunk_encoder.py 701 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/bat/model.py 706 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/bat/attention.py
File was deleted
funasr/models/bat/cif_predictor.py
File was deleted
funasr/models/bat/conformer_chunk_encoder.py
File was deleted
funasr/models/bat/model.py
@@ -3,137 +3,145 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import time
import torch
import logging
import torch.nn as nn
from contextlib import contextmanager
from typing import Dict, Optional, Tuple
from distutils.version import LooseVersion
from typing import Dict, List, Optional, Tuple, Union
from torch.cuda.amp import autocast
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.register import tables
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.scorers.length_bonus import LengthBonus
from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
class BATModel(nn.Module):
    """BATModel module definition.
    Args:
        vocab_size: Size of complete vocabulary (w/ EOS and blank included).
        token_list: List of token
        frontend: Frontend module.
        specaug: SpecAugment module.
        normalize: Normalization module.
        encoder: Encoder module.
        decoder: Decoder module.
        joint_network: Joint Network module.
        transducer_weight: Weight of the Transducer loss.
        fastemit_lambda: FastEmit lambda value.
        auxiliary_ctc_weight: Weight of auxiliary CTC loss.
        auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
        auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
        auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
        ignore_id: Initial padding ID.
        sym_space: Space symbol.
        sym_blank: Blank Symbol
        report_cer: Whether to report Character Error Rate during validation.
        report_wer: Whether to report Word Error Rate during validation.
        extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
    """
@tables.register("model_classes", "BAT")  # TODO: BAT training
class BAT(torch.nn.Module):
    def __init__(
        self,
        cif_weight: float = 1.0,
        frontend: Optional[str] = None,
        frontend_conf: Optional[Dict] = None,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        decoder: str = None,
        decoder_conf: Optional[Dict] = None,
        joint_network: str = None,
        joint_network_conf: Optional[Dict] = None,
        transducer_weight: float = 1.0,
        fastemit_lambda: float = 0.0,
        auxiliary_ctc_weight: float = 0.0,
        auxiliary_ctc_dropout_rate: float = 0.0,
        auxiliary_lm_loss_weight: float = 0.0,
        auxiliary_lm_loss_smoothing: float = 0.0,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
        report_cer: bool = True,
        report_wer: bool = True,
        extract_feats_in_collect_stats: bool = True,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        r_d: int = 5,
        r_u: int = 5,
        # report_cer: bool = True,
        # report_wer: bool = True,
        # sym_space: str = "<space>",
        # sym_blank: str = "<blank>",
        # extract_feats_in_collect_stats: bool = True,
        share_embedding: bool = False,
        # preencoder: Optional[AbsPreEncoder] = None,
        # postencoder: Optional[AbsPostEncoder] = None,
        **kwargs,
    ) -> None:
        """Construct an BATModel object."""
    ):
        super().__init__()
        # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
        self.blank_id = 0
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = tables.normalize_classes.get(normalize)
            normalize = normalize_class(**normalize_conf)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(input_size=input_size, **encoder_conf)
        encoder_output_size = encoder.output_size()
        decoder_class = tables.decoder_classes.get(decoder)
        decoder = decoder_class(
            vocab_size=vocab_size,
            **decoder_conf,
        )
        decoder_output_size = decoder.output_size
        joint_network_class = tables.joint_network_classes.get(joint_network)
        joint_network = joint_network_class(
            vocab_size,
            encoder_output_size,
            decoder_output_size,
            **joint_network_conf,
        )
        self.criterion_transducer = None
        self.error_calculator = None
        self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
        self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
        if self.use_auxiliary_ctc:
            self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
            self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
        if self.use_auxiliary_lm_loss:
            self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
            self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
        self.transducer_weight = transducer_weight
        self.fastemit_lambda = fastemit_lambda
        self.auxiliary_ctc_weight = auxiliary_ctc_weight
        self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
        self.blank_id = blank_id
        self.sos = sos if sos is not None else vocab_size - 1
        self.eos = eos if eos is not None else vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.token_list = token_list.copy()
        self.sym_space = sym_space
        self.sym_blank = sym_blank
        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        self.encoder = encoder
        self.decoder = decoder
        self.joint_network = joint_network
        self.criterion_transducer = None
        self.error_calculator = None
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
        self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
        if self.use_auxiliary_ctc:
            self.ctc_lin = torch.nn.Linear(encoder.output_size(), vocab_size)
            self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
        if self.use_auxiliary_lm_loss:
            self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
            self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
        self.transducer_weight = transducer_weight
        self.fastemit_lambda = fastemit_lambda
        self.auxiliary_ctc_weight = auxiliary_ctc_weight
        self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
        self.report_cer = report_cer
        self.report_wer = report_wer
        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
        self.criterion_pre = torch.nn.L1Loss()
        self.predictor_weight = predictor_weight
        self.predictor = predictor
        self.cif_weight = cif_weight
        if self.cif_weight > 0:
            self.cif_output_layer = torch.nn.Linear(encoder.output_size(), vocab_size)
            self.criterion_cif = LabelSmoothingLoss(
                size=vocab_size,
                padding_idx=ignore_id,
                smoothing=lsm_weight,
                normalize_length=length_normalized_loss,
            )
        self.r_d = r_d
        self.r_u = r_u
        self.length_normalized_loss = length_normalized_loss
        self.beam_search = None
        self.ctc = None
        self.ctc_weight = 0.0
    def forward(
        self,
        speech: torch.Tensor,
@@ -142,111 +150,167 @@
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Forward architecture and compute loss(es).
        """Encoder + Decoder + Calc loss
        Args:
            speech: Speech sequences. (B, S)
            speech_lengths: Speech sequences lengths. (B,)
            text: Label ID sequences. (B, L)
            text_lengths: Label ID sequences lengths. (B,)
            kwargs: Contains "utts_id".
        Return:
            loss: Main loss value.
            stats: Task statistics.
            weight: Task weights.
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        assert (
            speech.shape[0]
            == speech_lengths.shape[0]
            == text.shape[0]
            == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]
        text = text[:, : text_lengths.max()]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None:
            encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens,
                                                                                        chunk_outs=None)
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(encoder_out.device)
        # 2. Transducer-related I/O preparation
        decoder_in, target, t_len, u_len = get_transducer_task_io(
            text,
            encoder_out_lens,
            ignore_id=self.ignore_id,
        )
        # 3. Decoder
        self.decoder.set_device(encoder_out.device)
        decoder_out = self.decoder(decoder_in, u_len)
        pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=self.ignore_id)
        loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length), pre_token_length)
        if self.cif_weight > 0.0:
            cif_predict = self.cif_output_layer(pre_acoustic_embeds)
            loss_cif = self.criterion_cif(cif_predict, text)
        else:
            loss_cif = 0.0
        # 5. Losses
        boundary = torch.zeros((encoder_out.size(0), 4), dtype=torch.int64, device=encoder_out.device)
        boundary[:, 2] = u_len.long().detach()
        boundary[:, 3] = t_len.long().detach()
        pre_peak_index = torch.floor(pre_peak_index).long()
        s_begin = pre_peak_index - self.r_d
        T = encoder_out.size(1)
        B = encoder_out.size(0)
        U = decoder_out.size(1)
        mask = torch.arange(0, T, device=encoder_out.device).reshape(1, T).expand(B, T)
        mask = mask <= boundary[:, 3].reshape(B, 1) - 1
        s_begin_padding = boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
        # handle the cases where `len(symbols) < s_range`
        s_begin_padding = torch.clamp(s_begin_padding, min=0)
        s_begin = torch.where(mask, s_begin, s_begin_padding)
        
        mask2 = s_begin <  boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1
        s_begin = torch.where(mask2, s_begin, boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1)
        s_begin = torch.clamp(s_begin, min=0)
        ranges = s_begin.reshape((B, T, 1)).expand((B, T, min(self.r_u+self.r_d, min(u_len)))) + torch.arange(min(self.r_d+self.r_u, min(u_len)), device=encoder_out.device)
        import fast_rnnt
        am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning(
            am=self.joint_network.lin_enc(encoder_out),
            lm=self.joint_network.lin_dec(decoder_out),
            ranges=ranges,
        # 4. Joint Network
        joint_out = self.joint_network(
            encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
        )
        logits = self.joint_network(am_pruned, lm_pruned, project_input=False)
        with torch.cuda.amp.autocast(enabled=False):
            loss_trans = fast_rnnt.rnnt_loss_pruned(
                logits=logits.float(),
                symbols=target.long(),
                ranges=ranges,
                termination_symbol=self.blank_id,
                boundary=boundary,
                reduction="sum",
        # 5. Losses
        loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
            encoder_out,
            joint_out,
            target,
            t_len,
            u_len,
        )
        loss_ctc, loss_lm = 0.0, 0.0
        if self.use_auxiliary_ctc:
            loss_ctc = self._calc_ctc_loss(
                encoder_out,
                target,
                t_len,
                u_len,
            )
        if self.use_auxiliary_lm_loss:
            loss_lm = self._calc_lm_loss(decoder_out, target)
        loss = (
            self.transducer_weight * loss_trans
            + self.auxiliary_ctc_weight * loss_ctc
            + self.auxiliary_lm_loss_weight * loss_lm
        )
        stats = dict(
            loss=loss.detach(),
            loss_transducer=loss_trans.detach(),
            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
            cer_transducer=cer_trans,
            wer_transducer=wer_trans,
        )
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
        cer_trans, wer_trans = None, None
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
        return encoder_out, encoder_out_lens
    def _calc_transducer_loss(
        self,
        encoder_out: torch.Tensor,
        joint_out: torch.Tensor,
        target: torch.Tensor,
        t_len: torch.Tensor,
        u_len: torch.Tensor,
    ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
        """Compute Transducer loss.
        Args:
            encoder_out: Encoder output sequences. (B, T, D_enc)
            joint_out: Joint Network output sequences (B, T, U, D_joint)
            target: Target label ID sequences. (B, L)
            t_len: Encoder output sequences lengths. (B,)
            u_len: Target label ID sequences lengths. (B,)
        Return:
            loss_transducer: Transducer loss value.
            cer_transducer: Character error rate for Transducer.
            wer_transducer: Word Error Rate for Transducer.
        """
        if self.criterion_transducer is None:
            try:
                from warp_rnnt import rnnt_loss as RNNTLoss
                self.criterion_transducer = RNNTLoss
            except ImportError:
                logging.error(
                    "warp-rnnt was not installed."
                    "Please consult the installation documentation."
                )
                exit(1)
        log_probs = torch.log_softmax(joint_out, dim=-1)
        loss_transducer = self.criterion_transducer(
            log_probs,
            target,
            t_len,
            u_len,
            reduction="mean",
            blank=self.blank_id,
            fastemit_lambda=self.fastemit_lambda,
            gather=True,
        )
        if not self.training and (self.report_cer or self.report_wer):
            if self.error_calculator is None:
                from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator
                self.error_calculator = ErrorCalculator(
                    self.decoder,
                    self.joint_network,
@@ -256,149 +320,13 @@
                    report_cer=self.report_cer,
                    report_wer=self.report_wer,
                )
            cer_trans, wer_trans = self.error_calculator(encoder_out, target, t_len)
        loss_ctc, loss_lm = 0.0, 0.0
        if self.use_auxiliary_ctc:
            loss_ctc = self._calc_ctc_loss(
                encoder_out,
                target,
                t_len,
                u_len,
            )
        if self.use_auxiliary_lm_loss:
            loss_lm = self._calc_lm_loss(decoder_out, target)
        loss = (
            self.transducer_weight * loss_trans
            + self.auxiliary_ctc_weight * loss_ctc
            + self.auxiliary_lm_loss_weight * loss_lm
            + self.predictor_weight * loss_pre
            + self.cif_weight * loss_cif
        )
        stats = dict(
            loss=loss.detach(),
            loss_transducer=loss_trans.detach(),
            loss_pre=loss_pre.detach(),
            loss_cif=loss_cif.detach() if loss_cif > 0.0 else None,
            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
            cer_transducer=cer_trans,
            wer_transducer=wer_trans,
        )
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def collect_feats(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Dict[str, torch.Tensor]:
        """Collect features sequences and features lengths sequences.
        Args:
            speech: Speech sequences. (B, S)
            speech_lengths: Speech sequences lengths. (B,)
            text: Label ID sequences. (B, L)
            text_lengths: Label ID sequences lengths. (B,)
            kwargs: Contains "utts_id".
        Return:
            {}: "feats": Features sequences. (B, T, D_feats),
                "feats_lengths": Features sequences lengths. (B,)
        """
        if self.extract_feats_in_collect_stats:
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        else:
            # Generate dummy stats if extract_feats_in_collect_stats is False
            logging.warning(
                "Generating dummy stats for feats and feats_lengths, "
                "because encoder_conf.extract_feats_in_collect_stats is "
                f"{self.extract_feats_in_collect_stats}"
            )
            feats, feats_lengths = speech, speech_lengths
        return {"feats": feats, "feats_lengths": feats_lengths}
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder speech sequences.
        Args:
            speech: Speech sequences. (B, S)
            speech_lengths: Speech sequences lengths. (B,)
        Return:
            encoder_out: Encoder outputs. (B, T, D_enc)
            encoder_out_lens: Encoder outputs lengths. (B,)
        """
        with autocast(False):
            # 1. Extract feats
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
            # 2. Data augmentation
            if self.specaug is not None and self.training:
                feats, feats_lengths = self.specaug(feats, feats_lengths)
            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                feats, feats_lengths = self.normalize(feats, feats_lengths)
        # 4. Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
            speech.size(0),
        )
        assert encoder_out.size(1) <= encoder_out_lens.max(), (
            encoder_out.size(),
            encoder_out_lens.max(),
        )
        return encoder_out, encoder_out_lens
    def _extract_feats(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract features sequences and features sequences lengths.
        Args:
            speech: Speech sequences. (B, S)
            speech_lengths: Speech sequences lengths. (B,)
        Return:
            feats: Features sequences. (B, T, D_feats)
            feats_lengths: Features sequences lengths. (B,)
        """
        assert speech_lengths.dim() == 1, speech_lengths.shape
        # for data-parallel
        speech = speech[:, : speech_lengths.max()]
        if self.frontend is not None:
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            feats, feats_lengths = speech, speech_lengths
        return feats, feats_lengths
            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len)
            return loss_transducer, cer_transducer, wer_transducer
        return loss_transducer, None, None
    def _calc_ctc_loss(
        self,
        encoder_out: torch.Tensor,
@@ -422,10 +350,10 @@
            torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
        )
        ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
        target_mask = target != 0
        ctc_target = target[target_mask].cpu()
        with torch.backends.cudnn.flags(deterministic=True):
            loss_ctc = torch.nn.functional.ctc_loss(
                ctc_in,
@@ -436,9 +364,9 @@
                reduction="sum",
            )
        loss_ctc /= target.size(0)
        return loss_ctc
    def _calc_lm_loss(
        self,
        decoder_out: torch.Tensor,
@@ -456,17 +384,17 @@
        """
        lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
        lm_target = target.view(-1).type(torch.int64)
        with torch.no_grad():
            true_dist = lm_loss_in.clone()
            true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
            # Ignore blank ID (0)
            ignore = lm_target == 0
            lm_target = lm_target.masked_fill(ignore, 0)
            true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
        loss_lm = torch.nn.functional.kl_div(
            torch.log_softmax(lm_loss_in, dim=1),
            true_dist,
@@ -475,5 +403,117 @@
        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
            0
        )
        return loss_lm
    def init_beam_search(self,
                         **kwargs,
                         ):
        # 1. Build ASR model
        scorers = {}
        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = kwargs.get("token_list")
        scorers.update(
            length_bonus=LengthBonus(len(token_list)),
        )
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        beam_search = BeamSearchTransducer(
            self.decoder,
            self.joint_network,
            kwargs.get("beam_size", 2),
            nbest=1,
        )
        # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
        # for scorer in scorers.values():
        #     if isinstance(scorer, torch.nn.Module):
        #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
        self.beam_search = beam_search
    def inference(self,
                  data_in: list,
                  data_lengths: list=None,
                  key: list=None,
                  tokenizer=None,
                  **kwargs,
                  ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        # if self.beam_search is None and (is_use_lm or is_use_ctc):
        logging.info("enable beam_search")
        self.init_beam_search(**kwargs)
        self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
        # extract fbank feats
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
        time3 = time.perf_counter()
        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
        meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(encoder_out[0], is_final=True)
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        b, n, d = encoder_out.size()
        for i in range(b):
            for nbest_idx, hyp in enumerate(nbest_hyps):
                ibest_writer = None
                if kwargs.get("output_dir") is not None:
                    if not hasattr(self, "writer"):
                        self.writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq#[1:last_pos]
                else:
                    token_int = hyp.yseq#[1:last_pos].tolist()
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
                # Change integer-ids to tokens
                token = tokenizer.ids2tokens(token_int)
                text = tokenizer.tokens2text(token)
                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
                results.append(result_i)
                if ibest_writer is not None:
                    ibest_writer["token"][key[i]] = " ".join(token)
                    ibest_writer["text"][key[i]] = text
                    ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
        return results, meta_data