zhifu gao
2024-01-24 e4035edb4620e483279923be3de9af9c35b9ce67
Funasr1.0 (#1297)

* fix add_file bug (#1296)

Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* funasr1.0 uniasr

* funasr1.0 uniasr

---------

Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
2个文件已修改
3个文件已添加
1194 ■■■■■ 已修改文件
examples/industrial_data_pretraining/uniasr/demo.py 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/uniasr/infer.sh 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/transformer/model.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/uniasr/beam_search.py 496 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/uniasr/model.py 656 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/uniasr/demo.py
New file
@@ -0,0 +1,29 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr import AutoModel
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_UniASR_asr_2pass-zh-cn-16k-common-vocab8358-tensorflow1-online", model_revision="v2.0.4",
                  # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  # vad_model_revision="v2.0.4",
                  # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  # punc_model_revision="v2.0.4",
                  )
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav")
print(res)
''' can not use currently
from funasr import AutoFrontend
frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
for batch_idx, fbank_dict in enumerate(fbanks):
    res = model.generate(**fbank_dict)
    print(res)
'''
examples/industrial_data_pretraining/uniasr/infer.sh
New file
@@ -0,0 +1,11 @@
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
python funasr/bin/inference.py \
+model=${model} \
+model_revision=${model_revision} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
+output_dir="./outputs/debug" \
+device="cpu" \
funasr/models/transformer/model.py
@@ -348,7 +348,7 @@
        scorers["ngram"] = ngram
        
        weights = dict(
            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
            ctc=kwargs.get("decoding_ctc_weight", 0.0),
            lm=kwargs.get("lm_weight", 0.0),
            ngram=kwargs.get("ngram_weight", 0.0),
funasr/models/uniasr/beam_search.py
New file
@@ -0,0 +1,496 @@
"""Beam search module."""
from itertools import chain
import logging
from typing import Any
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Tuple
from typing import Union
import torch
from funasr.metrics.common import end_detect
from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
class Hypothesis(NamedTuple):
    """Hypothesis data type."""
    yseq: torch.Tensor
    score: Union[float, torch.Tensor] = 0
    scores: Dict[str, Union[float, torch.Tensor]] = dict()
    states: Dict[str, Any] = dict()
    def asdict(self) -> dict:
        """Convert data to JSON-friendly dict."""
        return self._replace(
            yseq=self.yseq.tolist(),
            score=float(self.score),
            scores={k: float(v) for k, v in self.scores.items()},
        )._asdict()
class BeamSearchScama(torch.nn.Module):
    """Beam search implementation."""
    def __init__(
        self,
        scorers: Dict[str, ScorerInterface],
        weights: Dict[str, float],
        beam_size: int,
        vocab_size: int,
        sos: int,
        eos: int,
        token_list: List[str] = None,
        pre_beam_ratio: float = 1.5,
        pre_beam_score_key: str = None,
    ):
        """Initialize beam search.
        Args:
            scorers (dict[str, ScorerInterface]): Dict of decoder modules
                e.g., Decoder, CTCPrefixScorer, LM
                The scorer will be ignored if it is `None`
            weights (dict[str, float]): Dict of weights for each scorers
                The scorer will be ignored if its weight is 0
            beam_size (int): The number of hypotheses kept during search
            vocab_size (int): The number of vocabulary
            sos (int): Start of sequence id
            eos (int): End of sequence id
            token_list (list[str]): List of tokens for debug log
            pre_beam_score_key (str): key of scores to perform pre-beam search
            pre_beam_ratio (float): beam size in the pre-beam search
                will be `int(pre_beam_ratio * beam_size)`
        """
        super().__init__()
        # set scorers
        self.weights = weights
        self.scorers = dict()
        self.full_scorers = dict()
        self.part_scorers = dict()
        # this module dict is required for recursive cast
        # `self.to(device, dtype)` in `recog.py`
        self.nn_dict = torch.nn.ModuleDict()
        for k, v in scorers.items():
            w = weights.get(k, 0)
            if w == 0 or v is None:
                continue
            assert isinstance(
                v, ScorerInterface
            ), f"{k} ({type(v)}) does not implement ScorerInterface"
            self.scorers[k] = v
            if isinstance(v, PartialScorerInterface):
                self.part_scorers[k] = v
            else:
                self.full_scorers[k] = v
            if isinstance(v, torch.nn.Module):
                self.nn_dict[k] = v
        # set configurations
        self.sos = sos
        self.eos = eos
        self.token_list = token_list
        self.pre_beam_size = int(pre_beam_ratio * beam_size)
        self.beam_size = beam_size
        self.n_vocab = vocab_size
        if (
            pre_beam_score_key is not None
            and pre_beam_score_key != "full"
            and pre_beam_score_key not in self.full_scorers
        ):
            raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
        self.pre_beam_score_key = pre_beam_score_key
        self.do_pre_beam = (
            self.pre_beam_score_key is not None
            and self.pre_beam_size < self.n_vocab
            and len(self.part_scorers) > 0
        )
    def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
        """Get an initial hypothesis data.
        Args:
            x (torch.Tensor): The encoder output feature
        Returns:
            Hypothesis: The initial hypothesis.
        """
        init_states = dict()
        init_scores = dict()
        for k, d in self.scorers.items():
            init_states[k] = d.init_state(x)
            init_scores[k] = 0.0
        return [
            Hypothesis(
                score=0.0,
                scores=init_scores,
                states=init_states,
                yseq=torch.tensor([self.sos], device=x.device),
            )
        ]
    @staticmethod
    def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
        """Append new token to prefix tokens.
        Args:
            xs (torch.Tensor): The prefix token
            x (int): The new token to append
        Returns:
            torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
        """
        x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
        return torch.cat((xs, x))
    def score_full(
        self, hyp: Hypothesis,
        x: torch.Tensor,
        x_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
        """Score new hypothesis by `self.full_scorers`.
        Args:
            hyp (Hypothesis): Hypothesis with prefix tokens to score
            x (torch.Tensor): Corresponding input feature
        Returns:
            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
                score dict of `hyp` that has string keys of `self.full_scorers`
                and tensor score values of shape: `(self.n_vocab,)`,
                and state dict that has string keys
                and state values of `self.full_scorers`
        """
        scores = dict()
        states = dict()
        for k, d in self.full_scorers.items():
            scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
        return scores, states
    def score_partial(
        self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
        """Score new hypothesis by `self.part_scorers`.
        Args:
            hyp (Hypothesis): Hypothesis with prefix tokens to score
            ids (torch.Tensor): 1D tensor of new partial tokens to score
            x (torch.Tensor): Corresponding input feature
        Returns:
            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
                score dict of `hyp` that has string keys of `self.part_scorers`
                and tensor score values of shape: `(len(ids),)`,
                and state dict that has string keys
                and state values of `self.part_scorers`
        """
        scores = dict()
        states = dict()
        for k, d in self.part_scorers.items():
            scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
        return scores, states
    def beam(
        self, weighted_scores: torch.Tensor, ids: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute topk full token ids and partial token ids.
        Args:
            weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
            Its shape is `(self.n_vocab,)`.
            ids (torch.Tensor): The partial token ids to compute topk
        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                The topk full token ids and partial token ids.
                Their shapes are `(self.beam_size,)`
        """
        # no pre beam performed
        if weighted_scores.size(0) == ids.size(0):
            top_ids = weighted_scores.topk(self.beam_size)[1]
            return top_ids, top_ids
        # mask pruned in pre-beam not to select in topk
        tmp = weighted_scores[ids]
        weighted_scores[:] = -float("inf")
        weighted_scores[ids] = tmp
        top_ids = weighted_scores.topk(self.beam_size)[1]
        local_ids = weighted_scores[ids].topk(self.beam_size)[1]
        return top_ids, local_ids
    @staticmethod
    def merge_scores(
        prev_scores: Dict[str, float],
        next_full_scores: Dict[str, torch.Tensor],
        full_idx: int,
        next_part_scores: Dict[str, torch.Tensor],
        part_idx: int,
    ) -> Dict[str, torch.Tensor]:
        """Merge scores for new hypothesis.
        Args:
            prev_scores (Dict[str, float]):
                The previous hypothesis scores by `self.scorers`
            next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
            full_idx (int): The next token id for `next_full_scores`
            next_part_scores (Dict[str, torch.Tensor]):
                scores of partial tokens by `self.part_scorers`
            part_idx (int): The new token id for `next_part_scores`
        Returns:
            Dict[str, torch.Tensor]: The new score dict.
                Its keys are names of `self.full_scorers` and `self.part_scorers`.
                Its values are scalar tensors by the scorers.
        """
        new_scores = dict()
        for k, v in next_full_scores.items():
            new_scores[k] = prev_scores[k] + v[full_idx]
        for k, v in next_part_scores.items():
            new_scores[k] = prev_scores[k] + v[part_idx]
        return new_scores
    def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
        """Merge states for new hypothesis.
        Args:
            states: states of `self.full_scorers`
            part_states: states of `self.part_scorers`
            part_idx (int): The new token id for `part_scores`
        Returns:
            Dict[str, torch.Tensor]: The new score dict.
                Its keys are names of `self.full_scorers` and `self.part_scorers`.
                Its values are states of the scorers.
        """
        new_states = dict()
        for k, v in states.items():
            new_states[k] = v
        for k, d in self.part_scorers.items():
            new_states[k] = d.select_state(part_states[k], part_idx)
        return new_states
    def search(
        self, running_hyps: List[Hypothesis],
        x: torch.Tensor,
        x_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
    ) -> List[Hypothesis]:
        """Search new tokens for running hypotheses and encoded speech x.
        Args:
            running_hyps (List[Hypothesis]): Running hypotheses on beam
            x (torch.Tensor): Encoded speech feature (T, D)
        Returns:
            List[Hypotheses]: Best sorted hypotheses
        """
        best_hyps = []
        part_ids = torch.arange(self.n_vocab, device=x.device)  # no pre-beam
        for hyp in running_hyps:
            # scoring
            weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
            scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
            for k in self.full_scorers:
                weighted_scores += self.weights[k] * scores[k]
            # partial scoring
            if self.do_pre_beam:
                pre_beam_scores = (
                    weighted_scores
                    if self.pre_beam_score_key == "full"
                    else scores[self.pre_beam_score_key]
                )
                part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
            part_scores, part_states = self.score_partial(hyp, part_ids, x)
            for k in self.part_scorers:
                weighted_scores[part_ids] += self.weights[k] * part_scores[k]
            # add previous hyp score
            weighted_scores += hyp.score
            # update hyps
            for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
                # will be (2 x beam at most)
                best_hyps.append(
                    Hypothesis(
                        score=weighted_scores[j],
                        yseq=self.append_token(hyp.yseq, j),
                        scores=self.merge_scores(
                            hyp.scores, scores, j, part_scores, part_j
                        ),
                        states=self.merge_states(states, part_states, part_j),
                    )
                )
            # sort and prune 2 x beam -> beam
            best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
                : min(len(best_hyps), self.beam_size)
            ]
        return best_hyps
    def forward(
        self, x: torch.Tensor,
        scama_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        maxlen: int = None,
        minlen: int = 0,
    ) -> List[Hypothesis]:
        """Perform beam search.
        Args:
            x (torch.Tensor): Encoded speech feature (T, D)
            maxlenratio (float): Input length ratio to obtain max output length.
                If maxlenratio=0.0 (default), it uses a end-detect function
                to automatically find maximum hypothesis lengths
                If maxlenratio<0.0, its absolute value is interpreted
                as a constant max output length.
            minlenratio (float): Input length ratio to obtain min output length.
        Returns:
            list[Hypothesis]: N-best decoding results
        """
        if maxlen is None:
            # set length bounds
            if maxlenratio == 0:
                maxlen = x.shape[0]
            elif maxlenratio < 0:
                maxlen = -1 * int(maxlenratio)
            else:
                maxlen = max(1, int(maxlenratio * x.size(0)))
            minlen = int(minlenratio * x.size(0))
        logging.info("decoder input length: " + str(x.shape[0]))
        logging.info("max output length: " + str(maxlen))
        logging.info("min output length: " + str(minlen))
        # main loop of prefix search
        running_hyps = self.init_hyp(x)
        ended_hyps = []
        for i in range(maxlen):
            logging.debug("position " + str(i))
            mask_enc = None
            if scama_mask is not None:
                token_num_predictor = scama_mask.size(1)
                token_id_slice = min(i, token_num_predictor-1)
                mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
                # if mask_enc.size(1) == 0:
                #     mask_enc = scama_mask[:, -2:-1, :]
                #     # mask_enc = torch.zeros_like(mask_enc)
            pre_acoustic_embeds_cur = None
            if pre_acoustic_embeds is not None:
                b, t, d = pre_acoustic_embeds.size()
                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
                pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
                token_id_slice = min(i, t)
                pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
            best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur)
            # post process of one iteration
            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
            # end detection
            if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
                logging.info(f"end detected at {i}")
                break
            if len(running_hyps) == 0:
                logging.info("no hypothesis. Finish decoding.")
                break
            else:
                logging.debug(f"remained hypotheses: {len(running_hyps)}")
        nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
        # check the number of hypotheses reaching to eos
        if len(nbest_hyps) == 0:
            logging.warning(
                "there is no N-best results, perform recognition "
                "again with smaller minlenratio."
            )
            return (
                []
                if minlenratio < 0.1
                else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
            )
        # report the best result
        for x in nbest_hyps:
            yseq = "".join([self.token_list[x] for x in x.yseq])
            logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
        best = nbest_hyps[0]
        for k, v in best.scores.items():
            logging.info(
                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
            )
        logging.info(f"total log probability: {best.score:.2f}")
        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
        if self.token_list is not None:
            logging.info(
                "best hypo: "
                + "".join([self.token_list[x] for x in best.yseq[1:-1]])
                + "\n"
            )
        return nbest_hyps
    def post_process(
        self,
        i: int,
        maxlen: int,
        maxlenratio: float,
        running_hyps: List[Hypothesis],
        ended_hyps: List[Hypothesis],
    ) -> List[Hypothesis]:
        """Perform post-processing of beam search iterations.
        Args:
            i (int): The length of hypothesis tokens.
            maxlen (int): The maximum length of tokens in beam search.
            maxlenratio (int): The maximum length ratio in beam search.
            running_hyps (List[Hypothesis]): The running hypotheses in beam search.
            ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
        Returns:
            List[Hypothesis]: The new running hypotheses.
        """
        logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
        if self.token_list is not None:
            logging.debug(
                "best hypo: "
                + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
            )
        # add eos in the final loop to avoid that there are no ended hyps
        if i == maxlen - 1:
            logging.info("adding <eos> in the last position in the loop")
            running_hyps = [
                h._replace(yseq=self.append_token(h.yseq, self.eos))
                for h in running_hyps
            ]
        # add ended hypotheses to a final list, and removed them from current hypotheses
        # (this will be a problem, number of hyps < beam)
        remained_hyps = []
        for hyp in running_hyps:
            if hyp.yseq[-1] == self.eos:
                # e.g., Word LM needs to add final <eos> score
                for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
                    s = d.final_score(hyp.states[k])
                    hyp.scores[k] += s
                    hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
                ended_hyps.append(hyp)
            else:
                remained_hyps.append(hyp)
        return remained_hyps
funasr/models/uniasr/model.py
@@ -14,14 +14,13 @@
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.models.scama.utils import sequence_mask
@tables.register("model_classes", "UniASR")
class UniASR(torch.nn.Module):
@@ -31,19 +30,37 @@
    def __init__(
        self,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        encoder_conf: dict = None,
        encoder2: str = None,
        encoder2_conf: dict = None,
        decoder: str = None,
        decoder_conf: Optional[Dict] = None,
        ctc: str = None,
        ctc_conf: Optional[Dict] = None,
        decoder_conf: dict = None,
        decoder2: str = None,
        decoder2_conf: dict = None,
        predictor: str = None,
        predictor_conf: Optional[Dict] = None,
        predictor_conf: dict = None,
        predictor_bias: int = 0,
        predictor_weight: float = 0.0,
        predictor2: str = None,
        predictor2_conf: dict = None,
        predictor2_bias: int = 0,
        predictor2_weight: float = 0.0,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        ctc2: str = None,
        ctc2_conf: dict = None,
        ctc2_weight: float = 0.5,
        decoder_attention_chunk_type: str = 'chunk',
        decoder_attention_chunk_type2: str = 'chunk',
        stride_conv=None,
        stride_conv_conf: dict = None,
        loss_weight_model1: float = 0.5,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
@@ -52,60 +69,72 @@
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        # report_cer: bool = True,
        # report_wer: bool = True,
        # sym_space: str = "<space>",
        # sym_blank: str = "<blank>",
        # extract_feats_in_collect_stats: bool = True,
        # predictor=None,
        predictor_weight: float = 0.0,
        predictor_bias: int = 0,
        sampling_ratio: float = 0.2,
        share_embedding: bool = False,
        # preencoder: Optional[AbsPreEncoder] = None,
        # postencoder: Optional[AbsPostEncoder] = None,
        use_1st_decoder_loss: bool = False,
        encoder1_encoder2_joint_training: bool = True,
        **kwargs,
        
    ):
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
        assert 0.0 <= interctc_weight < 1.0, interctc_weight
        super().__init__()
        self.blank_id = 0
        self.sos = 1
        self.eos = 2
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = tables.normalize_classes.get(normalize)
            normalize = normalize_class(**normalize_conf)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(input_size=input_size, **encoder_conf)
        encoder_output_size = encoder.output_size()
        decoder_class = tables.decoder_classes.get(decoder)
        decoder = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=encoder_output_size,
            **decoder_conf,
        )
        predictor_class = tables.predictor_classes.get(predictor)
        predictor = predictor_class(**predictor_conf)
        from funasr.models.transformer.utils.subsampling import Conv1dSubsampling
        stride_conv = Conv1dSubsampling(**stride_conv_conf, idim=input_size + encoder_output_size,
                                        odim=input_size + encoder_output_size)
        stride_conv_output_size = stride_conv.output_size()
        encoder_class = tables.encoder_classes.get(encoder2)
        encoder2 = encoder_class(input_size=stride_conv_output_size, **encoder2_conf)
        encoder2_output_size = encoder2.output_size()
        decoder_class = tables.decoder_classes.get(decoder2)
        decoder2 = decoder_class(
            vocab_size=vocab_size,
            encoder_output_size=encoder2_output_size,
            **decoder2_conf,
        )
        predictor_class = tables.predictor_classes.get(predictor2)
        predictor2 = predictor_class(**predictor2_conf)
        self.blank_id = blank_id
        self.sos = sos
        self.eos = eos
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.ctc_weight = ctc_weight
        self.interctc_weight = interctc_weight
        self.token_list = token_list.copy()
        self.ctc2_weight = ctc2_weight
        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        self.preencoder = preencoder
        self.postencoder = postencoder
        self.encoder = encoder
        if not hasattr(self.encoder, "interctc_use_conditioning"):
            self.encoder.interctc_use_conditioning = False
        if self.encoder.interctc_use_conditioning:
            self.encoder.conditioning_layer = torch.nn.Linear(
                vocab_size, self.encoder.output_size()
            )
        self.error_calculator = None
        # we set self.decoder = None in the CTC mode since
        # self.decoder parameters were never used and PyTorch complained
        # and threw an Exception in the multi-GPU experiment.
        # thanks Jeff Farris for pointing out the issue.
        if ctc_weight == 1.0:
            self.decoder = None
        else:
            self.decoder = decoder
        self.decoder = decoder
        self.ctc = None
        self.ctc2 = None
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
@@ -113,22 +142,13 @@
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        if report_cer or report_wer:
            self.error_calculator = ErrorCalculator(
                token_list, sym_space, sym_blank, report_cer, report_wer
            )
        if ctc_weight == 0.0:
            self.ctc = None
        else:
            self.ctc = ctc
        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
        self.predictor = predictor
        self.predictor_weight = predictor_weight
        self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
        self.step_cur = 0
        self.encoder1_encoder2_joint_training = kwargs.get("encoder1_encoder2_joint_training", True)
        if self.encoder.overlap_chunk_cls is not None:
            from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
            self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
@@ -136,14 +156,10 @@
        self.encoder2 = encoder2
        self.decoder2 = decoder2
        self.ctc_weight2 = ctc_weight2
        if ctc_weight2 == 0.0:
            self.ctc2 = None
        else:
            self.ctc2 = ctc2
        self.interctc_weight2 = interctc_weight2
        self.ctc2_weight = ctc2_weight
        self.predictor2 = predictor2
        self.predictor_weight2 = predictor_weight2
        self.predictor2_weight = predictor2_weight
        self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
        self.stride_conv = stride_conv
        self.loss_weight_model1 = loss_weight_model1
@@ -152,10 +168,10 @@
            self.build_scama_mask_for_cross_attention_decoder_fn2 = build_scama_mask_for_cross_attention_decoder
            self.decoder_attention_chunk_type2 = decoder_attention_chunk_type2
        self.enable_maas_finetune = enable_maas_finetune
        self.freeze_encoder2 = freeze_encoder2
        self.encoder1_encoder2_joint_training = encoder1_encoder2_joint_training
        self.length_normalized_loss = length_normalized_loss
        self.enable_maas_finetune = kwargs.get("enable_maas_finetune", False)
        self.freeze_encoder2 = kwargs.get("freeze_encoder2", False)
        self.beam_search = None
    def forward(
        self,
@@ -163,7 +179,7 @@
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        decoding_ind: int = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
@@ -172,19 +188,14 @@
                        text: (Batch, Length)
                        text_lengths: (Batch,)
        """
        assert text_lengths.dim() == 1, text_lengths.shape
        # Check that batch_size is unified
        assert (
            speech.shape[0]
            == speech_lengths.shape[0]
            == text.shape[0]
            == text_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
        decoding_ind = kwargs.get("decoding_ind", None)
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]
        # for data-parallel
        text = text[:, : text_lengths.max()]
        speech = speech[:, :speech_lengths.max()]
        ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
        # 1. Encoder
@@ -194,10 +205,6 @@
        else:
            speech_raw, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        loss_att, acc_att, cer_att, wer_att = None, None, None, None
        loss_ctc, cer_ctc = None, None
@@ -210,62 +217,12 @@
            # 1. CTC branch
            if self.enable_maas_finetune:
                with torch.no_grad():
                    if self.ctc_weight != 0.0:
                        if self.encoder.overlap_chunk_cls is not None:
                            encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
                                                                                                                encoder_out_lens,
                                                                                                                chunk_outs=None)
                        loss_ctc, cer_ctc = self._calc_ctc_loss(
                            encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
                        )
                        # Collect CTC branch stats
                        stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
                        stats["cer_ctc"] = cer_ctc
                    loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
                        encoder_out, encoder_out_lens, text, text_lengths
                    )
                    # Intermediate CTC (optional)
                    loss_interctc = 0.0
                    if self.interctc_weight != 0.0 and intermediate_outs is not None:
                        for layer_idx, intermediate_out in intermediate_outs:
                            # we assume intermediate_out has the same length & padding
                            # as those of encoder_out
                            if self.encoder.overlap_chunk_cls is not None:
                                encoder_out_ctc, encoder_out_lens_ctc = \
                                    self.encoder.overlap_chunk_cls.remove_chunk(
                                        intermediate_out,
                                        encoder_out_lens,
                                        chunk_outs=None)
                            loss_ic, cer_ic = self._calc_ctc_loss(
                                encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
                            )
                            loss_interctc = loss_interctc + loss_ic
                            # Collect Intermedaite CTC stats
                            stats["loss_interctc_layer{}".format(layer_idx)] = (
                                loss_ic.detach() if loss_ic is not None else None
                            )
                            stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
                        loss_interctc = loss_interctc / len(intermediate_outs)
                        # calculate whole encoder loss
                        loss_ctc = (
                                    1 - self.interctc_weight
                                ) * loss_ctc + self.interctc_weight * loss_interctc
                    # 2b. Attention decoder branch
                    if self.ctc_weight != 1.0:
                        loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
                            encoder_out, encoder_out_lens, text, text_lengths
                        )
                    # 3. CTC-Att loss definition
                    if self.ctc_weight == 0.0:
                        loss = loss_att + loss_pre * self.predictor_weight
                    elif self.ctc_weight == 1.0:
                        loss = loss_ctc
                    else:
                        loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
                    loss = loss_att + loss_pre * self.predictor_weight
                    # Collect Attn branch stats
                    stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -274,62 +231,13 @@
                    stats["wer"] = wer_att
                    stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
            else:
                if self.ctc_weight != 0.0:
                    if self.encoder.overlap_chunk_cls is not None:
                        encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
                                                                                                            encoder_out_lens,
                                                                                                            chunk_outs=None)
                    loss_ctc, cer_ctc = self._calc_ctc_loss(
                        encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
                    )
                loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
                    encoder_out, encoder_out_lens, text, text_lengths
                )
                    # Collect CTC branch stats
                    stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
                    stats["cer_ctc"] = cer_ctc
                    # Intermediate CTC (optional)
                loss_interctc = 0.0
                if self.interctc_weight != 0.0 and intermediate_outs is not None:
                    for layer_idx, intermediate_out in intermediate_outs:
                        # we assume intermediate_out has the same length & padding
                        # as those of encoder_out
                        if self.encoder.overlap_chunk_cls is not None:
                            encoder_out_ctc, encoder_out_lens_ctc = \
                                self.encoder.overlap_chunk_cls.remove_chunk(
                                    intermediate_out,
                                    encoder_out_lens,
                                    chunk_outs=None)
                        loss_ic, cer_ic = self._calc_ctc_loss(
                            encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
                        )
                        loss_interctc = loss_interctc + loss_ic
                        # Collect Intermedaite CTC stats
                        stats["loss_interctc_layer{}".format(layer_idx)] = (
                            loss_ic.detach() if loss_ic is not None else None
                        )
                        stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
                    loss_interctc = loss_interctc / len(intermediate_outs)
                    # calculate whole encoder loss
                    loss_ctc = (
                                1 - self.interctc_weight
                            ) * loss_ctc + self.interctc_weight * loss_interctc
                # 2b. Attention decoder branch
                if self.ctc_weight != 1.0:
                    loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss(
                        encoder_out, encoder_out_lens, text, text_lengths
                    )
                # 3. CTC-Att loss definition
                if self.ctc_weight == 0.0:
                    loss = loss_att + loss_pre * self.predictor_weight
                elif self.ctc_weight == 1.0:
                    loss = loss_ctc
                else:
                    loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
                loss = loss_att + loss_pre * self.predictor_weight
                # Collect Attn branch stats
                stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -354,67 +262,14 @@
            if isinstance(encoder_out, tuple):
                intermediate_outs = encoder_out[1]
                encoder_out = encoder_out[0]
            # CTC2
            if self.ctc_weight2 != 0.0:
                if self.encoder2.overlap_chunk_cls is not None:
                    encoder_out_ctc, encoder_out_lens_ctc = \
                        self.encoder2.overlap_chunk_cls.remove_chunk(
                            encoder_out,
                            encoder_out_lens,
                            chunk_outs=None,
                        )
                loss_ctc, cer_ctc = self._calc_ctc_loss2(
                    encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
                )
                # Collect CTC branch stats
                stats["loss_ctc2"] = loss_ctc.detach() if loss_ctc is not None else None
                stats["cer_ctc2"] = cer_ctc
            # Intermediate CTC (optional)
            loss_interctc = 0.0
            if self.interctc_weight2 != 0.0 and intermediate_outs is not None:
                for layer_idx, intermediate_out in intermediate_outs:
                    # we assume intermediate_out has the same length & padding
                    # as those of encoder_out
                    if self.encoder2.overlap_chunk_cls is not None:
                        encoder_out_ctc, encoder_out_lens_ctc = \
                            self.encoder2.overlap_chunk_cls.remove_chunk(
                                intermediate_out,
                                encoder_out_lens,
                                chunk_outs=None)
                    loss_ic, cer_ic = self._calc_ctc_loss2(
                        encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
                    )
                    loss_interctc = loss_interctc + loss_ic
            loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
                encoder_out, encoder_out_lens, text, text_lengths
            )
                    # Collect Intermedaite CTC stats
                    stats["loss_interctc_layer{}2".format(layer_idx)] = (
                        loss_ic.detach() if loss_ic is not None else None
                    )
                    stats["cer_interctc_layer{}2".format(layer_idx)] = cer_ic
                loss_interctc = loss_interctc / len(intermediate_outs)
                # calculate whole encoder loss
                loss_ctc = (
                               1 - self.interctc_weight2
                           ) * loss_ctc + self.interctc_weight2 * loss_interctc
            # 2b. Attention decoder branch
            if self.ctc_weight2 != 1.0:
                loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_predictor_loss2(
                    encoder_out, encoder_out_lens, text, text_lengths
                )
            # 3. CTC-Att loss definition
            if self.ctc_weight2 == 0.0:
                loss = loss_att + loss_pre * self.predictor_weight2
            elif self.ctc_weight2 == 1.0:
                loss = loss_ctc
            else:
                loss = self.ctc_weight2 * loss_ctc + (
                    1 - self.ctc_weight2) * loss_att + loss_pre * self.predictor_weight2
            loss = loss_att + loss_pre * self.predictor2_weight
            # Collect Attn branch stats
            stats["loss_att2"] = loss_att.detach() if loss_att is not None else None
@@ -422,6 +277,7 @@
            stats["cer2"] = cer_att
            stats["wer2"] = wer_att
            stats["loss_pre2"] = loss_pre.detach().cpu() if loss_pre is not None else None
        loss2 = loss
        loss = loss1 * self.loss_weight_model1 + loss2 * (1 - self.loss_weight_model1)
@@ -456,60 +312,30 @@
        return {"feats": feats, "feats_lengths": feats_lengths}
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ):
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                        speech: (Batch, Length, ...)
                        speech_lengths: (Batch, )
        """
        ind = kwargs.get("ind", 0)
        with autocast(False):
            # 1. Extract feats
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
            # 2. Data augmentation
            # Data augmentation
            if self.specaug is not None and self.training:
                feats, feats_lengths = self.specaug(feats, feats_lengths)
            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
                speech, speech_lengths = self.specaug(speech, speech_lengths)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                feats, feats_lengths = self.normalize(feats, feats_lengths)
        speech_raw = feats.clone().to(feats.device)
        # Pre-encoder, e.g. used for raw input data
        if self.preencoder is not None:
            feats, feats_lengths = self.preencoder(feats, feats_lengths)
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        speech_raw = speech.clone().to(speech.device)
        # 4. Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(
                feats, feats_lengths, ctc=self.ctc, ind=ind
            )
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, ind=ind)
        intermediate_outs = None
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths, ind=ind)
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        # Post-encoder, e.g. NLU
        if self.postencoder is not None:
            encoder_out, encoder_out_lens = self.postencoder(
                encoder_out, encoder_out_lens
            )
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
            speech.size(0),
        )
        assert encoder_out.size(1) <= encoder_out_lens.max(), (
            encoder_out.size(),
            encoder_out_lens.max(),
        )
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
        return speech_raw, encoder_out, encoder_out_lens
@@ -519,28 +345,15 @@
        encoder_out_lens: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        ind: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        **kwargs,
    ):
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                        speech: (Batch, Length, ...)
                        speech_lengths: (Batch, )
        """
        # with autocast(False):
        #     # 1. Extract feats
        #     feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        #
        #     # 2. Data augmentation
        #     if self.specaug is not None and self.training:
        #         feats, feats_lengths = self.specaug(feats, feats_lengths)
        #
        #     # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
        #     if self.normalize is not None:
        #         feats, feats_lengths = self.normalize(feats, feats_lengths)
        # Pre-encoder, e.g. used for raw input data
        # if self.preencoder is not None:
        #     feats, feats_lengths = self.preencoder(feats, feats_lengths)
        ind = kwargs.get("ind", 0)
        encoder_out_rm, encoder_out_lens_rm = self.encoder.overlap_chunk_cls.remove_chunk(
            encoder_out,
            encoder_out_lens,
@@ -557,55 +370,14 @@
        # 4. Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder2.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder2(
                speech, speech_lengths, ctc=self.ctc2, ind=ind
            )
        else:
            encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
        intermediate_outs = None
        encoder_out, encoder_out_lens, _ = self.encoder2(speech, speech_lengths, ind=ind)
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        # # Post-encoder, e.g. NLU
        # if self.postencoder is not None:
        #     encoder_out, encoder_out_lens = self.postencoder(
        #         encoder_out, encoder_out_lens
        #     )
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
            speech.size(0),
        )
        assert encoder_out.size(1) <= encoder_out_lens.max(), (
            encoder_out.size(),
            encoder_out_lens.max(),
        )
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
        return encoder_out, encoder_out_lens
    def _extract_feats(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert speech_lengths.dim() == 1, speech_lengths.shape
        # for data-parallel
        speech = speech[:, : speech_lengths.max()]
        if self.frontend is not None:
            # Frontend
            #  e.g. STFT and Feature extract
            #       data_loader may send time-domain signal in this case
            # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
            feats, feats_lengths = self.frontend(speech, speech_lengths)
        else:
            # No frontend and no feature extract
            feats, feats_lengths = speech, speech_lengths
        return feats, feats_lengths
    def nll(
        self,
@@ -1024,36 +796,152 @@
        return pre_acoustic_embeds, pre_token_length, predictor_alignments, predictor_alignments_len, scama_mask
    def _calc_ctc_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        # Calc CTC loss
        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
    def init_beam_search(self,
                         **kwargs,
                         ):
        from funasr.models.uniasr.beam_search import BeamSearchScama
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus
        # Calc CER using CTC
        cer_ctc = None
        if not self.training and self.error_calculator is not None:
            ys_hat = self.ctc.argmax(encoder_out).data
            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
        return loss_ctc, cer_ctc
        decoding_mode = kwargs.get("decoding_mode", "model1")
        if decoding_mode == "model1":
            decoder = self.decoder
        else:
            decoder = self.decoder2
        # 1. Build ASR model
        scorers = {}
        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=decoder,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        weights = dict(
            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.0),
            ctc=kwargs.get("decoding_ctc_weight", 0.0),
            lm=kwargs.get("lm_weight", 0.0),
            ngram=kwargs.get("ngram_weight", 0.0),
            length_bonus=kwargs.get("penalty", 0.0),
        )
        beam_search = BeamSearchScama(
            beam_size=kwargs.get("beam_size", 5),
            weights=weights,
            scorers=scorers,
            sos=self.sos,
            eos=self.eos,
            vocab_size=len(token_list),
            token_list=token_list,
            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
        )
        self.beam_search = beam_search
    def _calc_ctc_loss2(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        # Calc CTC loss
        loss_ctc = self.ctc2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
    def inference(self,
                  data_in,
                  data_lengths=None,
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
                  **kwargs,
                  ):
        # Calc CER using CTC
        cer_ctc = None
        if not self.training and self.error_calculator is not None:
            ys_hat = self.ctc2.argmax(encoder_out).data
            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
        return loss_ctc, cer_ctc
        decoding_model = kwargs.get("decoding_model", "normal")
        token_num_relax = kwargs.get("token_num_relax", 5)
        if decoding_model == "fast":
            decoding_ind = 0
            decoding_mode = "model1"
        elif decoding_model == "offline":
            decoding_ind = 1
            decoding_mode = "model2"
        else:
            decoding_ind = 0
            decoding_mode = "model2"
        # init beamsearch
        if self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(decoding_mode=decoding_mode, **kwargs)
            self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        speech_raw = speech.clone().to(device=kwargs["device"])
        # Encoder
        _, encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=decoding_ind)
        if decoding_mode == "model1":
            predictor_outs = self.calc_predictor_mask(encoder_out, encoder_out_lens)
        else:
            encoder_out, encoder_out_lens = self.encode2(encoder_out, encoder_out_lens, speech_raw, speech_lengths, ind=decoding_ind)
            predictor_outs = self.calc_predictor_mask2(encoder_out, encoder_out_lens)
        scama_mask = predictor_outs[4]
        pre_token_length = predictor_outs[1]
        pre_acoustic_embeds = predictor_outs[0]
        maxlen = pre_token_length.sum().item() + token_num_relax
        minlen = max(0, pre_token_length.sum().item() - token_num_relax)
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=encoder_out[0], scama_mask=scama_mask, pre_acoustic_embeds=pre_acoustic_embeds, maxlenratio=0.0,
            minlenratio=0.0, maxlen=int(maxlen), minlen=int(minlen),
        )
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        for hyp in nbest_hyps:
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
                token_int = hyp.yseq[1:last_pos]
            else:
                token_int = hyp.yseq[1:last_pos].tolist()
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0, token_int))
            # Change integer-ids to tokens
            token = tokenizer.ids2tokens(token_int)
            text_postprocessed = tokenizer.tokens2text(token)
            if not hasattr(tokenizer, "bpemodel"):
                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
            result_i = {"key": key[0], "text": text_postprocessed}
            results.append(result_i)
        return results, meta_data