From fc606ceef3aa5a1dbca795a43147c0aa9ddf0b34 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期二, 14 三月 2023 20:42:08 +0800
Subject: [PATCH] rnnt

---
 funasr/models_transducer/activation.py                          |  213 +
 funasr/models_transducer/decoder/rnn_decoder.py                 |  259 +
 funasr/bin/asr_train_transducer.py                              |   46 
 funasr/models_transducer/encoder/blocks/__init__.py             |    0 
 funasr/bin/asr_inference_rnnt.py                                | 1297 ++++----
 funasr/models_transducer/encoder/modules/__init__.py            |    0 
 funasr/models_transducer/encoder/modules/attention.py           |  246 +
 funasr/models_transducer/encoder/sanm_encoder.py                |  835 +++++
 funasr/models_transducer/encoder/modules/multi_blocks.py        |  105 
 funasr/models_transducer/espnet_transducer_model_unified.py     |  588 ++++
 funasr/models_transducer/beam_search_transducer.py              |  705 ++++
 funasr/models_transducer/decoder/stateless_decoder.py           |  157 +
 funasr/models_transducer/espnet_transducer_model.py             |  484 +++
 funasr/models_transducer/__init__.py                            |    0 
 funasr/models_transducer/encoder/blocks/conformer.py            |  198 +
 funasr/models_transducer/encoder/blocks/linear_input.py         |   52 
 funasr/models_transducer/encoder/modules/normalization.py       |  170 +
 funasr/models_transducer/utils.py                               |  200 +
 funasr/models_transducer/encoder/__init__.py                    |    0 
 funasr/models_transducer/espnet_transducer_model_uni_asr.py     |  485 +++
 funasr/models_transducer/encoder/blocks/branchformer.py         |  178 +
 funasr/models_transducer/encoder/building.py                    |  352 ++
 funasr/models_transducer/decoder/abs_decoder.py                 |  110 
 funasr/models_transducer/encoder/blocks/conv1d.py               |  221 +
 funasr/models_transducer/encoder/blocks/conv_input.py           |  226 +
 funasr/models_transducer/encoder/modules/convolution.py         |  196 +
 funasr/models_transducer/joint_network.py                       |   62 
 funasr/models_transducer/error_calculator.py                    |  170 +
 funasr/models_transducer/encoder/encoder.py                     |  294 ++
 funasr/models_transducer/encoder/modules/positional_encoding.py |   91 
 funasr/models_transducer/decoder/__init__.py                    |    0 
 funasr/models_transducer/encoder/validation.py                  |  171 +
 funasr/tasks/asr_transducer.py                                  |  487 +++
 funasr/bin/asr_inference_launch.py                              |   43 
 34 files changed, 7,945 insertions(+), 696 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 1fae766..b9be3e2 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -131,6 +131,11 @@
         help="Pretrained model tag. If specify this option, *_train_config and "
              "*_file will be overwritten",
     )
+    group.add_argument(
+        "--beam_search_config",
+        default={},
+        help="The keyword arguments for transducer beam search.",
+    )
 
     group = parser.add_argument_group("Beam-search related")
     group.add_argument(
@@ -168,6 +173,41 @@
     group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
     group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
     group.add_argument("--streaming", type=str2bool, default=False)
+    group.add_argument("--simu_streaming", type=str2bool, default=False)
+    group.add_argument("--chunk_size", type=int, default=16)
+    group.add_argument("--left_context", type=int, default=16)
+    group.add_argument("--right_context", type=int, default=0)
+    group.add_argument(
+        "--display_partial_hypotheses",
+        type=bool,
+        default=False,
+        help="Whether to display partial hypotheses during chunk-by-chunk inference.",
+    )    
+   
+    group = parser.add_argument_group("Dynamic quantization related")
+    group.add_argument(
+        "--quantize_asr_model",
+        type=bool,
+        default=False,
+        help="Apply dynamic quantization to ASR model.",
+    )
+    group.add_argument(
+        "--quantize_modules",
+        nargs="*",
+        default=None,
+        help="""Module names to apply dynamic quantization on.
+        The module names are provided as a list, where each name is separated
+        by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
+        Each specified name should be an attribute of 'torch.nn', e.g.:
+        torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
+    )
+    group.add_argument(
+        "--quantize_dtype",
+        type=str,
+        default="qint8",
+        choices=["float16", "qint8"],
+        help="Dtype for dynamic quantization.",
+    )    
 
     group = parser.add_argument_group("Text converter related")
     group.add_argument(
@@ -262,6 +302,9 @@
     elif mode == "mfcca":
         from funasr.bin.asr_inference_mfcca import inference_modelscope
         return inference_modelscope(**kwargs)
+    elif mode == "rnnt":
+        from funasr.bin.asr_inference_rnnt import inference
+        return inference(**kwargs)
     else:
         logging.info("Unknown decoding mode: {}".format(mode))
         return None
diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index 6cd7061..f651f11 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -1,151 +1,145 @@
 #!/usr/bin/env python3
+
+""" Inference class definition for Transducer models."""
+
+from __future__ import annotations
+
 import argparse
 import logging
+import math
 import sys
-import time
-import copy
-import os
-import codecs
-import tempfile
-import requests
 from pathlib import Path
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
 
 import numpy as np
 import torch
-from typeguard import check_argument_types
+from packaging.version import parse as V
+from typeguard import check_argument_types, check_return_type
 
+from funasr.models_transducer.beam_search_transducer import (
+    BeamSearchTransducer,
+    Hypothesis,
+)
+from funasr.models_transducer.utils import TooShortUttError
 from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTaskParaformer as ASRTask
+from funasr.tasks.asr_transducer import ASRTransducerTask
 from funasr.tasks.lm import LMTask
 from funasr.text.build_tokenizer import build_tokenizer
 from funasr.text.token_id_converter import TokenIDConverter
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import config_argparse
+from funasr.utils.types import str2bool, str2triple_str, str_or_none
 from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
 
 
 class Speech2Text:
-    """Speech2Text class
-
-    Examples:
-            >>> import soundfile
-            >>> speech2text = Speech2Text("asr_config.yml", "asr.pth")
-            >>> audio, rate = soundfile.read("speech.wav")
-            >>> speech2text(audio)
-            [(text, token, token_int, hypothesis object), ...]
-
+    """Speech2Text class for Transducer models.
+    Args:
+        asr_train_config: ASR model training config path.
+        asr_model_file: ASR model path.
+        beam_search_config: Beam search config path.
+        lm_train_config: Language Model training config path.
+        lm_file: Language Model config path.
+        token_type: Type of token units.
+        bpemodel: BPE model path.
+        device: Device to use for inference.
+        beam_size: Size of beam during search.
+        dtype: Data type.
+        lm_weight: Language model weight.
+        quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+        quantize_modules: List of module names to apply dynamic quantization on.
+        quantize_dtype: Dynamic quantization data type.
+        nbest: Number of final hypothesis.
+        streaming: Whether to perform chunk-by-chunk inference.
+        chunk_size: Number of frames in chunk AFTER subsampling.
+        left_context: Number of frames in left context AFTER subsampling.
+        right_context: Number of frames in right context AFTER subsampling.
+        display_partial_hypotheses: Whether to display partial hypotheses.
     """
 
     def __init__(
-            self,
-            asr_train_config: Union[Path, str] = None,
-            asr_model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            lm_train_config: Union[Path, str] = None,
-            lm_file: Union[Path, str] = None,
-            token_type: str = None,
-            bpemodel: str = None,
-            device: str = "cpu",
-            maxlenratio: float = 0.0,
-            minlenratio: float = 0.0,
-            dtype: str = "float32",
-            beam_size: int = 20,
-            ctc_weight: float = 0.5,
-            lm_weight: float = 1.0,
-            ngram_weight: float = 0.9,
-            penalty: float = 0.0,
-            nbest: int = 1,
-            frontend_conf: dict = None,
-            hotword_list_or_file: str = None,
-            **kwargs,
-    ):
+        self,
+        asr_train_config: Union[Path, str] = None,
+        asr_model_file: Union[Path, str] = None,
+        beam_search_config: Dict[str, Any] = None,
+        lm_train_config: Union[Path, str] = None,
+        lm_file: Union[Path, str] = None,
+        token_type: str = None,
+        bpemodel: str = None,
+        device: str = "cpu",
+        beam_size: int = 5,
+        dtype: str = "float32",
+        lm_weight: float = 1.0,
+        quantize_asr_model: bool = False,
+        quantize_modules: List[str] = None,
+        quantize_dtype: str = "qint8",
+        nbest: int = 1,
+        streaming: bool = False,
+        simu_streaming: bool = False,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+        display_partial_hypotheses: bool = False,
+    ) -> None:
+        """Construct a Speech2Text object."""
+        super().__init__()
+
         assert check_argument_types()
 
-        # 1. Build ASR model
-        scorers = {}
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
-        )
-        frontend = None
-        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
-        logging.info("asr_model: {}".format(asr_model))
-        logging.info("asr_train_args: {}".format(asr_train_args))
-        asr_model.to(dtype=getattr(torch, dtype)).eval()
-
-        if asr_model.ctc != None:
-            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
-            scorers.update(
-                ctc=ctc
-            )
-        token_list = asr_model.token_list
-        scorers.update(
-            length_bonus=LengthBonus(len(token_list)),
+        asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
+            asr_train_config, asr_model_file, device
         )
 
-        # 2. Build Language model
+        if quantize_asr_model:
+            if quantize_modules is not None:
+                if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
+                    raise ValueError(
+                        "Only 'Linear' and 'LSTM' modules are currently supported"
+                        " by PyTorch and in --quantize_modules"
+                    )
+
+                q_config = set([getattr(torch.nn, q) for q in quantize_modules])
+            else:
+                q_config = {torch.nn.Linear}
+
+            if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
+                raise ValueError(
+                    "float16 dtype for dynamic quantization is not supported with torch"
+                    " version < 1.5.0. Switching to qint8 dtype instead."
+                )
+            q_dtype = getattr(torch, quantize_dtype)
+
+            asr_model = torch.quantization.quantize_dynamic(
+                asr_model, q_config, dtype=q_dtype
+            ).eval()
+        else:
+            asr_model.to(dtype=getattr(torch, dtype)).eval()
+
         if lm_train_config is not None:
             lm, lm_train_args = LMTask.build_model_from_file(
                 lm_train_config, lm_file, device
             )
-            scorers["lm"] = lm.lm
-
-        # 3. Build ngram model
-        # ngram is not supported now
-        ngram = None
-        scorers["ngram"] = ngram
+            lm_scorer = lm.lm
+        else:
+            lm_scorer = None
 
         # 4. Build BeamSearch object
-        # transducer is not supported now
-        beam_search_transducer = None
+        if beam_search_config is None:
+            beam_search_config = {}
 
-        weights = dict(
-            decoder=1.0 - ctc_weight,
-            ctc=ctc_weight,
-            lm=lm_weight,
-            ngram=ngram_weight,
-            length_bonus=penalty,
-        )
-        beam_search = BeamSearch(
-            beam_size=beam_size,
-            weights=weights,
-            scorers=scorers,
-            sos=asr_model.sos,
-            eos=asr_model.eos,
-            vocab_size=len(token_list),
-            token_list=token_list,
-            pre_beam_score_key=None if ctc_weight == 1.0 else "full",
+        beam_search = BeamSearchTransducer(
+            asr_model.decoder,
+            asr_model.joint_network,
+            beam_size,
+            lm=lm_scorer,
+            lm_weight=lm_weight,
+            nbest=nbest,
+            **beam_search_config,
         )
 
-        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
-        for scorer in scorers.values():
-            if isinstance(scorer, torch.nn.Module):
-                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+        token_list = asr_model.token_list
 
-        logging.info(f"Decoding device={device}, dtype={dtype}")
-
-        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
         if token_type is None:
             token_type = asr_train_args.token_type
         if bpemodel is None:
@@ -165,439 +159,397 @@
 
         self.asr_model = asr_model
         self.asr_train_args = asr_train_args
+        self.device = device
+        self.dtype = dtype
+        self.nbest = nbest
+
         self.converter = converter
         self.tokenizer = tokenizer
 
-        # 6. [Optional] Build hotword list from str, local file or url
-        self.hotword_list = None
-        self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
-
-        is_use_lm = lm_weight != 0.0 and lm_file is not None
-        if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
-            beam_search = None
         self.beam_search = beam_search
-        logging.info(f"Beam_search: {self.beam_search}")
-        self.beam_search_transducer = beam_search_transducer
-        self.maxlenratio = maxlenratio
-        self.minlenratio = minlenratio
-        self.device = device
-        self.dtype = dtype
-        self.nbest = nbest
-        self.frontend = frontend
-        self.encoder_downsampling_factor = 1
-        if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
-            self.encoder_downsampling_factor = 4
+        self.streaming = streaming
+        self.simu_streaming = simu_streaming
+        self.chunk_size = max(chunk_size, 0)
+        self.left_context = max(left_context, 0)
+        self.right_context = max(right_context, 0)
 
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
-    ):
-        """Inference
+        if not streaming or chunk_size == 0:
+            self.streaming = False
+            self.asr_model.encoder.dynamic_chunk_training = False
+        
+        if not simu_streaming or chunk_size == 0:
+            self.simu_streaming = False
+            self.asr_model.encoder.dynamic_chunk_training = False
 
-        Args:
-                speech: Input speech data
-        Returns:
-                text, token, token_int, hyp
+        self.n_fft = asr_train_args.frontend_conf.get("n_fft", 512)
+        self.hop_length = asr_train_args.frontend_conf.get("hop_length", 128)
 
-        """
-        assert check_argument_types()
-
-        # Input as audio signal
-        if isinstance(speech, np.ndarray):
-            speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            feats, feats_len = self.frontend.forward(speech, speech_lengths)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-            self.asr_model.frontend = None
+        if asr_train_args.frontend_conf.get("win_length", None) is not None:
+            self.frontend_window_size = asr_train_args.frontend_conf["win_length"]
         else:
-            feats = speech
-            feats_len = speech_lengths
-        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
-        batch = {"speech": feats, "speech_lengths": feats_len}
+            self.frontend_window_size = self.n_fft
 
-        # a. To device
-        batch = to_device(batch, device=self.device)
-
-        # b. Forward Encoder
-        enc, enc_len = self.asr_model.encode(**batch)
-        if isinstance(enc, tuple):
-            enc = enc[0]
-        # assert len(enc) == 1, len(enc)
-        enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
-
-        predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
-                                                                        predictor_outs[2], predictor_outs[3]
-        pre_token_length = pre_token_length.round().long()
-        if torch.max(pre_token_length) < 1:
-            return []
-        if not isinstance(self.asr_model, ContextualParaformer):
-            if self.hotword_list:
-                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
-            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
-            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-        else:
-            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
-            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
-        results = []
-        b, n, d = decoder_out.size()
-        for i in range(b):
-            x = enc[i, :enc_len[i], :]
-            am_scores = decoder_out[i, :pre_token_length[i], :]
-            if self.beam_search is not None:
-                nbest_hyps = self.beam_search(
-                    x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
-                )
-
-                nbest_hyps = nbest_hyps[: self.nbest]
-            else:
-                yseq = am_scores.argmax(dim=-1)
-                score = am_scores.max(dim=-1)[0]
-                score = torch.sum(score, dim=-1)
-                # pad with mask tokens to ensure compatibility with sos/eos tokens
-                yseq = torch.tensor(
-                    [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
-                )
-                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-
-            for hyp in nbest_hyps:
-                assert isinstance(hyp, (Hypothesis)), type(hyp)
-
-                # remove sos/eos and get results
-                last_pos = -1
-                if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
-                else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
-
-                # remove blank symbol id, which is assumed to be 0
-                token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
-                # Change integer-ids to tokens
-                token = self.converter.ids2tokens(token_int)
-
-                if self.tokenizer is not None:
-                    text = self.tokenizer.tokens2text(token)
-                else:
-                    text = None
-
-                results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
-
-        # assert check_return_type(results)
-        return results
-
-    def generate_hotwords_list(self, hotword_list_or_file):
-        # for None
-        if hotword_list_or_file is None:
-            hotword_list = None
-        # for local txt inputs
-        elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
-            logging.info("Attempting to parse hotwords from local txt...")
-            hotword_list = []
-            hotword_str_list = []
-            with codecs.open(hotword_list_or_file, 'r') as fin:
-                for line in fin.readlines():
-                    hw = line.strip()
-                    hotword_str_list.append(hw)
-                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
-                hotword_list.append([self.asr_model.sos])
-                hotword_str_list.append('<s>')
-            logging.info("Initialized hotword list from file: {}, hotword list: {}."
-                         .format(hotword_list_or_file, hotword_str_list))
-        # for url, download and generate txt
-        elif hotword_list_or_file.startswith('http'):
-            logging.info("Attempting to parse hotwords from url...")
-            work_dir = tempfile.TemporaryDirectory().name
-            if not os.path.exists(work_dir):
-                os.makedirs(work_dir)
-            text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
-            local_file = requests.get(hotword_list_or_file)
-            open(text_file_path, "wb").write(local_file.content)
-            hotword_list_or_file = text_file_path
-            hotword_list = []
-            hotword_str_list = []
-            with codecs.open(hotword_list_or_file, 'r') as fin:
-                for line in fin.readlines():
-                    hw = line.strip()
-                    hotword_str_list.append(hw)
-                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
-                hotword_list.append([self.asr_model.sos])
-                hotword_str_list.append('<s>')
-            logging.info("Initialized hotword list from file: {}, hotword list: {}."
-                         .format(hotword_list_or_file, hotword_str_list))
-        # for text str input
-        elif not hotword_list_or_file.endswith('.txt'):
-            logging.info("Attempting to parse hotwords as str...")
-            hotword_list = []
-            hotword_str_list = []
-            for hw in hotword_list_or_file.strip().split():
-                hotword_str_list.append(hw)
-                hotword_list.append(self.converter.tokens2ids([i for i in hw]))
-            hotword_list.append([self.asr_model.sos])
-            hotword_str_list.append('<s>')
-            logging.info("Hotword list: {}.".format(hotword_str_list))
-        else:
-            hotword_list = None
-        return hotword_list
-
-class Speech2TextExport:
-    """Speech2TextExport class
-
-    """
-
-    def __init__(
-            self,
-            asr_train_config: Union[Path, str] = None,
-            asr_model_file: Union[Path, str] = None,
-            cmvn_file: Union[Path, str] = None,
-            lm_train_config: Union[Path, str] = None,
-            lm_file: Union[Path, str] = None,
-            token_type: str = None,
-            bpemodel: str = None,
-            device: str = "cpu",
-            maxlenratio: float = 0.0,
-            minlenratio: float = 0.0,
-            dtype: str = "float32",
-            beam_size: int = 20,
-            ctc_weight: float = 0.5,
-            lm_weight: float = 1.0,
-            ngram_weight: float = 0.9,
-            penalty: float = 0.0,
-            nbest: int = 1,
-            frontend_conf: dict = None,
-            hotword_list_or_file: str = None,
-            **kwargs,
-    ):
-
-        # 1. Build ASR model
-        asr_model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, device
+        self.window_size = self.chunk_size + self.right_context
+        self._raw_ctx = self.asr_model.encoder.get_encoder_input_raw_size(
+            self.window_size, self.hop_length
         )
-        frontend = None
-        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+        self._ctx = self.asr_model.encoder.get_encoder_input_size(
+            self.window_size
+        )
+       
 
-        logging.info("asr_model: {}".format(asr_model))
-        logging.info("asr_train_args: {}".format(asr_train_args))
-        asr_model.to(dtype=getattr(torch, dtype)).eval()
+        #self.last_chunk_length = (
+        #    self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+        #) * self.hop_length
 
-        token_list = asr_model.token_list
+        self.last_chunk_length = (
+            self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+        )
+        self.reset_inference_cache()
 
+    def reset_inference_cache(self) -> None:
+        """Reset Speech2Text parameters."""
+        self.frontend_cache = None
 
+        self.asr_model.encoder.reset_streaming_cache(
+            self.left_context, device=self.device
+        )
+        self.beam_search.reset_inference_cache()
 
-        logging.info(f"Decoding device={device}, dtype={dtype}")
+        self.num_processed_frames = torch.tensor([[0]], device=self.device)
 
-        # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
-        if token_type is None:
-            token_type = asr_train_args.token_type
-        if bpemodel is None:
-            bpemodel = asr_train_args.bpemodel
-
-        if token_type is None:
-            tokenizer = None
-        elif token_type == "bpe":
-            if bpemodel is not None:
-                tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
-            else:
-                tokenizer = None
-        else:
-            tokenizer = build_tokenizer(token_type=token_type)
-        converter = TokenIDConverter(token_list=token_list)
-        logging.info(f"Text tokenizer: {tokenizer}")
-
-        # self.asr_model = asr_model
-        self.asr_train_args = asr_train_args
-        self.converter = converter
-        self.tokenizer = tokenizer
-
-        self.device = device
-        self.dtype = dtype
-        self.nbest = nbest
-        self.frontend = frontend
-
-        model = Paraformer_export(asr_model, onnx=False)
-        self.asr_model = model
-        
-    @torch.no_grad()
-    def __call__(
-            self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
-    ):
-        """Inference
-
+    def apply_frontend(
+        self, speech: torch.Tensor, is_final: bool = False
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward frontend.
         Args:
-                speech: Input speech data
+            speech: Speech data. (S)
+            is_final: Whether speech corresponds to the final (or only) chunk of data.
         Returns:
-                text, token, token_int, hyp
+            feats: Features sequence. (1, T_in, F)
+            feats_lengths: Features sequence length. (1, T_in, F)
+        """
+        if self.frontend_cache is not None:
+            speech = torch.cat([self.frontend_cache["waveform_buffer"], speech], dim=0)
 
+        if is_final:
+            if self.streaming and speech.size(0) < self.last_chunk_length:
+                pad = torch.zeros(
+                    self.last_chunk_length - speech.size(0), dtype=speech.dtype
+                )
+                speech = torch.cat([speech, pad], dim=0)
+
+            speech_to_process = speech
+            waveform_buffer = None
+        else:
+            n_frames = (
+                speech.size(0) - (self.frontend_window_size - self.hop_length)
+            ) // self.hop_length
+
+            n_residual = (
+                speech.size(0) - (self.frontend_window_size - self.hop_length)
+            ) % self.hop_length
+
+            speech_to_process = speech.narrow(
+                0,
+                0,
+                (self.frontend_window_size - self.hop_length)
+                + n_frames * self.hop_length,
+            )
+
+            waveform_buffer = speech.narrow(
+                0,
+                speech.size(0)
+                - (self.frontend_window_size - self.hop_length)
+                - n_residual,
+                (self.frontend_window_size - self.hop_length) + n_residual,
+            ).clone()
+
+        speech_to_process = speech_to_process.unsqueeze(0).to(
+            getattr(torch, self.dtype)
+        )
+        lengths = speech_to_process.new_full(
+            [1], dtype=torch.long, fill_value=speech_to_process.size(1)
+        )
+        batch = {"speech": speech_to_process, "speech_lengths": lengths}
+        batch = to_device(batch, device=self.device)
+
+        feats, feats_lengths = self.asr_model._extract_feats(**batch)
+        if self.asr_model.normalize is not None:
+            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+        if is_final:
+            if self.frontend_cache is None:
+                pass
+            else:
+                feats = feats.narrow(
+                    1,
+                    math.ceil(
+                        math.ceil(self.frontend_window_size / self.hop_length) / 2
+                    ),
+                    feats.size(1)
+                    - math.ceil(
+                        math.ceil(self.frontend_window_size / self.hop_length) / 2
+                    ),
+                )
+        else:
+            if self.frontend_cache is None:
+                feats = feats.narrow(
+                    1,
+                    0,
+                    feats.size(1)
+                    - math.ceil(
+                        math.ceil(self.frontend_window_size / self.hop_length) / 2
+                    ),
+                )
+            else:
+                feats = feats.narrow(
+                    1,
+                    math.ceil(
+                        math.ceil(self.frontend_window_size / self.hop_length) / 2
+                    ),
+                    feats.size(1)
+                    - 2
+                    * math.ceil(
+                        math.ceil(self.frontend_window_size / self.hop_length) / 2
+                    ),
+                )
+
+        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+        if is_final:
+            self.frontend_cache = None
+        else:
+            self.frontend_cache = {"waveform_buffer": waveform_buffer}
+
+        return feats, feats_lengths
+
+    @torch.no_grad()
+    def streaming_decode(
+        self,
+        speech: Union[torch.Tensor, np.ndarray],
+        is_final: bool = True,
+    ) -> List[Hypothesis]:
+        """Speech2Text streaming call.
+        Args:
+            speech: Chunk of speech data. (S)
+            is_final: Whether speech corresponds to the final chunk of data.
+        Returns:
+            nbest_hypothesis: N-best hypothesis.
+        """
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+        if is_final:
+            if self.streaming and speech.size(0) < self.last_chunk_length:
+                pad = torch.zeros(
+                    self.last_chunk_length - speech.size(0), speech.size(1),  dtype=speech.dtype
+                )
+                speech = torch.cat([speech, pad], dim=0) #feats, feats_length = self.apply_frontend(speech, is_final=is_final)
+
+        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+        if self.asr_model.normalize is not None:
+            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+        feats = to_device(feats, device=self.device)
+        feats_lengths = to_device(feats_lengths, device=self.device)
+        enc_out = self.asr_model.encoder.chunk_forward(
+            feats,
+            feats_lengths,
+            self.num_processed_frames,
+            chunk_size=self.chunk_size,
+            left_context=self.left_context,
+            right_context=self.right_context,
+        )
+        nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
+
+        self.num_processed_frames += self.chunk_size
+
+        if is_final:
+            self.reset_inference_cache()
+
+        return nbest_hyps
+
+    @torch.no_grad()
+    def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]:
+        """Speech2Text call.
+        Args:
+            speech: Speech data. (S)
+        Returns:
+            nbest_hypothesis: N-best hypothesis.
         """
         assert check_argument_types()
 
-        # Input as audio signal
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
-
-        if self.frontend is not None:
-            feats, feats_len = self.frontend.forward(speech, speech_lengths)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-            self.asr_model.frontend = None
-        else:
-            feats = speech
-            feats_len = speech_lengths
-
-        enc_len_batch_total = feats_len.sum()
-        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
-        batch = {"speech": feats, "speech_lengths": feats_len}
-
-        # a. To device
-        batch = to_device(batch, device=self.device)
-
-        decoder_outs = self.asr_model(**batch)
-        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
         
+        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+        if self.asr_model.normalize is not None:
+            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+        feats = to_device(feats, device=self.device)
+        feats_lengths = to_device(feats_lengths, device=self.device)
+        enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context, self.right_context)
+        nbest_hyps = self.beam_search(enc_out[0])
+
+        return nbest_hyps
+
+    @torch.no_grad()
+    def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]:
+        """Speech2Text call.
+        Args:
+            speech: Speech data. (S)
+        Returns:
+            nbest_hypothesis: N-best hypothesis.
+        """
+        assert check_argument_types()
+
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+        
+        # lengths: (1,)
+        # feats, feats_length = self.apply_frontend(speech)
+        feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+        # lengths: (1,)
+        feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+        # print(feats.shape)
+        # print(feats_lengths)
+        if self.asr_model.normalize is not None:
+            feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+        feats = to_device(feats, device=self.device)
+        feats_lengths = to_device(feats_lengths, device=self.device)
+
+        enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
+
+        nbest_hyps = self.beam_search(enc_out[0])
+
+        return nbest_hyps
+
+    def hypotheses_to_results(self, nbest_hyps: List[Hypothesis]) -> List[Any]:
+        """Build partial or final results from the hypotheses.
+        Args:
+            nbest_hyps: N-best hypothesis.
+        Returns:
+            results: Results containing different representation for the hypothesis.
+        """
         results = []
-        b, n, d = decoder_out.size()
-        for i in range(b):
-            am_scores = decoder_out[i, :ys_pad_lens[i], :]
 
-            yseq = am_scores.argmax(dim=-1)
-            score = am_scores.max(dim=-1)[0]
-            score = torch.sum(score, dim=-1)
-            # pad with mask tokens to ensure compatibility with sos/eos tokens
-            yseq = torch.tensor(
-                yseq.tolist(), device=yseq.device
-            )
-            nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+        for hyp in nbest_hyps:
+            token_int = list(filter(lambda x: x != 0, hyp.yseq))
 
-            for hyp in nbest_hyps:
-                assert isinstance(hyp, (Hypothesis)), type(hyp)
+            token = self.converter.ids2tokens(token_int)
 
-                # remove sos/eos and get results
-                last_pos = -1
-                if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
-                else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
+            if self.tokenizer is not None:
+                text = self.tokenizer.tokens2text(token)
+            else:
+                text = None
+            results.append((text, token, token_int, hyp))
 
-                # remove blank symbol id, which is assumed to be 0
-                token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
-                # Change integer-ids to tokens
-                token = self.converter.ids2tokens(token_int)
-
-                if self.tokenizer is not None:
-                    text = self.tokenizer.tokens2text(token)
-                else:
-                    text = None
-
-                results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
+            assert check_return_type(results)
 
         return results
+
+    @staticmethod
+    def from_pretrained(
+        model_tag: Optional[str] = None,
+        **kwargs: Optional[Any],
+    ) -> Speech2Text:
+        """Build Speech2Text instance from the pretrained model.
+        Args:
+            model_tag: Model tag of the pretrained models.
+        Return:
+            : Speech2Text instance.
+        """
+        if model_tag is not None:
+            try:
+                from espnet_model_zoo.downloader import ModelDownloader
+
+            except ImportError:
+                logging.error(
+                    "`espnet_model_zoo` is not installed. "
+                    "Please install via `pip install -U espnet_model_zoo`."
+                )
+                raise
+            d = ModelDownloader()
+            kwargs.update(**d.download_and_unpack(model_tag))
+
+        return Speech2Text(**kwargs)
 
 
 def inference(
-        maxlenratio: float,
-        minlenratio: float,
-        batch_size: int,
-        beam_size: int,
-        ngpu: int,
-        ctc_weight: float,
-        lm_weight: float,
-        penalty: float,
-        log_level: Union[int, str],
-        data_path_and_name_and_type,
-        asr_train_config: Optional[str],
-        asr_model_file: Optional[str],
-        cmvn_file: Optional[str] = None,
-        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        token_type: Optional[str] = None,
-        key_file: Optional[str] = None,
-        word_lm_train_config: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        streaming: bool = False,
-        output_dir: Optional[str] = None,
-        dtype: str = "float32",
-        seed: int = 0,
-        ngram_weight: float = 0.9,
-        nbest: int = 1,
-        num_workers: int = 1,
-
-        **kwargs,
-):
-    inference_pipeline = inference_modelscope(
-        maxlenratio=maxlenratio,
-        minlenratio=minlenratio,
-        batch_size=batch_size,
-        beam_size=beam_size,
-        ngpu=ngpu,
-        ctc_weight=ctc_weight,
-        lm_weight=lm_weight,
-        penalty=penalty,
-        log_level=log_level,
-        asr_train_config=asr_train_config,
-        asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
-        raw_inputs=raw_inputs,
-        lm_train_config=lm_train_config,
-        lm_file=lm_file,
-        token_type=token_type,
-        key_file=key_file,
-        word_lm_train_config=word_lm_train_config,
-        bpemodel=bpemodel,
-        allow_variable_data_keys=allow_variable_data_keys,
-        streaming=streaming,
-        output_dir=output_dir,
-        dtype=dtype,
-        seed=seed,
-        ngram_weight=ngram_weight,
-        nbest=nbest,
-        num_workers=num_workers,
-
-        **kwargs,
-    )
-    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-
-def inference_modelscope(
-        maxlenratio: float,
-        minlenratio: float,
-        batch_size: int,
-        beam_size: int,
-        ngpu: int,
-        ctc_weight: float,
-        lm_weight: float,
-        penalty: float,
-        log_level: Union[int, str],
-        # data_path_and_name_and_type,
-        asr_train_config: Optional[str],
-        asr_model_file: Optional[str],
-        cmvn_file: Optional[str] = None,
-        lm_train_config: Optional[str] = None,
-        lm_file: Optional[str] = None,
-        token_type: Optional[str] = None,
-        key_file: Optional[str] = None,
-        word_lm_train_config: Optional[str] = None,
-        bpemodel: Optional[str] = None,
-        allow_variable_data_keys: bool = False,
-        dtype: str = "float32",
-        seed: int = 0,
-        ngram_weight: float = 0.9,
-        nbest: int = 1,
-        num_workers: int = 1,
-        output_dir: Optional[str] = None,
-        param_dict: dict = None,
-        **kwargs,
-):
+    output_dir: str,
+    batch_size: int,
+    dtype: str,
+    beam_size: int,
+    ngpu: int,
+    seed: int,
+    lm_weight: float,
+    nbest: int,
+    num_workers: int,
+    log_level: Union[int, str],
+    data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
+    asr_train_config: Optional[str],
+    asr_model_file: Optional[str],
+    beam_search_config: Optional[dict],
+    lm_train_config: Optional[str],
+    lm_file: Optional[str],
+    model_tag: Optional[str],
+    token_type: Optional[str],
+    bpemodel: Optional[str],
+    key_file: Optional[str],
+    allow_variable_data_keys: bool,
+    quantize_asr_model: Optional[bool],
+    quantize_modules: Optional[List[str]],
+    quantize_dtype: Optional[str],
+    streaming: Optional[bool],
+    simu_streaming: Optional[bool],
+    chunk_size: Optional[int],
+    left_context: Optional[int],
+    right_context: Optional[int],
+    display_partial_hypotheses: bool,
+    **kwargs,
+) -> None:
+    """Transducer model inference.
+    Args:
+        output_dir: Output directory path.
+        batch_size: Batch decoding size.
+        dtype: Data type.
+        beam_size: Beam size.
+        ngpu: Number of GPUs.
+        seed: Random number generator seed.
+        lm_weight: Weight of language model.
+        nbest: Number of final hypothesis.
+        num_workers: Number of workers.
+        log_level: Level of verbose for logs.
+        data_path_and_name_and_type:
+        asr_train_config: ASR model training config path.
+        asr_model_file: ASR model path.
+        beam_search_config: Beam search config path.
+        lm_train_config: Language Model training config path.
+        lm_file: Language Model path.
+        model_tag: Model tag.
+        token_type: Type of token units.
+        bpemodel: BPE model path.
+        key_file: File key.
+        allow_variable_data_keys: Whether to allow variable data keys.
+        quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+        quantize_modules: List of module names to apply dynamic quantization on.
+        quantize_dtype: Dynamic quantization data type.
+        streaming: Whether to perform chunk-by-chunk inference.
+        chunk_size: Number of frames in chunk AFTER subsampling.
+        left_context: Number of frames in left context AFTER subsampling.
+        right_context: Number of frames in right context AFTER subsampling.
+        display_partial_hypotheses: Whether to display partial hypotheses.
+    """
     assert check_argument_types()
 
-    if word_lm_train_config is not None:
-        raise NotImplementedError("Word LM is not implemented")
+    if batch_size > 1:
+        raise NotImplementedError("batch decoding is not implemented")
     if ngpu > 1:
         raise NotImplementedError("only single GPU decoding is supported")
 
@@ -605,19 +557,11 @@
         level=log_level,
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
     )
-    
-    export_mode = False
-    if param_dict is not None:
-        hotword_list_or_file = param_dict.get('hotword')
-        export_mode = param_dict.get("export_mode", False)
-    else:
-        hotword_list_or_file = None
 
-    if ngpu >= 1 and torch.cuda.is_available():
+    if ngpu >= 1:
         device = "cuda"
     else:
         device = "cpu"
-        batch_size = 1
 
     # 1. Set random-seed
     set_all_random_seed(seed)
@@ -626,144 +570,105 @@
     speech2text_kwargs = dict(
         asr_train_config=asr_train_config,
         asr_model_file=asr_model_file,
-        cmvn_file=cmvn_file,
+        beam_search_config=beam_search_config,
         lm_train_config=lm_train_config,
         lm_file=lm_file,
         token_type=token_type,
         bpemodel=bpemodel,
         device=device,
-        maxlenratio=maxlenratio,
-        minlenratio=minlenratio,
         dtype=dtype,
         beam_size=beam_size,
-        ctc_weight=ctc_weight,
         lm_weight=lm_weight,
-        ngram_weight=ngram_weight,
-        penalty=penalty,
         nbest=nbest,
-        hotword_list_or_file=hotword_list_or_file,
+        quantize_asr_model=quantize_asr_model,
+        quantize_modules=quantize_modules,
+        quantize_dtype=quantize_dtype,
+        streaming=streaming,
+        simu_streaming=simu_streaming,
+        chunk_size=chunk_size,
+        left_context=left_context,
+        right_context=right_context,
     )
-    if export_mode:
-        speech2text = Speech2TextExport(**speech2text_kwargs)
-    else:
-        speech2text = Speech2Text(**speech2text_kwargs)
+    speech2text = Speech2Text.from_pretrained(
+        model_tag=model_tag,
+        **speech2text_kwargs,
+    )
 
-    def _forward(
-            data_path_and_name_and_type,
-            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
-            output_dir_v2: Optional[str] = None,
-            fs: dict = None,
-            param_dict: dict = None,
-            **kwargs,
-    ):
+    # 3. Build data-iterator
+    loader = ASRTransducerTask.build_streaming_iterator(
+        data_path_and_name_and_type,
+        dtype=dtype,
+        batch_size=batch_size,
+        key_file=key_file,
+        num_workers=num_workers,
+        preprocess_fn=ASRTransducerTask.build_preprocess_fn(
+            speech2text.asr_train_args, False
+        ),
+        collate_fn=ASRTransducerTask.build_collate_fn(
+            speech2text.asr_train_args, False
+        ),
+        allow_variable_data_keys=allow_variable_data_keys,
+        inference=True,
+    )
 
-        hotword_list_or_file = None
-        if param_dict is not None:
-            hotword_list_or_file = param_dict.get('hotword')
-        if 'hotword' in kwargs:
-            hotword_list_or_file = kwargs['hotword']
-        if hotword_list_or_file is not None or 'hotword' in kwargs:
-            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
-        cache = None
-        if 'cache' in param_dict:
-            cache = param_dict['cache']
-        # 3. Build data-iterator
-        if data_path_and_name_and_type is None and raw_inputs is not None:
-            if isinstance(raw_inputs, torch.Tensor):
-                raw_inputs = raw_inputs.numpy()
-            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = ASRTask.build_streaming_iterator(
-            data_path_and_name_and_type,
-            dtype=dtype,
-            fs=fs,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
-            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
-        )
-
-        forward_time_total = 0.0
-        length_total = 0.0
-        finish_count = 0
-        file_count = 1
-        # 7 .Start for-loop
-        # FIXME(kamo): The output format should be discussed about
-        asr_result_list = []
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path is not None:
-            writer = DatadirWriter(output_path)
-        else:
-            writer = None
-
+    # 4 .Start for-loop
+    with DatadirWriter(output_dir) as writer:
         for keys, batch in loader:
             assert isinstance(batch, dict), type(batch)
             assert all(isinstance(s, str) for s in keys), keys
+
             _bs = len(next(iter(batch.values())))
             assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
+            batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+            assert len(batch.keys()) == 1
 
-            logging.info("decoding, utt_id: {}".format(keys))
-            # N-best list of (text, token, token_int, hyp_object)
+            try:
+                if speech2text.streaming:
+                    speech = batch["speech"]
 
-            time_beg = time.time()
-            results = speech2text(cache=cache, **batch)
-            if len(results) < 1:
-                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
-                results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
-            time_end = time.time()
-            forward_time = time_end - time_beg
-            lfr_factor = results[0][-1]
-            length = results[0][-2]
-            forward_time_total += forward_time
-            length_total += length
-            rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor))
-            logging.info(rtf_cur)
+                    _steps = len(speech) // speech2text._ctx
+                    _end = 0
+                    for i in range(_steps):
+                        _end = (i + 1) * speech2text._ctx
 
-            for batch_id in range(_bs):
-                result = [results[batch_id][:-2]]
+                        speech2text.streaming_decode(
+                            speech[i * speech2text._ctx : _end], is_final=False
+                        )
 
-                key = keys[batch_id]
-                for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result):
-                    # Create a directory: outdir/{n}best_recog
-                    if writer is not None:
-                        ibest_writer = writer[f"{n}best_recog"]
+                    final_hyps = speech2text.streaming_decode(
+                        speech[_end : len(speech)], is_final=True
+                    )
+                elif speech2text.simu_streaming:
+                    final_hyps = speech2text.simu_streaming_decode(**batch)
+                else:
+                    final_hyps = speech2text(**batch)
 
-                        # Write the result to each file
-                        ibest_writer["token"][key] = " ".join(token)
-                        # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-                        ibest_writer["score"][key] = str(hyp.score)
-                        ibest_writer["rtf"][key] = rtf_cur
+                results = speech2text.hypotheses_to_results(final_hyps)
+            except TooShortUttError as e:
+                logging.warning(f"Utterance {keys} {e}")
+                hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
+                results = [[" ", ["<space>"], [2], hyp]] * nbest
 
-                    if text is not None:
-                        text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-                        item = {'key': key, 'value': text_postprocessed}
-                        asr_result_list.append(item)
-                        finish_count += 1
-                        # asr_utils.print_progress(finish_count / file_count)
-                        if writer is not None:
-                            ibest_writer["text"][key] = text_postprocessed
+            key = keys[0]
+            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+                ibest_writer = writer[f"{n}best_recog"]
 
-                    logging.info("decoding, utt: {}, predictions: {}".format(key, text))
-        rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
-        logging.info(rtf_avg)
-        if writer is not None:
-            ibest_writer["rtf"]["rtf_avf"] = rtf_avg
-        return asr_result_list
+                ibest_writer["token"][key] = " ".join(token)
+                ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                ibest_writer["score"][key] = str(hyp.score)
 
-    return _forward
+                if text is not None:
+                    ibest_writer["text"][key] = text
 
 
 def get_parser():
+    """Get Transducer model inference parser."""
+
     parser = config_argparse.ArgumentParser(
-        description="ASR Decoding",
+        description="ASR Transducer Decoding",
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
 
-    # Note(kamo): Use '_' instead of '-' as separator.
-    # '-' is confusing if written in yaml.
     parser.add_argument(
         "--log_level",
         type=lambda x: x.upper(),
@@ -792,17 +697,12 @@
         default=1,
         help="The number of workers used for DataLoader",
     )
-    parser.add_argument(
-        "--hotword",
-        type=str_or_none,
-        default=None,
-        help="hotword file path or hotwords seperated by space"
-    )
+
     group = parser.add_argument_group("Input data related")
     group.add_argument(
         "--data_path_and_name_and_type",
         type=str2triple_str,
-        required=False,
+        required=True,
         action="append",
     )
     group.add_argument("--key_file", type=str_or_none)
@@ -820,11 +720,6 @@
         help="ASR model parameter file",
     )
     group.add_argument(
-        "--cmvn_file",
-        type=str,
-        help="Global cmvn file",
-    )
-    group.add_argument(
         "--lm_train_config",
         type=str,
         help="LM training configuration",
@@ -835,25 +730,10 @@
         help="LM parameter file",
     )
     group.add_argument(
-        "--word_lm_train_config",
-        type=str,
-        help="Word LM training configuration",
-    )
-    group.add_argument(
-        "--word_lm_file",
-        type=str,
-        help="Word LM parameter file",
-    )
-    group.add_argument(
-        "--ngram_file",
-        type=str,
-        help="N-gram parameter file",
-    )
-    group.add_argument(
         "--model_tag",
         type=str,
         help="Pretrained model tag. If specify this option, *_train_config and "
-             "*_file will be overwritten",
+        "*_file will be overwritten",
     )
 
     group = parser.add_argument_group("Beam-search related")
@@ -864,42 +744,13 @@
         help="The batch size for inference",
     )
     group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
-    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
-    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
-    group.add_argument(
-        "--maxlenratio",
-        type=float,
-        default=0.0,
-        help="Input length ratio to obtain max output length. "
-             "If maxlenratio=0.0 (default), it uses a end-detect "
-             "function "
-             "to automatically find maximum hypothesis lengths."
-             "If maxlenratio<0.0, its absolute value is interpreted"
-             "as a constant max output length",
-    )
-    group.add_argument(
-        "--minlenratio",
-        type=float,
-        default=0.0,
-        help="Input length ratio to obtain min output length",
-    )
-    group.add_argument(
-        "--ctc_weight",
-        type=float,
-        default=0.5,
-        help="CTC weight in joint decoding",
-    )
+    group.add_argument("--beam_size", type=int, default=5, help="Beam size")
     group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
-    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
-    group.add_argument("--streaming", type=str2bool, default=False)
-
     group.add_argument(
-        "--frontend_conf",
-        default=None,
-        help="",
+        "--beam_search_config",
+        default={},
+        help="The keyword arguments for transducer beam search.",
     )
-    group.add_argument("--raw_inputs", type=list, default=None)
-    # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
 
     group = parser.add_argument_group("Text converter related")
     group.add_argument(
@@ -908,14 +759,77 @@
         default=None,
         choices=["char", "bpe", None],
         help="The token type for ASR model. "
-             "If not given, refers from the training args",
+        "If not given, refers from the training args",
     )
     group.add_argument(
         "--bpemodel",
         type=str_or_none,
         default=None,
         help="The model path of sentencepiece. "
-             "If not given, refers from the training args",
+        "If not given, refers from the training args",
+    )
+
+    group = parser.add_argument_group("Dynamic quantization related")
+    parser.add_argument(
+        "--quantize_asr_model",
+        type=bool,
+        default=False,
+        help="Apply dynamic quantization to ASR model.",
+    )
+    parser.add_argument(
+        "--quantize_modules",
+        nargs="*",
+        default=None,
+        help="""Module names to apply dynamic quantization on.
+        The module names are provided as a list, where each name is separated
+        by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
+        Each specified name should be an attribute of 'torch.nn', e.g.:
+        torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
+    )
+    parser.add_argument(
+        "--quantize_dtype",
+        type=str,
+        default="qint8",
+        choices=["float16", "qint8"],
+        help="Dtype for dynamic quantization.",
+    )
+
+    group = parser.add_argument_group("Streaming related")
+    parser.add_argument(
+        "--streaming",
+        type=bool,
+        default=False,
+        help="Whether to perform chunk-by-chunk inference.",
+    )
+    parser.add_argument(
+        "--simu_streaming",
+        type=bool,
+        default=False,
+        help="Whether to simulate chunk-by-chunk inference.",
+    )
+    parser.add_argument(
+        "--chunk_size",
+        type=int,
+        default=16,
+        help="Number of frames in chunk AFTER subsampling.",
+    )
+    parser.add_argument(
+        "--left_context",
+        type=int,
+        default=32,
+        help="Number of frames in left context of the chunk AFTER subsampling.",
+    )
+    parser.add_argument(
+        "--right_context",
+        type=int,
+        default=0,
+        help="Number of frames in right context of the chunk AFTER subsampling.",
+    )
+    parser.add_argument(
+        "--display_partial_hypotheses",
+        type=bool,
+        default=False,
+        help="Whether to display partial hypotheses during chunk-by-chunk inference.",
     )
 
     return parser
@@ -923,24 +837,15 @@
 
 def main(cmd=None):
     print(get_commandline_args(), file=sys.stderr)
+
     parser = get_parser()
     args = parser.parse_args(cmd)
-    param_dict = {'hotword': args.hotword}
     kwargs = vars(args)
+
     kwargs.pop("config", None)
-    kwargs['param_dict'] = param_dict
     inference(**kwargs)
 
 
 if __name__ == "__main__":
     main()
 
-    # from modelscope.pipelines import pipeline
-    # from modelscope.utils.constant import Tasks
-    #
-    # inference_16k_pipline = pipeline(
-    #     task=Tasks.auto_speech_recognition,
-    #     model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
-    #
-    # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
-    # print(rec_result)
diff --git a/funasr/bin/asr_train_transducer.py b/funasr/bin/asr_train_transducer.py
new file mode 100755
index 0000000..9b6d287
--- /dev/null
+++ b/funasr/bin/asr_train_transducer.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+
+import os
+
+from funasr.tasks.asr_transducer import ASRTransducerTask
+
+
+# for ASR Training
+def parse_args():
+    parser = ASRTransducerTask.get_parser()
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def main(args=None, cmd=None):
+    # for ASR Training
+    ASRTransducerTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    # setup local gpu_id
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+    # DDP settings
+    if args.ngpu > 1:
+        args.distributed = True
+    else:
+        args.distributed = False
+    assert args.num_worker_count == 1
+
+    # re-compute batch size: when dataset type is small
+    if args.dataset_type == "small":
+        if args.batch_size is not None:
+            args.batch_size = args.batch_size * args.ngpu
+        if args.batch_bins is not None:
+            args.batch_bins = args.batch_bins * args.ngpu
+
+    main(args=args)
diff --git a/funasr/models_transducer/__init__.py b/funasr/models_transducer/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models_transducer/__init__.py
diff --git a/funasr/models_transducer/activation.py b/funasr/models_transducer/activation.py
new file mode 100644
index 0000000..82cda12
--- /dev/null
+++ b/funasr/models_transducer/activation.py
@@ -0,0 +1,213 @@
+"""Activation functions for Transducer."""
+
+import torch
+from packaging.version import parse as V
+
+
+def get_activation(
+    activation_type: str,
+    ftswish_threshold: float = -0.2,
+    ftswish_mean_shift: float = 0.0,
+    hardtanh_min_val: int = -1.0,
+    hardtanh_max_val: int = 1.0,
+    leakyrelu_neg_slope: float = 0.01,
+    smish_alpha: float = 1.0,
+    smish_beta: float = 1.0,
+    softplus_beta: float = 1.0,
+    softplus_threshold: int = 20,
+    swish_beta: float = 1.0,
+) -> torch.nn.Module:
+    """Return activation function.
+
+    Args:
+        activation_type: Activation function type.
+        ftswish_threshold: Threshold value for FTSwish activation formulation.
+        ftswish_mean_shift: Mean shifting value for FTSwish activation formulation.
+        hardtanh_min_val: Minimum value of the linear region range for HardTanh.
+        hardtanh_max_val: Maximum value of the linear region range for HardTanh.
+        leakyrelu_neg_slope: Negative slope value for LeakyReLU activation formulation.
+        smish_alpha: Alpha value for Smish activation fomulation.
+        smish_beta: Beta value for Smish activation formulation.
+        softplus_beta: Beta value for softplus activation formulation in Mish.
+        softplus_threshold: Values above this revert to a linear function in Mish.
+        swish_beta: Beta value for Swish variant formulation.
+
+    Returns:
+        : Activation function.
+
+    """
+    torch_version = V(torch.__version__)
+
+    activations = {
+        "ftswish": (
+            FTSwish,
+            {"threshold": ftswish_threshold, "mean_shift": ftswish_mean_shift},
+        ),
+        "hardtanh": (
+            torch.nn.Hardtanh,
+            {"min_val": hardtanh_min_val, "max_val": hardtanh_max_val},
+        ),
+        "leaky_relu": (torch.nn.LeakyReLU, {"negative_slope": leakyrelu_neg_slope}),
+        "mish": (
+            Mish,
+            {
+                "softplus_beta": softplus_beta,
+                "softplus_threshold": softplus_threshold,
+                "use_builtin": torch_version >= V("1.9"),
+            },
+        ),
+        "relu": (torch.nn.ReLU, {}),
+        "selu": (torch.nn.SELU, {}),
+        "smish": (Smish, {"alpha": smish_alpha, "beta": smish_beta}),
+        "swish": (
+            Swish,
+            {"beta": swish_beta, "use_builtin": torch_version >= V("1.8")},
+        ),
+        "tanh": (torch.nn.Tanh, {}),
+        "identity": (torch.nn.Identity, {}),
+    }
+
+    act_func, act_args = activations[activation_type]
+
+    return act_func(**act_args)
+
+
+class FTSwish(torch.nn.Module):
+    """Flatten-T Swish activation definition.
+
+    FTSwish(x) = x * sigmoid(x) + threshold
+                  where FTSwish(x) < 0 = threshold
+
+    Reference: https://arxiv.org/abs/1812.06247
+
+    Args:
+        threshold: Threshold value for FTSwish activation formulation. (threshold < 0)
+        mean_shift: Mean shifting value for FTSwish activation formulation.
+                       (applied only if != 0, disabled by default)
+
+    """
+
+    def __init__(self, threshold: float = -0.2, mean_shift: float = 0) -> None:
+        super().__init__()
+
+        assert threshold < 0, "FTSwish threshold parameter should be < 0."
+
+        self.threshold = threshold
+        self.mean_shift = mean_shift
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Forward computation."""
+        x = (x * torch.sigmoid(x)) + self.threshold
+        x = torch.where(x >= 0, x, torch.tensor([self.threshold], device=x.device))
+
+        if self.mean_shift != 0:
+            x.sub_(self.mean_shift)
+
+        return x
+
+
+class Mish(torch.nn.Module):
+    """Mish activation definition.
+
+    Mish(x) = x * tanh(softplus(x))
+
+    Reference: https://arxiv.org/abs/1908.08681.
+
+    Args:
+        softplus_beta: Beta value for softplus activation formulation.
+                         (Usually 0 > softplus_beta >= 2)
+        softplus_threshold: Values above this revert to a linear function.
+                         (Usually 10 > softplus_threshold >= 20)
+        use_builtin: Whether to use PyTorch activation function if available.
+
+    """
+
+    def __init__(
+        self,
+        softplus_beta: float = 1.0,
+        softplus_threshold: int = 20,
+        use_builtin: bool = False,
+    ) -> None:
+        super().__init__()
+
+        if use_builtin:
+            self.mish = torch.nn.Mish()
+        else:
+            self.tanh = torch.nn.Tanh()
+            self.softplus = torch.nn.Softplus(
+                beta=softplus_beta, threshold=softplus_threshold
+            )
+
+            self.mish = lambda x: x * self.tanh(self.softplus(x))
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Forward computation."""
+        return self.mish(x)
+
+
+class Smish(torch.nn.Module):
+    """Smish activation definition.
+
+    Smish(x) = (alpha * x) * tanh(log(1 + sigmoid(beta * x)))
+                 where alpha > 0 and beta > 0
+
+    Reference: https://www.mdpi.com/2079-9292/11/4/540/htm.
+
+    Args:
+        alpha: Alpha value for Smish activation fomulation.
+                 (Usually, alpha = 1. If alpha <= 0, set value to 1).
+        beta: Beta value for Smish activation formulation.
+                (Usually, beta = 1. If beta <= 0, set value to 1).
+
+    """
+
+    def __init__(self, alpha: float = 1.0, beta: float = 1.0) -> None:
+        super().__init__()
+
+        self.tanh = torch.nn.Tanh()
+
+        self.alpha = alpha if alpha > 0 else 1
+        self.beta = beta if beta > 0 else 1
+
+        self.smish = lambda x: (self.alpha * x) * self.tanh(
+            torch.log(1 + torch.sigmoid((self.beta * x)))
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Forward computation."""
+        return self.smish(x)
+
+
+class Swish(torch.nn.Module):
+    """Swish activation definition.
+
+    Swish(x) = (beta * x) * sigmoid(x)
+                 where beta = 1 defines standard Swish activation.
+
+    References:
+        https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
+        E-swish variant: https://arxiv.org/abs/1801.07145.
+
+    Args:
+        beta: Beta parameter for E-Swish.
+                (beta >= 1. If beta < 1, use standard Swish).
+        use_builtin: Whether to use PyTorch function if available.
+
+    """
+
+    def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
+        super().__init__()
+
+        self.beta = beta
+
+        if beta > 1:
+            self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
+        else:
+            if use_builtin:
+                self.swish = torch.nn.SiLU()
+            else:
+                self.swish = lambda x: x * torch.sigmoid(x)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Forward computation."""
+        return self.swish(x)
diff --git a/funasr/models_transducer/beam_search_transducer.py b/funasr/models_transducer/beam_search_transducer.py
new file mode 100644
index 0000000..8e234e4
--- /dev/null
+++ b/funasr/models_transducer/beam_search_transducer.py
@@ -0,0 +1,705 @@
+"""Search algorithms for Transducer models."""
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.models_transducer.joint_network import JointNetwork
+
+
+@dataclass
+class Hypothesis:
+    """Default hypothesis definition for Transducer search algorithms.
+
+    Args:
+        score: Total log-probability.
+        yseq: Label sequence as integer ID sequence.
+        dec_state: RNNDecoder or StatelessDecoder state.
+                     ((N, 1, D_dec), (N, 1, D_dec) or None) or None
+        lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None
+
+    """
+
+    score: float
+    yseq: List[int]
+    dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None
+    lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None
+
+
+@dataclass
+class ExtendedHypothesis(Hypothesis):
+    """Extended hypothesis definition for NSC beam search and mAES.
+
+    Args:
+        : Hypothesis dataclass arguments.
+        dec_out: Decoder output sequence. (B, D_dec)
+        lm_score: Log-probabilities of the LM for given label. (vocab_size)
+
+    """
+
+    dec_out: torch.Tensor = None
+    lm_score: torch.Tensor = None
+
+
+class BeamSearchTransducer:
+    """Beam search implementation for Transducer.
+
+    Args:
+        decoder: Decoder module.
+        joint_network: Joint network module.
+        beam_size: Size of the beam.
+        lm: LM class.
+        lm_weight: LM weight for soft fusion.
+        search_type: Search algorithm to use during inference.
+        max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
+        u_max: Maximum expected target sequence length. (ALSD)
+        nstep: Number of maximum expansion steps at each time step. (mAES)
+        expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
+        expansion_beta:
+             Number of additional candidates for expanded hypotheses selection. (mAES)
+        score_norm: Normalize final scores by length.
+        nbest: Number of final hypothesis.
+        streaming: Whether to perform chunk-by-chunk beam search.
+
+    """
+
+    def __init__(
+        self,
+        decoder: AbsDecoder,
+        joint_network: JointNetwork,
+        beam_size: int,
+        lm: Optional[torch.nn.Module] = None,
+        lm_weight: float = 0.1,
+        search_type: str = "default",
+        max_sym_exp: int = 3,
+        u_max: int = 50,
+        nstep: int = 2,
+        expansion_gamma: float = 2.3,
+        expansion_beta: int = 2,
+        score_norm: bool = False,
+        nbest: int = 1,
+        streaming: bool = False,
+    ) -> None:
+        """Construct a BeamSearchTransducer object."""
+        super().__init__()
+
+        self.decoder = decoder
+        self.joint_network = joint_network
+
+        self.vocab_size = decoder.vocab_size
+
+        assert beam_size <= self.vocab_size, (
+            "beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
+            % (
+                beam_size,
+                self.vocab_size,
+            )
+        )
+        self.beam_size = beam_size
+
+        if search_type == "default":
+            self.search_algorithm = self.default_beam_search
+        elif search_type == "tsd":
+            assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
+                max_sym_exp
+            )
+            self.max_sym_exp = max_sym_exp
+
+            self.search_algorithm = self.time_sync_decoding
+        elif search_type == "alsd":
+            assert not streaming, "ALSD is not available in streaming mode."
+
+            assert u_max >= 0, "u_max should be a positive integer, a portion of max_T."
+            self.u_max = u_max
+
+            self.search_algorithm = self.align_length_sync_decoding
+        elif search_type == "maes":
+            assert self.vocab_size >= beam_size + expansion_beta, (
+                "beam_size (%d) + expansion_beta (%d) "
+                " should be smaller than or equal to vocab size (%d)."
+                % (beam_size, expansion_beta, self.vocab_size)
+            )
+            self.max_candidates = beam_size + expansion_beta
+
+            self.nstep = nstep
+            self.expansion_gamma = expansion_gamma
+
+            self.search_algorithm = self.modified_adaptive_expansion_search
+        else:
+            raise NotImplementedError(
+                "Specified search type (%s) is not supported." % search_type
+            )
+
+        self.use_lm = lm is not None
+
+        if self.use_lm:
+            assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported."
+
+            self.sos = self.vocab_size - 1
+
+            self.lm = lm
+            self.lm_weight = lm_weight
+
+        self.score_norm = score_norm
+        self.nbest = nbest
+
+        self.reset_inference_cache()
+
+    def __call__(
+        self,
+        enc_out: torch.Tensor,
+        is_final: bool = True,
+    ) -> List[Hypothesis]:
+        """Perform beam search.
+
+        Args:
+            enc_out: Encoder output sequence. (T, D_enc)
+            is_final: Whether enc_out is the final chunk of data.
+
+        Returns:
+            nbest_hyps: N-best decoding results
+
+        """
+        self.decoder.set_device(enc_out.device)
+
+        hyps = self.search_algorithm(enc_out)
+
+        if is_final:
+            self.reset_inference_cache()
+
+            return self.sort_nbest(hyps)
+
+        self.search_cache = hyps
+
+        return hyps
+
+    def reset_inference_cache(self) -> None:
+        """Reset cache for decoder scoring and streaming."""
+        self.decoder.score_cache = {}
+        self.search_cache = None
+
+    def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
+        """Sort in-place hypotheses by score or score given sequence length.
+
+        Args:
+            hyps: Hypothesis.
+
+        Return:
+            hyps: Sorted hypothesis.
+
+        """
+        if self.score_norm:
+            hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
+        else:
+            hyps.sort(key=lambda x: x.score, reverse=True)
+
+        return hyps[: self.nbest]
+
+    def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
+        """Recombine hypotheses with same label ID sequence.
+
+        Args:
+            hyps: Hypotheses.
+
+        Returns:
+            final: Recombined hypotheses.
+
+        """
+        final = {}
+
+        for hyp in hyps:
+            str_yseq = "_".join(map(str, hyp.yseq))
+
+            if str_yseq in final:
+                final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score)
+            else:
+                final[str_yseq] = hyp
+
+        return [*final.values()]
+
+    def select_k_expansions(
+        self,
+        hyps: List[ExtendedHypothesis],
+        topk_idx: torch.Tensor,
+        topk_logp: torch.Tensor,
+    ) -> List[ExtendedHypothesis]:
+        """Return K hypotheses candidates for expansion from a list of hypothesis.
+
+        K candidates are selected according to the extended hypotheses probabilities
+        and a prune-by-value method. Where K is equal to beam_size + beta.
+
+        Args:
+            hyps: Hypotheses.
+            topk_idx: Indices of candidates hypothesis.
+            topk_logp: Log-probabilities of candidates hypothesis.
+
+        Returns:
+            k_expansions: Best K expansion hypotheses candidates.
+
+        """
+        k_expansions = []
+
+        for i, hyp in enumerate(hyps):
+            hyp_i = [
+                (int(k), hyp.score + float(v))
+                for k, v in zip(topk_idx[i], topk_logp[i])
+            ]
+            k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
+
+            k_expansions.append(
+                sorted(
+                    filter(
+                        lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
+                    ),
+                    key=lambda x: x[1],
+                    reverse=True,
+                )
+            )
+
+        return k_expansions
+
+    def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor:
+        """Make batch of inputs with left padding for LM scoring.
+
+        Args:
+            hyps_seq: Hypothesis sequences.
+
+        Returns:
+            : Padded batch of sequences.
+
+        """
+        max_len = max([len(h) for h in hyps_seq])
+
+        return torch.LongTensor(
+            [[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq],
+            device=self.decoder.device,
+        )
+
+    def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
+        """Beam search implementation without prefix search.
+
+        Modified from https://arxiv.org/pdf/1211.3711.pdf
+
+        Args:
+            enc_out: Encoder output sequence. (T, D)
+
+        Returns:
+            nbest_hyps: N-best hypothesis.
+
+        """
+        beam_k = min(self.beam_size, (self.vocab_size - 1))
+        max_t = len(enc_out)
+
+        if self.search_cache is not None:
+            kept_hyps = self.search_cache
+        else:
+            kept_hyps = [
+                Hypothesis(
+                    score=0.0,
+                    yseq=[0],
+                    dec_state=self.decoder.init_state(1),
+                )
+            ]
+
+        for t in range(max_t):
+            hyps = kept_hyps
+            kept_hyps = []
+
+            while True:
+                max_hyp = max(hyps, key=lambda x: x.score)
+                hyps.remove(max_hyp)
+
+                label = torch.full(
+                    (1, 1),
+                    max_hyp.yseq[-1],
+                    dtype=torch.long,
+                    device=self.decoder.device,
+                )
+                dec_out, state = self.decoder.score(
+                    label,
+                    max_hyp.yseq,
+                    max_hyp.dec_state,
+                )
+
+                logp = torch.log_softmax(
+                    self.joint_network(enc_out[t : t + 1, :], dec_out),
+                    dim=-1,
+                ).squeeze(0)
+                top_k = logp[1:].topk(beam_k, dim=-1)
+
+                kept_hyps.append(
+                    Hypothesis(
+                        score=(max_hyp.score + float(logp[0:1])),
+                        yseq=max_hyp.yseq,
+                        dec_state=max_hyp.dec_state,
+                        lm_state=max_hyp.lm_state,
+                    )
+                )
+
+                if self.use_lm:
+                    lm_scores, lm_state = self.lm.score(
+                        torch.LongTensor(
+                            [self.sos] + max_hyp.yseq[1:], device=self.decoder.device
+                        ),
+                        max_hyp.lm_state,
+                        None,
+                    )
+                else:
+                    lm_state = max_hyp.lm_state
+
+                for logp, k in zip(*top_k):
+                    score = max_hyp.score + float(logp)
+
+                    if self.use_lm:
+                        score += self.lm_weight * lm_scores[k + 1]
+
+                    hyps.append(
+                        Hypothesis(
+                            score=score,
+                            yseq=max_hyp.yseq + [int(k + 1)],
+                            dec_state=state,
+                            lm_state=lm_state,
+                        )
+                    )
+
+                hyps_max = float(max(hyps, key=lambda x: x.score).score)
+                kept_most_prob = sorted(
+                    [hyp for hyp in kept_hyps if hyp.score > hyps_max],
+                    key=lambda x: x.score,
+                )
+                if len(kept_most_prob) >= self.beam_size:
+                    kept_hyps = kept_most_prob
+                    break
+
+        return kept_hyps
+    
+    def align_length_sync_decoding(
+        self,
+        enc_out: torch.Tensor,
+    ) -> List[Hypothesis]:
+        """Alignment-length synchronous beam search implementation.
+
+        Based on https://ieeexplore.ieee.org/document/9053040
+
+        Args:
+            h: Encoder output sequences. (T, D)
+
+        Returns:
+            nbest_hyps: N-best hypothesis.
+
+        """
+        t_max = int(enc_out.size(0))
+        u_max = min(self.u_max, (t_max - 1))
+
+        B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))]
+        final = []
+
+        if self.use_lm:
+            B[0].lm_state = self.lm.zero_state()
+
+        for i in range(t_max + u_max):
+            A = []
+
+            B_ = []
+            B_enc_out = []
+            for hyp in B:
+                u = len(hyp.yseq) - 1
+                t = i - u
+
+                if t > (t_max - 1):
+                    continue
+
+                B_.append(hyp)
+                B_enc_out.append((t, enc_out[t]))
+
+            if B_:
+                beam_enc_out = torch.stack([b[1] for b in B_enc_out])
+                beam_dec_out, beam_state = self.decoder.batch_score(B_)
+
+                beam_logp = torch.log_softmax(
+                    self.joint_network(beam_enc_out, beam_dec_out),
+                    dim=-1,
+                )
+                beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
+
+                if self.use_lm:
+                    beam_lm_scores, beam_lm_states = self.lm.batch_score(
+                        self.create_lm_batch_inputs([b.yseq for b in B_]),
+                        [b.lm_state for b in B_],
+                        None,
+                    )
+
+                for i, hyp in enumerate(B_):
+                    new_hyp = Hypothesis(
+                        score=(hyp.score + float(beam_logp[i, 0])),
+                        yseq=hyp.yseq[:],
+                        dec_state=hyp.dec_state,
+                        lm_state=hyp.lm_state,
+                    )
+
+                    A.append(new_hyp)
+
+                    if B_enc_out[i][0] == (t_max - 1):
+                        final.append(new_hyp)
+
+                    for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
+                        new_hyp = Hypothesis(
+                            score=(hyp.score + float(logp)),
+                            yseq=(hyp.yseq[:] + [int(k)]),
+                            dec_state=self.decoder.select_state(beam_state, i),
+                            lm_state=hyp.lm_state,
+                        )
+
+                        if self.use_lm:
+                            new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
+                            new_hyp.lm_state = beam_lm_states[i]
+
+                        A.append(new_hyp)
+
+                B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
+                B = self.recombine_hyps(B)
+
+        if final:
+            return final
+
+        return B
+
+    def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
+        """Time synchronous beam search implementation.
+
+        Based on https://ieeexplore.ieee.org/document/9053040
+
+        Args:
+            enc_out: Encoder output sequence. (T, D)
+
+        Returns:
+            nbest_hyps: N-best hypothesis.
+
+        """
+        if self.search_cache is not None:
+            B = self.search_cache
+        else:
+            B = [
+                Hypothesis(
+                    yseq=[0],
+                    score=0.0,
+                    dec_state=self.decoder.init_state(1),
+                )
+            ]
+
+            if self.use_lm:
+                B[0].lm_state = self.lm.zero_state()
+
+        for enc_out_t in enc_out:
+            A = []
+            C = B
+
+            enc_out_t = enc_out_t.unsqueeze(0)
+
+            for v in range(self.max_sym_exp):
+                D = []
+
+                beam_dec_out, beam_state = self.decoder.batch_score(C)
+
+                beam_logp = torch.log_softmax(
+                    self.joint_network(enc_out_t, beam_dec_out),
+                    dim=-1,
+                )
+                beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
+
+                seq_A = [h.yseq for h in A]
+
+                for i, hyp in enumerate(C):
+                    if hyp.yseq not in seq_A:
+                        A.append(
+                            Hypothesis(
+                                score=(hyp.score + float(beam_logp[i, 0])),
+                                yseq=hyp.yseq[:],
+                                dec_state=hyp.dec_state,
+                                lm_state=hyp.lm_state,
+                            )
+                        )
+                    else:
+                        dict_pos = seq_A.index(hyp.yseq)
+
+                        A[dict_pos].score = np.logaddexp(
+                            A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
+                        )
+
+                if v < (self.max_sym_exp - 1):
+                    if self.use_lm:
+                        beam_lm_scores, beam_lm_states = self.lm.batch_score(
+                            self.create_lm_batch_inputs([c.yseq for c in C]),
+                            [c.lm_state for c in C],
+                            None,
+                        )
+
+                    for i, hyp in enumerate(C):
+                        for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
+                            new_hyp = Hypothesis(
+                                score=(hyp.score + float(logp)),
+                                yseq=(hyp.yseq + [int(k)]),
+                                dec_state=self.decoder.select_state(beam_state, i),
+                                lm_state=hyp.lm_state,
+                            )
+
+                            if self.use_lm:
+                                new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
+                                new_hyp.lm_state = beam_lm_states[i]
+
+                            D.append(new_hyp)
+
+                C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size]
+
+            B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
+
+        return B
+
+    def modified_adaptive_expansion_search(
+        self,
+        enc_out: torch.Tensor,
+    ) -> List[ExtendedHypothesis]:
+        """Modified version of Adaptive Expansion Search (mAES).
+
+        Based on AES (https://ieeexplore.ieee.org/document/9250505) and
+                 NSC (https://arxiv.org/abs/2201.05420).
+
+        Args:
+            enc_out: Encoder output sequence. (T, D_enc)
+
+        Returns:
+            nbest_hyps: N-best hypothesis.
+
+        """
+        if self.search_cache is not None:
+            kept_hyps = self.search_cache
+        else:
+            init_tokens = [
+                ExtendedHypothesis(
+                    yseq=[0],
+                    score=0.0,
+                    dec_state=self.decoder.init_state(1),
+                )
+            ]
+
+            beam_dec_out, beam_state = self.decoder.batch_score(
+                init_tokens,
+            )
+
+            if self.use_lm:
+                beam_lm_scores, beam_lm_states = self.lm.batch_score(
+                    self.create_lm_batch_inputs([h.yseq for h in init_tokens]),
+                    [h.lm_state for h in init_tokens],
+                    None,
+                )
+
+                lm_state = beam_lm_states[0]
+                lm_score = beam_lm_scores[0]
+            else:
+                lm_state = None
+                lm_score = None
+
+            kept_hyps = [
+                ExtendedHypothesis(
+                    yseq=[0],
+                    score=0.0,
+                    dec_state=self.decoder.select_state(beam_state, 0),
+                    dec_out=beam_dec_out[0],
+                    lm_state=lm_state,
+                    lm_score=lm_score,
+                )
+            ]
+
+        for enc_out_t in enc_out:
+            hyps = kept_hyps
+            kept_hyps = []
+
+            beam_enc_out = enc_out_t.unsqueeze(0)
+
+            list_b = []
+            for n in range(self.nstep):
+                beam_dec_out = torch.stack([h.dec_out for h in hyps])
+
+                beam_logp, beam_idx = torch.log_softmax(
+                    self.joint_network(beam_enc_out, beam_dec_out),
+                    dim=-1,
+                ).topk(self.max_candidates, dim=-1)
+
+                k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp)
+
+                list_exp = []
+                for i, hyp in enumerate(hyps):
+                    for k, new_score in k_expansions[i]:
+                        new_hyp = ExtendedHypothesis(
+                            yseq=hyp.yseq[:],
+                            score=new_score,
+                            dec_out=hyp.dec_out,
+                            dec_state=hyp.dec_state,
+                            lm_state=hyp.lm_state,
+                            lm_score=hyp.lm_score,
+                        )
+
+                        if k == 0:
+                            list_b.append(new_hyp)
+                        else:
+                            new_hyp.yseq.append(int(k))
+
+                            if self.use_lm:
+                                new_hyp.score += self.lm_weight * float(hyp.lm_score[k])
+
+                            list_exp.append(new_hyp)
+
+                if not list_exp:
+                    kept_hyps = sorted(
+                        self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True
+                    )[: self.beam_size]
+
+                    break
+                else:
+                    beam_dec_out, beam_state = self.decoder.batch_score(
+                        list_exp,
+                    )
+
+                    if self.use_lm:
+                        beam_lm_scores, beam_lm_states = self.lm.batch_score(
+                            self.create_lm_batch_inputs([h.yseq for h in list_exp]),
+                            [h.lm_state for h in list_exp],
+                            None,
+                        )
+
+                    if n < (self.nstep - 1):
+                        for i, hyp in enumerate(list_exp):
+                            hyp.dec_out = beam_dec_out[i]
+                            hyp.dec_state = self.decoder.select_state(beam_state, i)
+
+                            if self.use_lm:
+                                hyp.lm_state = beam_lm_states[i]
+                                hyp.lm_score = beam_lm_scores[i]
+
+                        hyps = list_exp[:]
+                    else:
+                        beam_logp = torch.log_softmax(
+                            self.joint_network(beam_enc_out, beam_dec_out),
+                            dim=-1,
+                        )
+
+                        for i, hyp in enumerate(list_exp):
+                            hyp.score += float(beam_logp[i, 0])
+
+                            hyp.dec_out = beam_dec_out[i]
+                            hyp.dec_state = self.decoder.select_state(beam_state, i)
+
+                            if self.use_lm:
+                                hyp.lm_state = beam_lm_states[i]
+                                hyp.lm_score = beam_lm_scores[i]
+
+                        kept_hyps = sorted(
+                            self.recombine_hyps(list_b + list_exp),
+                            key=lambda x: x.score,
+                            reverse=True,
+                        )[: self.beam_size]
+
+        return kept_hyps
diff --git a/funasr/models_transducer/decoder/__init__.py b/funasr/models_transducer/decoder/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models_transducer/decoder/__init__.py
diff --git a/funasr/models_transducer/decoder/abs_decoder.py b/funasr/models_transducer/decoder/abs_decoder.py
new file mode 100644
index 0000000..5b4a335
--- /dev/null
+++ b/funasr/models_transducer/decoder/abs_decoder.py
@@ -0,0 +1,110 @@
+"""Abstract decoder definition for Transducer models."""
+
+from abc import ABC, abstractmethod
+from typing import Any, List, Optional, Tuple
+
+import torch
+
+
+class AbsDecoder(torch.nn.Module, ABC):
+    """Abstract decoder module."""
+
+    @abstractmethod
+    def forward(self, labels: torch.Tensor) -> torch.Tensor:
+        """Encode source label sequences.
+
+        Args:
+            labels: Label ID sequences. (B, L)
+
+        Returns:
+            dec_out: Decoder output sequences. (B, T, D_dec)
+
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def score(
+        self,
+        label: torch.Tensor,
+        label_sequence: List[int],
+        dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]],
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]:
+        """One-step forward hypothesis.
+
+        Args:
+            label: Previous label. (1, 1)
+            label_sequence: Current label sequence.
+            dec_state: Previous decoder hidden states.
+                         ((N, 1, D_dec), (N, 1, D_dec) or None) or None
+
+        Returns:
+            dec_out: Decoder output sequence. (1, D_dec) or (1, D_emb)
+            dec_state: Decoder hidden states.
+                         ((N, 1, D_dec), (N, 1, D_dec) or None) or None
+
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def batch_score(
+        self,
+        hyps: List[Any],
+    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]:
+        """One-step forward hypotheses.
+
+        Args:
+            hyps: Hypotheses.
+
+        Returns:
+            dec_out: Decoder output sequences. (B, D_dec) or (B, D_emb)
+            states: Decoder hidden states.
+                      ((N, B, D_dec), (N, B, D_dec) or None) or None
+
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def set_device(self, device: torch.Tensor) -> None:
+        """Set GPU device to use.
+
+        Args:
+            device: Device ID.
+
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def init_state(
+        self, batch_size: int
+    ) -> Optional[Tuple[torch.Tensor, Optional[torch.tensor]]]:
+        """Initialize decoder states.
+
+        Args:
+            batch_size: Batch size.
+
+        Returns:
+            : Initial decoder hidden states.
+                ((N, B, D_dec), (N, B, D_dec) or None) or None
+
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def select_state(
+        self,
+        states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
+        idx: int = 0,
+    ) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
+        """Get specified ID state from batch of states, if provided.
+
+        Args:
+            states: Decoder hidden states.
+                      ((N, B, D_dec), (N, B, D_dec) or None) or None
+            idx: State ID to extract.
+
+        Returns:
+            : Decoder hidden state for given ID.
+                ((N, 1, D_dec), (N, 1, D_dec) or None) or None
+
+        """
+        raise NotImplementedError
diff --git a/funasr/models_transducer/decoder/rnn_decoder.py b/funasr/models_transducer/decoder/rnn_decoder.py
new file mode 100644
index 0000000..04c3228
--- /dev/null
+++ b/funasr/models_transducer/decoder/rnn_decoder.py
@@ -0,0 +1,259 @@
+"""RNN decoder definition for Transducer models."""
+
+from typing import List, Optional, Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.models_transducer.beam_search_transducer import Hypothesis
+from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.models.specaug.specaug import SpecAug
+
+class RNNDecoder(AbsDecoder):
+    """RNN decoder module.
+
+    Args:
+        vocab_size: Vocabulary size.
+        embed_size: Embedding size.
+        hidden_size: Hidden size..
+        rnn_type: Decoder layers type.
+        num_layers: Number of decoder layers.
+        dropout_rate: Dropout rate for decoder layers.
+        embed_dropout_rate: Dropout rate for embedding layer.
+        embed_pad: Embedding padding symbol ID.
+
+    """
+
+    def __init__(
+        self,
+        vocab_size: int,
+        embed_size: int = 256,
+        hidden_size: int = 256,
+        rnn_type: str = "lstm",
+        num_layers: int = 1,
+        dropout_rate: float = 0.0,
+        embed_dropout_rate: float = 0.0,
+        embed_pad: int = 0,
+    ) -> None:
+        """Construct a RNNDecoder object."""
+        super().__init__()
+
+        assert check_argument_types()
+
+        if rnn_type not in ("lstm", "gru"):
+            raise ValueError(f"Not supported: rnn_type={rnn_type}")
+
+        self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
+        self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
+
+        rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
+
+        self.rnn = torch.nn.ModuleList(
+            [rnn_class(embed_size, hidden_size, 1, batch_first=True)]
+        )
+
+        for _ in range(1, num_layers):
+            self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
+
+        self.dropout_rnn = torch.nn.ModuleList(
+            [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
+        )
+
+        self.dlayers = num_layers
+        self.dtype = rnn_type
+
+        self.output_size = hidden_size
+        self.vocab_size = vocab_size
+
+        self.device = next(self.parameters()).device
+        self.score_cache = {}
+    
+    def forward(
+        self,
+        labels: torch.Tensor,
+        label_lens: torch.Tensor,
+        states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
+    ) -> torch.Tensor:
+        """Encode source label sequences.
+
+        Args:
+            labels: Label ID sequences. (B, L)
+            states: Decoder hidden states.
+                      ((N, B, D_dec), (N, B, D_dec) or None) or None
+
+        Returns:
+            dec_out: Decoder output sequences. (B, U, D_dec)
+
+        """
+        if states is None:
+            states = self.init_state(labels.size(0))
+
+        dec_embed = self.dropout_embed(self.embed(labels))
+        dec_out, states = self.rnn_forward(dec_embed, states)
+        return dec_out
+
+    def rnn_forward(
+        self,
+        x: torch.Tensor,
+        state: Tuple[torch.Tensor, Optional[torch.Tensor]],
+    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
+        """Encode source label sequences.
+
+        Args:
+            x: RNN input sequences. (B, D_emb)
+            state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+        Returns:
+            x: RNN output sequences. (B, D_dec)
+            (h_next, c_next): Decoder hidden states.
+                                (N, B, D_dec), (N, B, D_dec) or None)
+
+        """
+        h_prev, c_prev = state
+        h_next, c_next = self.init_state(x.size(0))
+
+        for layer in range(self.dlayers):
+            if self.dtype == "lstm":
+                x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
+                    layer
+                ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
+            else:
+                x, h_next[layer : layer + 1] = self.rnn[layer](
+                    x, hx=h_prev[layer : layer + 1]
+                )
+
+            x = self.dropout_rnn[layer](x)
+
+        return x, (h_next, c_next)
+
+    def score(
+        self,
+        label: torch.Tensor,
+        label_sequence: List[int],
+        dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
+    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
+        """One-step forward hypothesis.
+
+        Args:
+            label: Previous label. (1, 1)
+            label_sequence: Current label sequence.
+            dec_state: Previous decoder hidden states.
+                         ((N, 1, D_dec), (N, 1, D_dec) or None)
+
+        Returns:
+            dec_out: Decoder output sequence. (1, D_dec)
+            dec_state: Decoder hidden states.
+                         ((N, 1, D_dec), (N, 1, D_dec) or None)
+
+        """
+        str_labels = "_".join(map(str, label_sequence))
+
+        if str_labels in self.score_cache:
+            dec_out, dec_state = self.score_cache[str_labels]
+        else:
+            dec_embed = self.embed(label)
+            dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
+
+            self.score_cache[str_labels] = (dec_out, dec_state)
+
+        return dec_out[0], dec_state
+
+    def batch_score(
+        self,
+        hyps: List[Hypothesis],
+    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
+        """One-step forward hypotheses.
+
+        Args:
+            hyps: Hypotheses.
+
+        Returns:
+            dec_out: Decoder output sequences. (B, D_dec)
+            states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+        """
+        labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
+        dec_embed = self.embed(labels)
+
+        states = self.create_batch_states([h.dec_state for h in hyps])
+        dec_out, states = self.rnn_forward(dec_embed, states)
+
+        return dec_out.squeeze(1), states
+
+    def set_device(self, device: torch.device) -> None:
+        """Set GPU device to use.
+
+        Args:
+            device: Device ID.
+
+        """
+        self.device = device
+
+    def init_state(
+        self, batch_size: int
+    ) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
+        """Initialize decoder states.
+
+        Args:
+            batch_size: Batch size.
+
+        Returns:
+            : Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+        """
+        h_n = torch.zeros(
+            self.dlayers,
+            batch_size,
+            self.output_size,
+            device=self.device,
+        )
+
+        if self.dtype == "lstm":
+            c_n = torch.zeros(
+                self.dlayers,
+                batch_size,
+                self.output_size,
+                device=self.device,
+            )
+
+            return (h_n, c_n)
+
+        return (h_n, None)
+
+    def select_state(
+        self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """Get specified ID state from decoder hidden states.
+
+        Args:
+            states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+            idx: State ID to extract.
+
+        Returns:
+            : Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
+
+        """
+        return (
+            states[0][:, idx : idx + 1, :],
+            states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
+        )
+
+    def create_batch_states(
+        self,
+        new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """Create decoder hidden states.
+
+        Args:
+            new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
+
+        Returns:
+            states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+        """
+        return (
+            torch.cat([s[0] for s in new_states], dim=1),
+            torch.cat([s[1] for s in new_states], dim=1)
+            if self.dtype == "lstm"
+            else None,
+        )
diff --git a/funasr/models_transducer/decoder/stateless_decoder.py b/funasr/models_transducer/decoder/stateless_decoder.py
new file mode 100644
index 0000000..07c8f51
--- /dev/null
+++ b/funasr/models_transducer/decoder/stateless_decoder.py
@@ -0,0 +1,157 @@
+"""Stateless decoder definition for Transducer models."""
+
+from typing import List, Optional, Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.models_transducer.beam_search_transducer import Hypothesis
+from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.models.specaug.specaug import SpecAug
+
+class StatelessDecoder(AbsDecoder):
+    """Stateless Transducer decoder module.
+
+    Args:
+        vocab_size: Output size.
+        embed_size: Embedding size.
+        embed_dropout_rate: Dropout rate for embedding layer.
+        embed_pad: Embed/Blank symbol ID.
+
+    """
+
+    def __init__(
+        self,
+        vocab_size: int,
+        embed_size: int = 256,
+        embed_dropout_rate: float = 0.0,
+        embed_pad: int = 0,
+        use_embed_mask: bool = False,
+    ) -> None:
+        """Construct a StatelessDecoder object."""
+        super().__init__()
+
+        assert check_argument_types()
+
+        self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
+        self.embed_dropout_rate = torch.nn.Dropout(p=embed_dropout_rate)
+
+        self.output_size = embed_size
+        self.vocab_size = vocab_size
+
+        self.device = next(self.parameters()).device
+        self.score_cache = {}
+
+        self.use_embed_mask = use_embed_mask
+        if self.use_embed_mask:
+            self._embed_mask = SpecAug(
+                time_mask_width_range=3,
+                num_time_mask=1,
+                apply_freq_mask=False,
+                apply_time_warp=False
+            )
+
+
+    def forward(
+        self,
+        labels: torch.Tensor,
+        label_lens: torch.Tensor,
+        states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
+    ) -> torch.Tensor:
+        """Encode source label sequences.
+
+        Args:
+            labels: Label ID sequences. (B, L)
+            states: Decoder hidden states. None
+
+        Returns:
+            dec_embed: Decoder output sequences. (B, U, D_emb)
+
+        """
+        dec_embed = self.embed_dropout_rate(self.embed(labels))
+        if self.use_embed_mask and self.training:
+            dec_embed = self._embed_mask(dec_embed, label_lens)[0]
+
+        return dec_embed
+
+    def score(
+        self,
+        label: torch.Tensor,
+        label_sequence: List[int],
+        state: None,
+    ) -> Tuple[torch.Tensor, None]:
+        """One-step forward hypothesis.
+
+        Args:
+            label: Previous label. (1, 1)
+            label_sequence: Current label sequence.
+            state: Previous decoder hidden states. None
+
+        Returns:
+            dec_out: Decoder output sequence. (1, D_emb)
+            state: Decoder hidden states. None
+
+        """
+        str_labels = "_".join(map(str, label_sequence))
+
+        if str_labels in self.score_cache:
+            dec_embed = self.score_cache[str_labels]
+        else:
+            dec_embed = self.embed(label)
+
+            self.score_cache[str_labels] = dec_embed
+
+        return dec_embed[0], None
+
+    def batch_score(
+        self,
+        hyps: List[Hypothesis],
+    ) -> Tuple[torch.Tensor, None]:
+        """One-step forward hypotheses.
+
+        Args:
+            hyps: Hypotheses.
+
+        Returns:
+            dec_out: Decoder output sequences. (B, D_dec)
+            states: Decoder hidden states. None
+
+        """
+        labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
+        dec_embed = self.embed(labels)
+
+        return dec_embed.squeeze(1), None
+
+    def set_device(self, device: torch.device) -> None:
+        """Set GPU device to use.
+
+        Args:
+            device: Device ID.
+
+        """
+        self.device = device
+
+    def init_state(self, batch_size: int) -> None:
+        """Initialize decoder states.
+
+        Args:
+            batch_size: Batch size.
+
+        Returns:
+            : Initial decoder hidden states. None
+
+        """
+        return None
+
+    def select_state(self, states: Optional[torch.Tensor], idx: int) -> None:
+        """Get specified ID state from decoder hidden states.
+
+        Args:
+            states: Decoder hidden states. None
+            idx: State ID to extract.
+
+        Returns:
+            : Decoder hidden state for given ID. None
+
+        """
+        return None
diff --git a/funasr/models_transducer/encoder/__init__.py b/funasr/models_transducer/encoder/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models_transducer/encoder/__init__.py
diff --git a/funasr/models_transducer/encoder/blocks/__init__.py b/funasr/models_transducer/encoder/blocks/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models_transducer/encoder/blocks/__init__.py
diff --git a/funasr/models_transducer/encoder/blocks/branchformer.py b/funasr/models_transducer/encoder/blocks/branchformer.py
new file mode 100644
index 0000000..ba0b25d
--- /dev/null
+++ b/funasr/models_transducer/encoder/blocks/branchformer.py
@@ -0,0 +1,178 @@
+"""Branchformer block for Transducer encoder."""
+
+from typing import Dict, Optional, Tuple
+
+import torch
+
+
+class Branchformer(torch.nn.Module):
+    """Branchformer module definition.
+
+    Reference: https://arxiv.org/pdf/2207.02971.pdf
+
+    Args:
+        block_size: Input/output size.
+        linear_size: Linear layers' hidden size.
+        self_att: Self-attention module instance.
+        conv_mod: Convolution module instance.
+        norm_class: Normalization class.
+        norm_args: Normalization module arguments.
+        dropout_rate: Dropout rate.
+
+    """
+
+    def __init__(
+        self,
+        block_size: int,
+        linear_size: int,
+        self_att: torch.nn.Module,
+        conv_mod: torch.nn.Module,
+        norm_class: torch.nn.Module = torch.nn.LayerNorm,
+        norm_args: Dict = {},
+        dropout_rate: float = 0.0,
+    ) -> None:
+        """Construct a Branchformer object."""
+        super().__init__()
+
+        self.self_att = self_att
+        self.conv_mod = conv_mod
+
+        self.channel_proj1 = torch.nn.Sequential(
+            torch.nn.Linear(block_size, linear_size), torch.nn.GELU()
+        )
+        self.channel_proj2 = torch.nn.Linear(linear_size // 2, block_size)
+
+        self.merge_proj = torch.nn.Linear(block_size + block_size, block_size)
+
+        self.norm_self_att = norm_class(block_size, **norm_args)
+        self.norm_mlp = norm_class(block_size, **norm_args)
+        self.norm_final = norm_class(block_size, **norm_args)
+
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+        self.block_size = block_size
+        self.linear_size = linear_size
+        self.cache = None
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset self-attention and convolution modules cache for streaming.
+
+        Args:
+            left_context: Number of left frames during chunk-by-chunk inference.
+            device: Device to use for cache tensor.
+
+        """
+        self.cache = [
+            torch.zeros(
+                (1, left_context, self.block_size),
+                device=device,
+            ),
+            torch.zeros(
+                (
+                    1,
+                    self.linear_size // 2,
+                    self.conv_mod.kernel_size - 1,
+                ),
+                device=device,
+            ),
+        ]
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+
+        Args:
+            x: Branchformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T)
+            chunk_mask: Chunk mask. (T_2, T_2)
+
+        Returns:
+            x: Branchformer output sequences. (B, T, D_block)
+            mask: Source mask. (B, T)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+
+        """
+        x1 = x
+        x2 = x
+
+        x1 = self.norm_self_att(x1)
+
+        x1 = self.dropout(
+            self.self_att(x1, x1, x1, pos_enc, mask=mask, chunk_mask=chunk_mask)
+        )
+
+        x2 = self.norm_mlp(x2)
+
+        x2 = self.channel_proj1(x2)
+        x2, _ = self.conv_mod(x2)
+        x2 = self.channel_proj2(x2)
+
+        x2 = self.dropout(x2)
+
+        x = x + self.dropout(self.merge_proj(torch.cat([x1, x2], dim=-1)))
+
+        x = self.norm_final(x)
+
+        return x, mask, pos_enc
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        left_context: int = 0,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode chunk of input sequence.
+
+        Args:
+            x: Branchformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T_2)
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+
+        Returns:
+            x: Branchformer output sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+
+        """
+        x1 = x
+        x2 = x
+
+        x1 = self.norm_self_att(x1)
+
+        if left_context > 0:
+            key = torch.cat([self.cache[0], x1], dim=1)
+        else:
+            key = x1
+        val = key
+
+        if right_context > 0:
+            att_cache = key[:, -(left_context + right_context) : -right_context, :]
+        else:
+            att_cache = key[:, -left_context:, :]
+
+        x1 = self.self_att(x1, key, val, pos_enc, mask=mask, left_context=left_context)
+
+        x2 = self.norm_mlp(x2)
+        x2 = self.channel_proj1(x2)
+
+        x2, conv_cache = self.conv_mod(
+            x2, cache=self.cache[1], right_context=right_context
+        )
+
+        x2 = self.channel_proj2(x2)
+
+        x = x + self.merge_proj(torch.cat([x1, x2], dim=-1))
+
+        x = self.norm_final(x)
+        self.cache = [att_cache, conv_cache]
+
+        return x, pos_enc
diff --git a/funasr/models_transducer/encoder/blocks/conformer.py b/funasr/models_transducer/encoder/blocks/conformer.py
new file mode 100644
index 0000000..0b9bbbf
--- /dev/null
+++ b/funasr/models_transducer/encoder/blocks/conformer.py
@@ -0,0 +1,198 @@
+"""Conformer block for Transducer encoder."""
+
+from typing import Dict, Optional, Tuple
+
+import torch
+
+
+class Conformer(torch.nn.Module):
+    """Conformer module definition.
+
+    Args:
+        block_size: Input/output size.
+        self_att: Self-attention module instance.
+        feed_forward: Feed-forward module instance.
+        feed_forward_macaron: Feed-forward module instance for macaron network.
+        conv_mod: Convolution module instance.
+        norm_class: Normalization module class.
+        norm_args: Normalization module arguments.
+        dropout_rate: Dropout rate.
+
+    """
+
+    def __init__(
+        self,
+        block_size: int,
+        self_att: torch.nn.Module,
+        feed_forward: torch.nn.Module,
+        feed_forward_macaron: torch.nn.Module,
+        conv_mod: torch.nn.Module,
+        norm_class: torch.nn.Module = torch.nn.LayerNorm,
+        norm_args: Dict = {},
+        dropout_rate: float = 0.0,
+    ) -> None:
+        """Construct a Conformer object."""
+        super().__init__()
+
+        self.self_att = self_att
+
+        self.feed_forward = feed_forward
+        self.feed_forward_macaron = feed_forward_macaron
+        self.feed_forward_scale = 0.5
+
+        self.conv_mod = conv_mod
+
+        self.norm_feed_forward = norm_class(block_size, **norm_args)
+        self.norm_self_att = norm_class(block_size, **norm_args)
+
+        self.norm_macaron = norm_class(block_size, **norm_args)
+        self.norm_conv = norm_class(block_size, **norm_args)
+        self.norm_final = norm_class(block_size, **norm_args)
+
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+        self.block_size = block_size
+        self.cache = None
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset self-attention and convolution modules cache for streaming.
+
+        Args:
+            left_context: Number of left frames during chunk-by-chunk inference.
+            device: Device to use for cache tensor.
+
+        """
+        self.cache = [
+            torch.zeros(
+                (1, left_context, self.block_size),
+                device=device,
+            ),
+            torch.zeros(
+                (
+                    1,
+                    self.block_size,
+                    self.conv_mod.kernel_size - 1,
+                ),
+                device=device,
+            ),
+        ]
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+
+        Args:
+            x: Conformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T)
+            chunk_mask: Chunk mask. (T_2, T_2)
+
+        Returns:
+            x: Conformer output sequences. (B, T, D_block)
+            mask: Source mask. (B, T)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+
+        """
+        residual = x
+
+        x = self.norm_macaron(x)
+        x = residual + self.feed_forward_scale * self.dropout(
+            self.feed_forward_macaron(x)
+        )
+
+        residual = x
+        x = self.norm_self_att(x)
+        x_q = x
+        x = residual + self.dropout(
+            self.self_att(
+                x_q,
+                x,
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=chunk_mask,
+            )
+        )
+
+        residual = x
+
+        x = self.norm_conv(x)
+        x, _ = self.conv_mod(x)
+        x = residual + self.dropout(x)
+        residual = x
+
+        x = self.norm_feed_forward(x)
+        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
+
+        x = self.norm_final(x)
+        return x, mask, pos_enc
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_size: int = 16,
+        left_context: int = 0,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode chunk of input sequence.
+
+        Args:
+            x: Conformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T_2)
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+
+        Returns:
+            x: Conformer output sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+
+        """
+        residual = x
+
+        x = self.norm_macaron(x)
+        x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
+
+        residual = x
+        x = self.norm_self_att(x)
+        if left_context > 0:
+            key = torch.cat([self.cache[0], x], dim=1)
+        else:
+            key = x
+        val = key
+
+        if right_context > 0:
+            att_cache = key[:, -(left_context + right_context) : -right_context, :]
+        else:
+            att_cache = key[:, -left_context:, :]
+        x = residual + self.self_att(
+            x,
+            key,
+            val,
+            pos_enc,
+            mask,
+            left_context=left_context,
+        )
+
+        residual = x
+        x = self.norm_conv(x)
+        x, conv_cache = self.conv_mod(
+            x, cache=self.cache[1], right_context=right_context
+        )
+        x = residual + x
+        residual = x
+        
+        x = self.norm_feed_forward(x)
+        x = residual + self.feed_forward_scale * self.feed_forward(x)
+
+        x = self.norm_final(x)
+        self.cache = [att_cache, conv_cache]
+       
+        return x, pos_enc
diff --git a/funasr/models_transducer/encoder/blocks/conv1d.py b/funasr/models_transducer/encoder/blocks/conv1d.py
new file mode 100644
index 0000000..f79cc37
--- /dev/null
+++ b/funasr/models_transducer/encoder/blocks/conv1d.py
@@ -0,0 +1,221 @@
+"""Conv1d block for Transducer encoder."""
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+
+class Conv1d(torch.nn.Module):
+    """Conv1d module definition.
+
+    Args:
+        input_size: Input dimension.
+        output_size: Output dimension.
+        kernel_size: Size of the convolving kernel.
+        stride: Stride of the convolution.
+        dilation: Spacing between the kernel points.
+        groups: Number of blocked connections from input channels to output channels.
+        bias: Whether to add a learnable bias to the output.
+        batch_norm: Whether to use batch normalization after convolution.
+        relu: Whether to use a ReLU activation after convolution.
+        causal: Whether to use causal convolution (set to True if streaming).
+        dropout_rate: Dropout rate.
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int,
+        kernel_size: Union[int, Tuple],
+        stride: Union[int, Tuple] = 1,
+        dilation: Union[int, Tuple] = 1,
+        groups: Union[int, Tuple] = 1,
+        bias: bool = True,
+        batch_norm: bool = False,
+        relu: bool = True,
+        causal: bool = False,
+        dropout_rate: float = 0.0,
+    ) -> None:
+        """Construct a Conv1d object."""
+        super().__init__()
+
+        if causal:
+            self.lorder = kernel_size - 1
+            stride = 1
+        else:
+            self.lorder = 0
+            stride = stride
+
+        self.conv = torch.nn.Conv1d(
+            input_size,
+            output_size,
+            kernel_size,
+            stride=stride,
+            dilation=dilation,
+            groups=groups,
+            bias=bias,
+        )
+
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+        if relu:
+            self.relu_func = torch.nn.ReLU()
+
+        if batch_norm:
+            self.bn = torch.nn.BatchNorm1d(output_size)
+
+        self.out_pos = torch.nn.Linear(input_size, output_size)
+
+        self.input_size = input_size
+        self.output_size = output_size
+
+        self.relu = relu
+        self.batch_norm = batch_norm
+        self.causal = causal
+
+        self.kernel_size = kernel_size
+        self.padding = dilation * (kernel_size - 1)
+        self.stride = stride
+
+        self.cache = None
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset Conv1d cache for streaming.
+
+        Args:
+            left_context: Number of left frames during chunk-by-chunk inference.
+            device: Device to use for cache tensor.
+
+        """
+        self.cache = torch.zeros(
+            (1, self.input_size, self.kernel_size - 1), device=device
+        )
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: Optional[torch.Tensor] = None,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+
+        Args:
+            x: Conv1d input sequences. (B, T, D_in)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in)
+            mask: Source mask. (B, T)
+            chunk_mask: Chunk mask. (T_2, T_2)
+
+        Returns:
+            x: Conv1d output sequences. (B, sub(T), D_out)
+            mask: Source mask. (B, T) or (B, sub(T))
+            pos_enc: Positional embedding sequences.
+                       (B, 2 * (T - 1), D_att) or (B, 2 * (sub(T) - 1), D_out)
+
+        """
+        x = x.transpose(1, 2)
+
+        if self.lorder > 0:
+            x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+        else:
+            mask = self.create_new_mask(mask)
+            pos_enc = self.create_new_pos_enc(pos_enc)
+
+        x = self.conv(x)
+
+        if self.batch_norm:
+            x = self.bn(x)
+
+        x = self.dropout(x)
+
+        if self.relu:
+            x = self.relu_func(x)
+
+        x = x.transpose(1, 2)
+
+        return x, mask, self.out_pos(pos_enc)
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        left_context: int = 0,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode chunk of input sequence.
+
+        Args:
+            x: Conv1d input sequences. (B, T, D_in)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_in)
+            mask: Source mask. (B, T)
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+
+        Returns:
+            x: Conv1d output sequences. (B, T, D_out)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_out)
+
+        """
+        x = torch.cat([self.cache, x.transpose(1, 2)], dim=2)
+
+        if right_context > 0:
+            self.cache = x[:, :, -(self.lorder + right_context) : -right_context]
+        else:
+            self.cache = x[:, :, -self.lorder :]
+
+        x = self.conv(x)
+
+        if self.batch_norm:
+            x = self.bn(x)
+
+        x = self.dropout(x)
+
+        if self.relu:
+            x = self.relu_func(x)
+
+        x = x.transpose(1, 2)
+
+        return x, self.out_pos(pos_enc)
+
+    def create_new_mask(self, mask: torch.Tensor) -> torch.Tensor:
+        """Create new mask for output sequences.
+
+        Args:
+            mask: Mask of input sequences. (B, T)
+
+        Returns:
+            mask: Mask of output sequences. (B, sub(T))
+
+        """
+        if self.padding != 0:
+            mask = mask[:, : -self.padding]
+
+        return mask[:, :: self.stride]
+
+    def create_new_pos_enc(self, pos_enc: torch.Tensor) -> torch.Tensor:
+        """Create new positional embedding vector.
+
+        Args:
+            pos_enc: Input sequences positional embedding.
+                     (B, 2 * (T - 1), D_in)
+
+        Returns:
+            pos_enc: Output sequences positional embedding.
+                     (B, 2 * (sub(T) - 1), D_in)
+
+        """
+        pos_enc_positive = pos_enc[:, : pos_enc.size(1) // 2 + 1, :]
+        pos_enc_negative = pos_enc[:, pos_enc.size(1) // 2 :, :]
+
+        if self.padding != 0:
+            pos_enc_positive = pos_enc_positive[:, : -self.padding, :]
+            pos_enc_negative = pos_enc_negative[:, : -self.padding, :]
+
+        pos_enc_positive = pos_enc_positive[:, :: self.stride, :]
+        pos_enc_negative = pos_enc_negative[:, :: self.stride, :]
+
+        pos_enc = torch.cat([pos_enc_positive, pos_enc_negative[:, 1:, :]], dim=1)
+
+        return pos_enc
diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models_transducer/encoder/blocks/conv_input.py
new file mode 100644
index 0000000..931d0f0
--- /dev/null
+++ b/funasr/models_transducer/encoder/blocks/conv_input.py
@@ -0,0 +1,226 @@
+"""ConvInput block for Transducer encoder."""
+
+from typing import Optional, Tuple, Union
+
+import torch
+import math
+
+from funasr.models_transducer.utils import sub_factor_to_params, pad_to_len
+
+
+class ConvInput(torch.nn.Module):
+    """ConvInput module definition.
+
+    Args:
+        input_size: Input size.
+        conv_size: Convolution size.
+        subsampling_factor: Subsampling factor.
+        vgg_like: Whether to use a VGG-like network.
+        output_size: Block output dimension.
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        conv_size: Union[int, Tuple],
+        subsampling_factor: int = 4,
+        vgg_like: bool = True,
+        output_size: Optional[int] = None,
+    ) -> None:
+        """Construct a ConvInput object."""
+        super().__init__()
+        if vgg_like:
+            if subsampling_factor == 1:
+                conv_size1, conv_size2 = conv_size
+
+                self.conv = torch.nn.Sequential(
+                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.MaxPool2d((1, 2)),
+                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.MaxPool2d((1, 2)),
+                )
+
+                output_proj = conv_size2 * ((input_size // 2) // 2)
+ 
+                self.subsampling_factor = 1
+
+                self.stride_1 = 1
+
+                self.create_new_mask = self.create_new_vgg_mask
+
+            else:
+                conv_size1, conv_size2 = conv_size
+
+                kernel_1 = int(subsampling_factor / 2)
+
+                self.conv = torch.nn.Sequential(
+                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.MaxPool2d((kernel_1, 2)),
+                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+                    torch.nn.ReLU(),
+                    torch.nn.MaxPool2d((2, 2)),
+                )
+
+                output_proj = conv_size2 * ((input_size // 2) // 2)
+
+                self.subsampling_factor = subsampling_factor
+
+                self.create_new_mask = self.create_new_vgg_mask
+                
+                self.stride_1 = kernel_1
+
+        else:
+            if subsampling_factor == 1:
+                self.conv = torch.nn.Sequential(
+                    torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
+                    torch.nn.ReLU(),
+                )
+
+                output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
+
+                self.subsampling_factor = subsampling_factor
+                self.kernel_2 = 3
+                self.stride_2 = 1
+
+                self.create_new_mask = self.create_new_conv2d_mask
+
+            else:
+                kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
+                    subsampling_factor,
+                    input_size,
+                )
+
+                self.conv = torch.nn.Sequential(
+                    torch.nn.Conv2d(1, conv_size, 3, 2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
+                    torch.nn.ReLU(),
+                )
+
+                output_proj = conv_size * conv_2_output_size
+
+                self.subsampling_factor = subsampling_factor
+                self.kernel_2 = kernel_2
+                self.stride_2 = stride_2
+
+                self.create_new_mask = self.create_new_conv2d_mask
+
+        self.vgg_like = vgg_like
+        self.min_frame_length = 2
+
+        if output_size is not None:
+            self.output = torch.nn.Linear(output_proj, output_size)
+            self.output_size = output_size
+        else:
+            self.output = None
+            self.output_size = output_proj
+
+    def forward(
+        self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+
+        Args:
+            x: ConvInput input sequences. (B, T, D_feats)
+            mask: Mask of input sequences. (B, 1, T)
+
+        Returns:
+            x: ConvInput output sequences. (B, sub(T), D_out)
+            mask: Mask of output sequences. (B, 1, sub(T))
+
+        """
+        if mask is not None:
+            mask = self.create_new_mask(mask)
+            olens = max(mask.eq(0).sum(1))
+        
+        b, t_input, f = x.size()
+        x = x.unsqueeze(1) # (b. 1. t. f)
+        if chunk_size is not None:
+            max_input_length = int(
+                chunk_size * self.subsampling_factor * (math.ceil(float(t_input) / (chunk_size * self.subsampling_factor) ))
+            )
+            x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
+            x = list(x)
+            x = torch.stack(x, dim=0)
+            N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
+            x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
+        x = self.conv(x)
+
+        _, c, t, f = x.size()
+        
+        if chunk_size is not None:
+            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
+        else:
+            x = x.transpose(1, 2).contiguous().view(b, t, c * f)
+
+        if self.output is not None:
+            x = self.output(x)
+        
+        return x, mask[:,:olens][:,:x.size(1)]
+
+    def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
+        """Create a new mask for VGG output sequences.
+
+        Args:
+            mask: Mask of input sequences. (B, T)
+
+        Returns:
+            mask: Mask of output sequences. (B, sub(T))
+
+        """
+        if self.subsampling_factor > 1:
+            vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
+            mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
+
+            vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
+            mask = mask[:, :vgg2_t_len][:, ::2]
+        else:
+            mask = mask
+
+        return mask
+
+    def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
+        """Create new conformer mask for Conv2d output sequences.
+
+        Args:
+            mask: Mask of input sequences. (B, T)
+
+        Returns:
+            mask: Mask of output sequences. (B, sub(T))
+
+        """
+        if self.subsampling_factor > 1:
+            return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
+        else:
+            return mask
+
+    def get_size_before_subsampling(self, size: int) -> int:
+        """Return the original size before subsampling for a given size.
+
+        Args:
+            size: Number of frames after subsampling.
+
+        Returns:
+            : Number of frames before subsampling.
+
+        """
+        if self.subsampling_factor > 1:
+            if self.vgg_like:
+                return ((size * 2) * self.stride_1) + 1
+
+            return ((size + 2) * 2) + (self.kernel_2 - 1) * self.stride_2
+        return size
diff --git a/funasr/models_transducer/encoder/blocks/linear_input.py b/funasr/models_transducer/encoder/blocks/linear_input.py
new file mode 100644
index 0000000..9bb9698
--- /dev/null
+++ b/funasr/models_transducer/encoder/blocks/linear_input.py
@@ -0,0 +1,52 @@
+"""LinearInput block for Transducer encoder."""
+
+from typing import Optional, Tuple, Union
+
+import torch
+
+class LinearInput(torch.nn.Module):
+    """ConvInput module definition.
+
+    Args:
+        input_size: Input size.
+        conv_size: Convolution size.
+        subsampling_factor: Subsampling factor.
+        vgg_like: Whether to use a VGG-like network.
+        output_size: Block output dimension.
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: Optional[int] = None,
+        subsampling_factor: int = 1,
+    ) -> None:
+        """Construct a ConvInput object."""
+        super().__init__()
+        self.embed = torch.nn.Sequential(
+            torch.nn.Linear(input_size, output_size),
+            torch.nn.LayerNorm(output_size),
+            torch.nn.Dropout(0.1),
+        )
+        self.subsampling_factor = subsampling_factor
+        self.min_frame_length = 1
+
+    def forward(
+        self, x: torch.Tensor, mask: Optional[torch.Tensor]
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        
+        x = self.embed(x)
+        return x, mask
+
+    def get_size_before_subsampling(self, size: int) -> int:
+        """Return the original size before subsampling for a given size.
+
+        Args:
+            size: Number of frames after subsampling.
+
+        Returns:
+            : Number of frames before subsampling.
+
+        """
+        return size
diff --git a/funasr/models_transducer/encoder/building.py b/funasr/models_transducer/encoder/building.py
new file mode 100644
index 0000000..a19943b
--- /dev/null
+++ b/funasr/models_transducer/encoder/building.py
@@ -0,0 +1,352 @@
+"""Set of methods to build Transducer encoder architecture."""
+
+from typing import Any, Dict, List, Optional, Union
+
+from funasr.models_transducer.activation import get_activation
+from funasr.models_transducer.encoder.blocks.branchformer import Branchformer
+from funasr.models_transducer.encoder.blocks.conformer import Conformer
+from funasr.models_transducer.encoder.blocks.conv1d import Conv1d
+from funasr.models_transducer.encoder.blocks.conv_input import ConvInput
+from funasr.models_transducer.encoder.blocks.linear_input import LinearInput
+from funasr.models_transducer.encoder.modules.attention import (  # noqa: H301
+    RelPositionMultiHeadedAttention,
+)
+from funasr.models_transducer.encoder.modules.convolution import (  # noqa: H301
+    ConformerConvolution,
+    ConvolutionalSpatialGatingUnit,
+)
+from funasr.models_transducer.encoder.modules.multi_blocks import MultiBlocks
+from funasr.models_transducer.encoder.modules.normalization import get_normalization
+from funasr.models_transducer.encoder.modules.positional_encoding import (  # noqa: H301
+    RelPositionalEncoding,
+)
+from funasr.modules.positionwise_feed_forward import (
+    PositionwiseFeedForward,
+)
+
+
+def build_main_parameters(
+    pos_wise_act_type: str = "swish",
+    conv_mod_act_type: str = "swish",
+    pos_enc_dropout_rate: float = 0.0,
+    pos_enc_max_len: int = 5000,
+    simplified_att_score: bool = False,
+    norm_type: str = "layer_norm",
+    conv_mod_norm_type: str = "layer_norm",
+    after_norm_eps: Optional[float] = None,
+    after_norm_partial: Optional[float] = None,
+    dynamic_chunk_training: bool = False,
+    short_chunk_threshold: float = 0.75,
+    short_chunk_size: int = 25,
+    left_chunk_size: int = 0,
+    time_reduction_factor: int = 1,
+    unified_model_training: bool = False,
+    default_chunk_size: int = 16,
+    jitter_range: int =4,
+    **activation_parameters,
+) -> Dict[str, Any]:
+    """Build encoder main parameters.
+
+    Args:
+        pos_wise_act_type: Conformer position-wise feed-forward activation type.
+        conv_mod_act_type: Conformer convolution module activation type.
+        pos_enc_dropout_rate: Positional encoding dropout rate.
+        pos_enc_max_len: Positional encoding maximum length.
+        simplified_att_score: Whether to use simplified attention score computation.
+        norm_type: X-former normalization module type.
+        conv_mod_norm_type: Conformer convolution module normalization type.
+        after_norm_eps: Epsilon value for the final normalization.
+        after_norm_partial: Value for the final normalization with RMSNorm.
+        dynamic_chunk_training: Whether to use dynamic chunk training.
+        short_chunk_threshold: Threshold for dynamic chunk selection.
+        short_chunk_size: Minimum number of frames during dynamic chunk training.
+        left_chunk_size: Number of frames in left context.
+        **activations_parameters: Parameters of the activation functions.
+                                    (See espnet2/asr_transducer/activation.py)
+
+    Returns:
+        : Main encoder parameters
+
+    """
+    main_params = {}
+
+    main_params["pos_wise_act"] = get_activation(
+        pos_wise_act_type, **activation_parameters
+    )
+
+    main_params["conv_mod_act"] = get_activation(
+        conv_mod_act_type, **activation_parameters
+    )
+
+    main_params["pos_enc_dropout_rate"] = pos_enc_dropout_rate
+    main_params["pos_enc_max_len"] = pos_enc_max_len
+
+    main_params["simplified_att_score"] = simplified_att_score
+
+    main_params["norm_type"] = norm_type
+    main_params["conv_mod_norm_type"] = conv_mod_norm_type
+
+    (
+        main_params["after_norm_class"],
+        main_params["after_norm_args"],
+    ) = get_normalization(norm_type, eps=after_norm_eps, partial=after_norm_partial)
+
+    main_params["dynamic_chunk_training"] = dynamic_chunk_training
+    main_params["short_chunk_threshold"] = max(0, short_chunk_threshold)
+    main_params["short_chunk_size"] = max(0, short_chunk_size)
+    main_params["left_chunk_size"] = max(0, left_chunk_size)
+    
+    main_params["unified_model_training"] = unified_model_training
+    main_params["default_chunk_size"] = max(0, default_chunk_size)
+    main_params["jitter_range"] = max(0, jitter_range)
+   
+    main_params["time_reduction_factor"] = time_reduction_factor
+
+    return main_params
+
+
+def build_positional_encoding(
+    block_size: int, configuration: Dict[str, Any]
+) -> RelPositionalEncoding:
+    """Build positional encoding block.
+
+    Args:
+        block_size: Input/output size.
+        configuration: Positional encoding configuration.
+
+    Returns:
+        : Positional encoding module.
+
+    """
+    return RelPositionalEncoding(
+        block_size,
+        configuration.get("pos_enc_dropout_rate", 0.0),
+        max_len=configuration.get("pos_enc_max_len", 5000),
+    )
+
+
+def build_input_block(
+    input_size: int,
+    configuration: Dict[str, Union[str, int]],
+) -> ConvInput:
+    """Build encoder input block.
+
+    Args:
+        input_size: Input size.
+        configuration: Input block configuration.
+
+    Returns:
+        : ConvInput block function.
+
+    """
+    if configuration["linear"]:
+        return LinearInput(
+            input_size,
+            configuration["output_size"],
+            configuration["subsampling_factor"],
+        )
+    else:
+        return ConvInput(
+            input_size,
+            configuration["conv_size"],
+            configuration["subsampling_factor"],
+            vgg_like=configuration["vgg_like"],
+            output_size=configuration["output_size"],
+        )
+
+
+def build_branchformer_block(
+    configuration: List[Dict[str, Any]],
+    main_params: Dict[str, Any],
+) -> Conformer:
+    """Build Branchformer block.
+
+    Args:
+        configuration: Branchformer block configuration.
+        main_params: Encoder main parameters.
+
+    Returns:
+        : Branchformer block function.
+
+    """
+    hidden_size = configuration["hidden_size"]
+    linear_size = configuration["linear_size"]
+
+    dropout_rate = configuration.get("dropout_rate", 0.0)
+
+    conv_mod_norm_class, conv_mod_norm_args = get_normalization(
+        main_params["conv_mod_norm_type"],
+        eps=configuration.get("conv_mod_norm_eps"),
+        partial=configuration.get("conv_mod_norm_partial"),
+    )
+
+    conv_mod_args = (
+        linear_size,
+        configuration["conv_mod_kernel_size"],
+        conv_mod_norm_class,
+        conv_mod_norm_args,
+        dropout_rate,
+        main_params["dynamic_chunk_training"],
+    )
+
+    mult_att_args = (
+        configuration.get("heads", 4),
+        hidden_size,
+        configuration.get("att_dropout_rate", 0.0),
+        main_params["simplified_att_score"],
+    )
+
+    norm_class, norm_args = get_normalization(
+        main_params["norm_type"],
+        eps=configuration.get("norm_eps"),
+        partial=configuration.get("norm_partial"),
+    )
+
+    return lambda: Branchformer(
+        hidden_size,
+        linear_size,
+        RelPositionMultiHeadedAttention(*mult_att_args),
+        ConvolutionalSpatialGatingUnit(*conv_mod_args),
+        norm_class=norm_class,
+        norm_args=norm_args,
+        dropout_rate=dropout_rate,
+    )
+
+
+def build_conformer_block(
+    configuration: List[Dict[str, Any]],
+    main_params: Dict[str, Any],
+) -> Conformer:
+    """Build Conformer block.
+
+    Args:
+        configuration: Conformer block configuration.
+        main_params: Encoder main parameters.
+
+    Returns:
+        : Conformer block function.
+
+    """
+    hidden_size = configuration["hidden_size"]
+    linear_size = configuration["linear_size"]
+
+    pos_wise_args = (
+        hidden_size,
+        linear_size,
+        configuration.get("pos_wise_dropout_rate", 0.0),
+        main_params["pos_wise_act"],
+    )
+
+    conv_mod_norm_args = {
+        "eps": configuration.get("conv_mod_norm_eps", 1e-05),
+        "momentum": configuration.get("conv_mod_norm_momentum", 0.1),
+    }
+
+    conv_mod_args = (
+        hidden_size,
+        configuration["conv_mod_kernel_size"],
+        main_params["conv_mod_act"],
+        conv_mod_norm_args,
+        main_params["dynamic_chunk_training"] or main_params["unified_model_training"],
+    )
+
+    mult_att_args = (
+        configuration.get("heads", 4),
+        hidden_size,
+        configuration.get("att_dropout_rate", 0.0),
+        main_params["simplified_att_score"],
+    )
+
+    norm_class, norm_args = get_normalization(
+        main_params["norm_type"],
+        eps=configuration.get("norm_eps"),
+        partial=configuration.get("norm_partial"),
+    )
+
+    return lambda: Conformer(
+        hidden_size,
+        RelPositionMultiHeadedAttention(*mult_att_args),
+        PositionwiseFeedForward(*pos_wise_args),
+        PositionwiseFeedForward(*pos_wise_args),
+        ConformerConvolution(*conv_mod_args),
+        norm_class=norm_class,
+        norm_args=norm_args,
+        dropout_rate=configuration.get("dropout_rate", 0.0),
+    )
+
+
+def build_conv1d_block(
+    configuration: List[Dict[str, Any]],
+    causal: bool,
+) -> Conv1d:
+    """Build Conv1d block.
+
+    Args:
+        configuration: Conv1d block configuration.
+
+    Returns:
+        : Conv1d block function.
+
+    """
+    return lambda: Conv1d(
+        configuration["input_size"],
+        configuration["output_size"],
+        configuration["kernel_size"],
+        stride=configuration.get("stride", 1),
+        dilation=configuration.get("dilation", 1),
+        groups=configuration.get("groups", 1),
+        bias=configuration.get("bias", True),
+        relu=configuration.get("relu", True),
+        batch_norm=configuration.get("batch_norm", False),
+        causal=causal,
+        dropout_rate=configuration.get("dropout_rate", 0.0),
+    )
+
+
+def build_body_blocks(
+    configuration: List[Dict[str, Any]],
+    main_params: Dict[str, Any],
+    output_size: int,
+) -> MultiBlocks:
+    """Build encoder body blocks.
+
+    Args:
+        configuration: Body blocks configuration.
+        main_params: Encoder main parameters.
+        output_size: Architecture output size.
+
+    Returns:
+        MultiBlocks function encapsulation all encoder blocks.
+
+    """
+    fn_modules = []
+    extended_conf = []
+
+    for c in configuration:
+        if c.get("num_blocks") is not None:
+            extended_conf += c["num_blocks"] * [
+                {c_i: c[c_i] for c_i in c if c_i != "num_blocks"}
+            ]
+        else:
+            extended_conf += [c]
+
+    for i, c in enumerate(extended_conf):
+        block_type = c["block_type"]
+
+        if block_type == "branchformer":
+            module = build_branchformer_block(c, main_params)
+        elif block_type == "conformer":
+            module = build_conformer_block(c, main_params)
+        elif block_type == "conv1d":
+            module = build_conv1d_block(c, main_params["dynamic_chunk_training"])
+        else:
+            raise NotImplementedError
+
+        fn_modules.append(module)
+
+    return MultiBlocks(
+        [fn() for fn in fn_modules],
+        output_size,
+        norm_class=main_params["after_norm_class"],
+        norm_args=main_params["after_norm_args"],
+    )
diff --git a/funasr/models_transducer/encoder/encoder.py b/funasr/models_transducer/encoder/encoder.py
new file mode 100644
index 0000000..45c99c1
--- /dev/null
+++ b/funasr/models_transducer/encoder/encoder.py
@@ -0,0 +1,294 @@
+"""Encoder for Transducer model."""
+
+from typing import Any, Dict, List, Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.models_transducer.encoder.building import (
+    build_body_blocks,
+    build_input_block,
+    build_main_parameters,
+    build_positional_encoding,
+)
+from funasr.models_transducer.encoder.validation import validate_architecture
+from funasr.models_transducer.utils import (
+    TooShortUttError,
+    check_short_utt,
+    make_chunk_mask,
+    make_source_mask,
+)
+
+
+class Encoder(torch.nn.Module):
+    """Encoder module definition.
+
+    Args:
+        input_size: Input size.
+        body_conf: Encoder body configuration.
+        input_conf: Encoder input configuration.
+        main_conf: Encoder main configuration.
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        body_conf: List[Dict[str, Any]],
+        input_conf: Dict[str, Any] = {},
+        main_conf: Dict[str, Any] = {},
+    ) -> None:
+        """Construct an Encoder object."""
+        super().__init__()
+
+        assert check_argument_types()
+
+        embed_size, output_size = validate_architecture(
+            input_conf, body_conf, input_size
+        )
+        main_params = build_main_parameters(**main_conf)
+
+        self.embed = build_input_block(input_size, input_conf)
+        self.pos_enc = build_positional_encoding(embed_size, main_params)
+        self.encoders = build_body_blocks(body_conf, main_params, output_size)
+
+        self.output_size = output_size
+
+        self.dynamic_chunk_training = main_params["dynamic_chunk_training"]
+        self.short_chunk_threshold = main_params["short_chunk_threshold"]
+        self.short_chunk_size = main_params["short_chunk_size"]
+        self.left_chunk_size = main_params["left_chunk_size"]
+
+        self.unified_model_training = main_params["unified_model_training"]
+        self.default_chunk_size = main_params["default_chunk_size"]
+        self.jitter_range = main_params["jitter_range"]       
+
+        self.time_reduction_factor = main_params["time_reduction_factor"] 
+
+    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
+        """Return the corresponding number of sample for a given chunk size, in frames.
+
+        Where size is the number of features frames after applying subsampling.
+
+        Args:
+            size: Number of frames after subsampling.
+            hop_length: Frontend's hop length
+
+        Returns:
+            : Number of raw samples
+
+        """
+        return self.embed.get_size_before_subsampling(size) * hop_length
+    
+    def get_encoder_input_size(self, size: int) -> int:
+        """Return the corresponding number of sample for a given chunk size, in frames.
+
+        Where size is the number of features frames after applying subsampling.
+
+        Args:
+            size: Number of frames after subsampling.
+
+        Returns:
+            : Number of raw samples
+
+        """
+        return self.embed.get_size_before_subsampling(size)
+
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset encoder streaming cache.
+
+        Args:
+            left_context: Number of frames in left context.
+            device: Device ID.
+
+        """
+        return self.encoders.reset_streaming_cache(left_context, device)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len)
+        if self.unified_model_training:
+            x, mask = self.embed(x, mask, self.default_chunk_size)
+        else:
+            x, mask = self.embed(x, mask)
+        pos_enc = self.pos_enc(x)
+
+        if self.unified_model_training:
+            chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            chunk_mask = make_chunk_mask(
+                x.size(1),
+                chunk_size,
+                left_chunk_size=self.left_chunk_size,
+                device=x.device,
+            )
+            x_utt = self.encoders(
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=None,
+            )
+            x_chunk = self.encoders(
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=chunk_mask,
+            )
+       
+            olens = mask.eq(0).sum(1)
+            if self.time_reduction_factor > 1:
+                x_utt = x_utt[:,::self.time_reduction_factor,:]
+                x_chunk = x_chunk[:,::self.time_reduction_factor,:]
+                olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+            return x_utt, x_chunk, olens
+
+        elif self.dynamic_chunk_training:
+            max_len = x.size(1)
+            chunk_size = torch.randint(1, max_len, (1,)).item()
+
+            if chunk_size > (max_len * self.short_chunk_threshold):
+                chunk_size = max_len
+            else:
+                chunk_size = (chunk_size % self.short_chunk_size) + 1
+
+            chunk_mask = make_chunk_mask(
+                x.size(1),
+                chunk_size,
+                left_chunk_size=self.left_chunk_size,
+                device=x.device,
+            )
+        else:
+            chunk_mask = None
+        x = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=chunk_mask,
+        )
+        
+        olens = mask.eq(0).sum(1)
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+        return x, olens
+     
+    def simu_chunk_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len)
+
+        x, mask = self.embed(x, mask, chunk_size)
+        pos_enc = self.pos_enc(x)
+        chunk_mask = make_chunk_mask(
+            x.size(1),
+            chunk_size,
+            left_chunk_size=self.left_chunk_size,
+            device=x.device,
+        )
+
+        x = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=chunk_mask,
+        )
+        olens = mask.eq(0).sum(1)
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+
+        return x
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+        processed_frames: torch.tensor,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        """Encode input sequences as chunks.
+
+        Args:
+            x: Encoder input features. (1, T_in, F)
+            x_len: Encoder input features lengths. (1,)
+            processed_frames: Number of frames already seen.
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+
+        """
+        mask = make_source_mask(x_len)
+        x, mask = self.embed(x, mask, None)
+
+        if left_context > 0:
+            processed_mask = (
+                torch.arange(left_context, device=x.device)
+                .view(1, left_context)
+                .flip(1)
+            )
+            processed_mask = processed_mask >= processed_frames
+            mask = torch.cat([processed_mask, mask], dim=1)
+        pos_enc = self.pos_enc(x, left_context=left_context)
+        x = self.encoders.chunk_forward(
+            x,
+            pos_enc,
+            mask,
+            chunk_size=chunk_size,
+            left_context=left_context,
+            right_context=right_context,
+        )
+
+        if right_context > 0:
+            x = x[:, 0:-right_context, :]
+        
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+        return x
diff --git a/funasr/models_transducer/encoder/modules/__init__.py b/funasr/models_transducer/encoder/modules/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models_transducer/encoder/modules/__init__.py
diff --git a/funasr/models_transducer/encoder/modules/attention.py b/funasr/models_transducer/encoder/modules/attention.py
new file mode 100644
index 0000000..53e7087
--- /dev/null
+++ b/funasr/models_transducer/encoder/modules/attention.py
@@ -0,0 +1,246 @@
+"""Multi-Head attention layers with relative positional encoding."""
+
+import math
+from typing import Optional, Tuple
+
+import torch
+
+
+class RelPositionMultiHeadedAttention(torch.nn.Module):
+    """RelPositionMultiHeadedAttention definition.
+
+    Args:
+        num_heads: Number of attention heads.
+        embed_size: Embedding size.
+        dropout_rate: Dropout rate.
+
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        embed_size: int,
+        dropout_rate: float = 0.0,
+        simplified_attention_score: bool = False,
+    ) -> None:
+        """Construct an MultiHeadedAttention object."""
+        super().__init__()
+
+        self.d_k = embed_size // num_heads
+        self.num_heads = num_heads
+
+        assert self.d_k * num_heads == embed_size, (
+            "embed_size (%d) must be divisible by num_heads (%d)",
+            (embed_size, num_heads),
+        )
+
+        self.linear_q = torch.nn.Linear(embed_size, embed_size)
+        self.linear_k = torch.nn.Linear(embed_size, embed_size)
+        self.linear_v = torch.nn.Linear(embed_size, embed_size)
+
+        self.linear_out = torch.nn.Linear(embed_size, embed_size)
+
+        if simplified_attention_score:
+            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
+
+            self.compute_att_score = self.compute_simplified_attention_score
+        else:
+            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
+
+            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            torch.nn.init.xavier_uniform_(self.pos_bias_u)
+            torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+            self.compute_att_score = self.compute_attention_score
+
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.attn = None
+
+    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+        """Compute relative positional encoding.
+
+        Args:
+            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
+            left_context: Number of frames in left context.
+
+        Returns:
+            x: Output sequence. (B, H, T_1, T_2)
+
+        """
+        batch_size, n_heads, time1, n = x.shape
+        time2 = time1 + left_context
+
+        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
+
+        return x.as_strided(
+            (batch_size, n_heads, time1, time2),
+            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
+            storage_offset=(n_stride * (time1 - 1)),
+        )
+
+    def compute_simplified_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Simplified attention score computation.
+
+        Reference: https://github.com/k2-fsa/icefall/pull/458
+
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+
+        """
+        pos_enc = self.linear_pos(pos_enc)
+
+        matrix_ac = torch.matmul(query, key.transpose(2, 3))
+
+        matrix_bd = self.rel_shift(
+            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
+            left_context=left_context,
+        )
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def compute_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Attention score computation.
+
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+
+        """
+        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
+
+        query = query.transpose(1, 2)
+        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
+        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def forward_qkv(
+        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Transform query, key and value.
+
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            v: Value tensor. (B, T_2, size)
+
+        Returns:
+            q: Transformed query tensor. (B, H, T_1, d_k)
+            k: Transformed key tensor. (B, H, T_2, d_k)
+            v: Transformed value tensor. (B, H, T_2, d_k)
+
+        """
+        n_batch = query.size(0)
+
+        q = (
+            self.linear_q(query)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        k = (
+            self.linear_k(key)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        v = (
+            self.linear_v(value)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+
+        return q, k, v
+
+    def forward_attention(
+        self,
+        value: torch.Tensor,
+        scores: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Compute attention context vector.
+
+        Args:
+            value: Transformed value. (B, H, T_2, d_k)
+            scores: Attention score. (B, H, T_1, T_2)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+
+        Returns:
+           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
+
+        """
+        batch_size = scores.size(0)
+        mask = mask.unsqueeze(1).unsqueeze(2)
+        if chunk_mask is not None:
+            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
+        scores = scores.masked_fill(mask, float("-inf"))
+        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+
+        attn_output = self.dropout(self.attn)
+        attn_output = torch.matmul(attn_output, value)
+
+        attn_output = self.linear_out(
+            attn_output.transpose(1, 2)
+            .contiguous()
+            .view(batch_size, -1, self.num_heads * self.d_k)
+        )
+
+        return attn_output
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Compute scaled dot product attention with rel. positional encoding.
+
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            value: Value tensor. (B, T_2, size)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+            left_context: Number of frames in left context.
+
+        Returns:
+            : Output tensor. (B, T_1, H * d_k)
+
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
+        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
diff --git a/funasr/models_transducer/encoder/modules/convolution.py b/funasr/models_transducer/encoder/modules/convolution.py
new file mode 100644
index 0000000..012538a
--- /dev/null
+++ b/funasr/models_transducer/encoder/modules/convolution.py
@@ -0,0 +1,196 @@
+"""Convolution modules for X-former blocks."""
+
+from typing import Dict, Optional, Tuple
+
+import torch
+
+
+class ConformerConvolution(torch.nn.Module):
+    """ConformerConvolution module definition.
+
+    Args:
+        channels: The number of channels.
+        kernel_size: Size of the convolving kernel.
+        activation: Type of activation function.
+        norm_args: Normalization module arguments.
+        causal: Whether to use causal convolution (set to True if streaming).
+
+    """
+
+    def __init__(
+        self,
+        channels: int,
+        kernel_size: int,
+        activation: torch.nn.Module = torch.nn.ReLU(),
+        norm_args: Dict = {},
+        causal: bool = False,
+    ) -> None:
+        """Construct an ConformerConvolution object."""
+        super().__init__()
+
+        assert (kernel_size - 1) % 2 == 0
+
+        self.kernel_size = kernel_size
+
+        self.pointwise_conv1 = torch.nn.Conv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        )
+
+        if causal:
+            self.lorder = kernel_size - 1
+            padding = 0
+        else:
+            self.lorder = 0
+            padding = (kernel_size - 1) // 2
+
+        self.depthwise_conv = torch.nn.Conv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=padding,
+            groups=channels,
+        )
+        self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
+        self.pointwise_conv2 = torch.nn.Conv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        )
+
+        self.activation = activation
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        cache: Optional[torch.Tensor] = None,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute convolution module.
+
+        Args:
+            x: ConformerConvolution input sequences. (B, T, D_hidden)
+            cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
+            right_context: Number of frames in right context.
+
+        Returns:
+            x: ConformerConvolution output sequences. (B, T, D_hidden)
+            cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
+
+        """
+        x = self.pointwise_conv1(x.transpose(1, 2))
+        x = torch.nn.functional.glu(x, dim=1)
+
+        if self.lorder > 0:
+            if cache is None:
+                x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+            else:
+                x = torch.cat([cache, x], dim=2)
+
+                if right_context > 0:
+                    cache = x[:, :, -(self.lorder + right_context) : -right_context]
+                else:
+                    cache = x[:, :, -self.lorder :]
+
+        x = self.depthwise_conv(x)
+        x = self.activation(self.norm(x))
+
+        x = self.pointwise_conv2(x).transpose(1, 2)
+
+        return x, cache
+
+
+class ConvolutionalSpatialGatingUnit(torch.nn.Module):
+    """Convolutional Spatial Gating Unit module definition.
+
+    Args:
+        size: Initial size to determine the number of channels.
+        kernel_size: Size of the convolving kernel.
+        norm_class: Normalization module class.
+        norm_args: Normalization module arguments.
+        dropout_rate: Dropout rate.
+        causal: Whether to use causal convolution (set to True if streaming).
+
+    """
+
+    def __init__(
+        self,
+        size: int,
+        kernel_size: int,
+        norm_class: torch.nn.Module = torch.nn.LayerNorm,
+        norm_args: Dict = {},
+        dropout_rate: float = 0.0,
+        causal: bool = False,
+    ) -> None:
+        """Construct a ConvolutionalSpatialGatingUnit object."""
+        super().__init__()
+
+        channels = size // 2
+
+        self.kernel_size = kernel_size
+
+        if causal:
+            self.lorder = kernel_size - 1
+            padding = 0
+        else:
+            self.lorder = 0
+            padding = (kernel_size - 1) // 2
+
+        self.conv = torch.nn.Conv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=padding,
+            groups=channels,
+        )
+
+        self.norm = norm_class(channels, **norm_args)
+        self.activation = torch.nn.Identity()
+
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        cache: Optional[torch.Tensor] = None,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute convolution module.
+
+        Args:
+            x: ConvolutionalSpatialGatingUnit input sequences. (B, T, D_hidden)
+            cache: ConvolutionalSpationGatingUnit input cache.
+                   (1, conv_kernel, D_hidden)
+            right_context: Number of frames in right context.
+
+        Returns:
+            x: ConvolutionalSpatialGatingUnit output sequences. (B, T, D_hidden // 2)
+
+        """
+        x_r, x_g = x.chunk(2, dim=-1)
+
+        x_g = self.norm(x_g).transpose(1, 2)
+
+        if self.lorder > 0:
+            if cache is None:
+                x_g = torch.nn.functional.pad(x_g, (self.lorder, 0), "constant", 0.0)
+            else:
+                x_g = torch.cat([cache, x_g], dim=2)
+
+                if right_context > 0:
+                    cache = x_g[:, :, -(self.lorder + right_context) : -right_context]
+                else:
+                    cache = x_g[:, :, -self.lorder :]
+
+        x_g = self.conv(x_g).transpose(1, 2)
+
+        x = self.dropout(x_r * self.activation(x_g))
+
+        return x, cache
diff --git a/funasr/models_transducer/encoder/modules/multi_blocks.py b/funasr/models_transducer/encoder/modules/multi_blocks.py
new file mode 100644
index 0000000..14aca8b
--- /dev/null
+++ b/funasr/models_transducer/encoder/modules/multi_blocks.py
@@ -0,0 +1,105 @@
+"""MultiBlocks for encoder architecture."""
+
+from typing import Dict, List, Optional
+
+import torch
+
+
+class MultiBlocks(torch.nn.Module):
+    """MultiBlocks definition.
+
+    Args:
+        block_list: Individual blocks of the encoder architecture.
+        output_size: Architecture output size.
+        norm_class: Normalization module class.
+        norm_args: Normalization module arguments.
+
+    """
+
+    def __init__(
+        self,
+        block_list: List[torch.nn.Module],
+        output_size: int,
+        norm_class: torch.nn.Module = torch.nn.LayerNorm,
+        norm_args: Optional[Dict] = None,
+    ) -> None:
+        """Construct a MultiBlocks object."""
+        super().__init__()
+
+        self.blocks = torch.nn.ModuleList(block_list)
+        self.norm_blocks = norm_class(output_size, **norm_args)
+
+        self.num_blocks = len(block_list)
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset encoder streaming cache.
+
+        Args:
+            left_context: Number of left frames during chunk-by-chunk inference.
+            device: Device to use for cache tensor.
+
+        """
+        for idx in range(self.num_blocks):
+            self.blocks[idx].reset_streaming_cache(left_context, device)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Forward each block of the encoder architecture.
+
+        Args:
+            x: MultiBlocks input sequences. (B, T, D_block_1)
+            pos_enc: Positional embedding sequences.
+            mask: Source mask. (B, T)
+            chunk_mask: Chunk mask. (T_2, T_2)
+
+        Returns:
+            x: Output sequences. (B, T, D_block_N)
+
+        """
+        for block_index, block in enumerate(self.blocks):
+            x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
+
+        x = self.norm_blocks(x)
+
+        return x
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_size: int = 0,
+        left_context: int = 0,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        """Forward each block of the encoder architecture.
+
+        Args:
+            x: MultiBlocks input sequences. (B, T, D_block_1)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
+            mask: Source mask. (B, T_2)
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+
+        Returns:
+            x: MultiBlocks output sequences. (B, T, D_block_N)
+
+        """
+        for block_idx, block in enumerate(self.blocks):
+            x, pos_enc = block.chunk_forward(
+                x,
+                pos_enc,
+                mask,
+                chunk_size=chunk_size,
+                left_context=left_context,
+                right_context=right_context,
+            )
+
+        x = self.norm_blocks(x)
+
+        return x
diff --git a/funasr/models_transducer/encoder/modules/normalization.py b/funasr/models_transducer/encoder/modules/normalization.py
new file mode 100644
index 0000000..ae35fd4
--- /dev/null
+++ b/funasr/models_transducer/encoder/modules/normalization.py
@@ -0,0 +1,170 @@
+"""Normalization modules for X-former blocks."""
+
+from typing import Dict, Optional, Tuple
+
+import torch
+
+
+def get_normalization(
+    normalization_type: str,
+    eps: Optional[float] = None,
+    partial: Optional[float] = None,
+) -> Tuple[torch.nn.Module, Dict]:
+    """Get normalization module and arguments given parameters.
+
+    Args:
+        normalization_type: Normalization module type.
+        eps: Value added to the denominator.
+        partial: Value defining the part of the input used for RMS stats (RMSNorm).
+
+    Return:
+        : Normalization module class
+        : Normalization module arguments
+
+    """
+    norm = {
+        "basic_norm": (
+            BasicNorm,
+            {"eps": eps if eps is not None else 0.25},
+        ),
+        "layer_norm": (torch.nn.LayerNorm, {"eps": eps if eps is not None else 1e-12}),
+        "rms_norm": (
+            RMSNorm,
+            {
+                "eps": eps if eps is not None else 1e-05,
+                "partial": partial if partial is not None else -1.0,
+            },
+        ),
+        "scale_norm": (
+            ScaleNorm,
+            {"eps": eps if eps is not None else 1e-05},
+        ),
+    }
+
+    return norm[normalization_type]
+
+
+class BasicNorm(torch.nn.Module):
+    """BasicNorm module definition.
+
+    Reference: https://github.com/k2-fsa/icefall/pull/288
+
+    Args:
+        normalized_shape: Expected size.
+        eps: Value added to the denominator for numerical stability.
+
+    """
+
+    def __init__(
+        self,
+        normalized_shape: int,
+        eps: float = 0.25,
+    ) -> None:
+        """Construct a BasicNorm object."""
+        super().__init__()
+
+        self.eps = torch.nn.Parameter(torch.tensor(eps).log().detach())
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Compute basic normalization.
+
+        Args:
+            x: Input sequences. (B, T, D_hidden)
+
+        Returns:
+            : Output sequences. (B, T, D_hidden)
+
+        """
+        scales = (torch.mean(x.pow(2), dim=-1, keepdim=True) + self.eps.exp()) ** -0.5
+
+        return x * scales
+
+
+class RMSNorm(torch.nn.Module):
+    """RMSNorm module definition.
+
+    Reference: https://arxiv.org/pdf/1910.07467.pdf
+
+    Args:
+        normalized_shape: Expected size.
+        eps: Value added to the denominator for numerical stability.
+        partial: Value defining the part of the input used for RMS stats.
+
+    """
+
+    def __init__(
+        self,
+        normalized_shape: int,
+        eps: float = 1e-5,
+        partial: float = 0.0,
+    ) -> None:
+        """Construct a RMSNorm object."""
+        super().__init__()
+
+        self.normalized_shape = normalized_shape
+
+        self.partial = True if 0 < partial < 1 else False
+        self.p = partial
+        self.eps = eps
+
+        self.scale = torch.nn.Parameter(torch.ones(normalized_shape))
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Compute RMS normalization.
+
+        Args:
+            x: Input sequences. (B, T, D_hidden)
+
+        Returns:
+            x: Output sequences. (B, T, D_hidden)
+
+        """
+        if self.partial:
+            partial_size = int(self.normalized_shape * self.p)
+            partial_x, _ = torch.split(
+                x, [partial_size, self.normalized_shape - partial_size], dim=-1
+            )
+
+            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
+            d_x = partial_size
+        else:
+            norm_x = x.norm(2, dim=-1, keepdim=True)
+            d_x = self.normalized_shape
+
+        rms_x = norm_x * d_x ** (-1.0 / 2)
+        x = self.scale * (x / (rms_x + self.eps))
+
+        return x
+
+
+class ScaleNorm(torch.nn.Module):
+    """ScaleNorm module definition.
+
+    Reference: https://arxiv.org/pdf/1910.05895.pdf
+
+    Args:
+        normalized_shape: Expected size.
+        eps: Value added to the denominator for numerical stability.
+
+    """
+
+    def __init__(self, normalized_shape: int, eps: float = 1e-5) -> None:
+        """Construct a ScaleNorm object."""
+        super().__init__()
+
+        self.eps = eps
+        self.scale = torch.nn.Parameter(torch.tensor(normalized_shape**0.5))
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Compute scale normalization.
+
+        Args:
+            x: Input sequences. (B, T, D_hidden)
+
+        Returns:
+            : Output sequences. (B, T, D_hidden)
+
+        """
+        norm = self.scale / torch.norm(x, dim=-1, keepdim=True).clamp(min=self.eps)
+
+        return x * norm
diff --git a/funasr/models_transducer/encoder/modules/positional_encoding.py b/funasr/models_transducer/encoder/modules/positional_encoding.py
new file mode 100644
index 0000000..5b56e26
--- /dev/null
+++ b/funasr/models_transducer/encoder/modules/positional_encoding.py
@@ -0,0 +1,91 @@
+"""Positional encoding modules."""
+
+import math
+
+import torch
+
+from funasr.modules.embedding import _pre_hook
+
+
+class RelPositionalEncoding(torch.nn.Module):
+    """Relative positional encoding.
+
+    Args:
+        size: Module size.
+        max_len: Maximum input length.
+        dropout_rate: Dropout rate.
+
+    """
+
+    def __init__(
+        self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
+    ) -> None:
+        """Construct a RelativePositionalEncoding object."""
+        super().__init__()
+
+        self.size = size
+
+        self.pe = None
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+        self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+        self._register_load_state_dict_pre_hook(_pre_hook)
+
+    def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
+        """Reset positional encoding.
+
+        Args:
+            x: Input sequences. (B, T, ?)
+            left_context: Number of frames in left context.
+
+        """
+        time1 = x.size(1) + left_context
+
+        if self.pe is not None:
+            if self.pe.size(1) >= time1 * 2 - 1:
+                if self.pe.dtype != x.dtype or self.pe.device != x.device:
+                    self.pe = self.pe.to(device=x.device, dtype=x.dtype)
+                return
+
+        pe_positive = torch.zeros(time1, self.size)
+        pe_negative = torch.zeros(time1, self.size)
+
+        position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, self.size, 2, dtype=torch.float32)
+            * -(math.log(10000.0) / self.size)
+        )
+
+        pe_positive[:, 0::2] = torch.sin(position * div_term)
+        pe_positive[:, 1::2] = torch.cos(position * div_term)
+        pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+
+        pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+        pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+        pe_negative = pe_negative[1:].unsqueeze(0)
+
+        self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
+            dtype=x.dtype, device=x.device
+        )
+
+    def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+        """Compute positional encoding.
+
+        Args:
+            x: Input sequences. (B, T, ?)
+            left_context: Number of frames in left context.
+
+        Returns:
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?)
+
+        """
+        self.extend_pe(x, left_context=left_context)
+
+        time1 = x.size(1) + left_context
+
+        pos_enc = self.pe[
+            :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
+        ]
+        pos_enc = self.dropout(pos_enc)
+
+        return pos_enc
diff --git a/funasr/models_transducer/encoder/sanm_encoder.py b/funasr/models_transducer/encoder/sanm_encoder.py
new file mode 100644
index 0000000..9e74bdf
--- /dev/null
+++ b/funasr/models_transducer/encoder/sanm_encoder.py
@@ -0,0 +1,835 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
+from typeguard import check_argument_types
+import numpy as np
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
+from funasr.modules.embedding import SinusoidalPositionEncoder
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.multi_layer_conv import Conv1dLinear
+from funasr.modules.multi_layer_conv import MultiLayeredConv1d
+from funasr.modules.positionwise_feed_forward import (
+    PositionwiseFeedForward,  # noqa: H301
+)
+from funasr.modules.repeat import repeat
+from funasr.modules.subsampling import Conv2dSubsampling
+from funasr.modules.subsampling import Conv2dSubsampling2
+from funasr.modules.subsampling import Conv2dSubsampling6
+from funasr.modules.subsampling import Conv2dSubsampling8
+from funasr.modules.subsampling import TooShortUttError
+from funasr.modules.subsampling import check_short_utt
+from funasr.models.ctc import CTC
+from funasr.models.encoder.abs_encoder import AbsEncoder
+
+
+class EncoderLayerSANM(nn.Module):
+    def __init__(
+        self,
+        in_size,
+        size,
+        self_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+        stochastic_depth_rate=0.0,
+    ):
+        """Construct an EncoderLayer object."""
+        super(EncoderLayerSANM, self).__init__()
+        self.self_attn = self_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(in_size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.in_size = in_size
+        self.size = size
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear = nn.Linear(size + size, size)
+        self.stochastic_depth_rate = stochastic_depth_rate
+        self.dropout_rate = dropout_rate
+
+    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        """Compute encoded features.
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+        """
+        skip_layer = False
+        # with stochastic depth, residual connection `x + f(x)` becomes
+        # `x <- x + 1 / (1 - p) * f(x)` at training time.
+        stoch_layer_coeff = 1.0
+        if self.training and self.stochastic_depth_rate > 0:
+            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+        if skip_layer:
+            if cache is not None:
+                x = torch.cat([cache, x], dim=1)
+            return x, mask
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.concat_after:
+            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+            else:
+                x = stoch_layer_coeff * self.concat_linear(x_concat)
+        else:
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+            else:
+                x = stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+
+        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+class SANMEncoder(AbsEncoder):
+    """
+    author: Speech Lab, Alibaba Group, China
+    San-m: Memory equipped self-attention for end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        pos_enc_class=SinusoidalPositionEncoder,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
+        kernel_size : int = 11,
+        sanm_shfit : int = 0,
+        tf2torch_tensor_name_prefix_torch: str = "encoder",
+        tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+    ):
+        assert check_argument_types()
+        super().__init__()
+
+        self.embed = SinusoidalPositionEncoder()
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        encoder_selfattn_layer = MultiHeadedAttentionSANM
+        encoder_selfattn_layer_args0 = (
+            attention_heads,
+            input_size,
+            output_size,
+            attention_dropout_rate,
+            kernel_size,
+            sanm_shfit,
+        )
+
+        encoder_selfattn_layer_args = (
+            attention_heads,
+            output_size,
+            output_size,
+            attention_dropout_rate,
+            kernel_size,
+            sanm_shfit,
+        )
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+        self.encoders = repeat(
+            num_blocks-1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+        self.dropout = nn.Dropout(dropout_rate)
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+
+    def forward(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        xs_pad = xs_pad * self.output_size**0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling2)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        # xs_pad = self.dropout(xs_pad)
+        encoder_outs = self.encoders0(xs_pad, masks)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+        if len(self.interctc_layer_idx) == 0:
+            encoder_outs = self.encoders(xs_pad, masks)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                encoder_outs = encoder_layer(xs_pad, masks)
+                xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens
+
+    def gen_tf2torch_map_dict(self):
+        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+        map_dict_local = {
+            ## encoder
+            # cicd
+            "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (768,256),(1,256,768)
+            "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (768,),(768,)
+            "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 2, 0),
+                 },  # (256,1,31),(1,31,256,1)
+            "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,256),(1,256,256)
+            "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            # ffn
+            "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+            "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            # out norm
+            "{}.after_norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.after_norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+        
+        }
+    
+        return map_dict_local
+
+    def convert_tf2torch(self,
+                         var_dict_tf,
+                         var_dict_torch,
+                         ):
+        
+        map_dict = self.gen_tf2torch_map_dict()
+    
+        var_dict_torch_update = dict()
+        for name in sorted(var_dict_torch.keys(), reverse=False):
+            names = name.split('.')
+            if names[0] == self.tf2torch_tensor_name_prefix_torch:
+                if names[1] == "encoders0":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                
+                    name_q = name_q.replace("encoders0", "encoders")
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+                elif names[1] == "encoders":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                    layeridx_bias = 1
+                    layeridx += layeridx_bias
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "after_norm":
+                    name_tf = map_dict[name]["name"]
+                    data_tf = var_dict_tf[name_tf]
+                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                    var_dict_torch_update[name] = data_tf
+                    logging.info(
+                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+                                                                                      var_dict_tf[name_tf].shape))
+    
+        return var_dict_torch_update
+
+
+class SANMEncoderChunkOpt(AbsEncoder):
+    """
+    author: Speech Lab, Alibaba Group, China
+    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
+    """
+
+    def __init__(
+            self,
+            input_size: int,
+            output_size: int = 256,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            attention_dropout_rate: float = 0.0,
+            pos_enc_class=SinusoidalPositionEncoder,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            positionwise_layer_type: str = "linear",
+            positionwise_conv_kernel_size: int = 1,
+            padding_idx: int = -1,
+            interctc_layer_idx: List[int] = [],
+            interctc_use_conditioning: bool = False,
+            kernel_size: int = 11,
+            sanm_shfit: int = 0,
+            chunk_size: Union[int, Sequence[int]] = (16,),
+            stride: Union[int, Sequence[int]] = (10,),
+            pad_left: Union[int, Sequence[int]] = (0,),
+            time_reduction_factor: int = 1,
+            encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+            decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+            tf2torch_tensor_name_prefix_torch: str = "encoder",
+            tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+    ):
+        assert check_argument_types()
+        super().__init__()
+        self.output_size = output_size
+
+        self.embed = SinusoidalPositionEncoder()
+        
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        encoder_selfattn_layer = MultiHeadedAttentionSANM
+        encoder_selfattn_layer_args0 = (
+            attention_heads,
+            input_size,
+            output_size,
+            attention_dropout_rate,
+            kernel_size,
+            sanm_shfit,
+        )
+
+        encoder_selfattn_layer_args = (
+            attention_heads,
+            output_size,
+            output_size,
+            attention_dropout_rate,
+            kernel_size,
+            sanm_shfit,
+        )
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+        self.encoders = repeat(
+            num_blocks - 1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+        shfit_fsmn = (kernel_size - 1) // 2
+        self.overlap_chunk_cls = overlap_chunk(
+            chunk_size=chunk_size,
+            stride=stride,
+            pad_left=pad_left,
+            shfit_fsmn=shfit_fsmn,
+            encoder_att_look_back_factor=encoder_att_look_back_factor,
+            decoder_att_look_back_factor=decoder_att_look_back_factor,
+        )
+        self.time_reduction_factor = time_reduction_factor
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+
+    def forward(
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+            prev_states: torch.Tensor = None,
+            ctc: CTC = None,
+            ind: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        xs_pad *= self.output_size ** 0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (
+                isinstance(self.embed, Conv2dSubsampling)
+                or isinstance(self.embed, Conv2dSubsampling2)
+                or isinstance(self.embed, Conv2dSubsampling6)
+                or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        mask_shfit_chunk, mask_att_chunk_encoder = None, None
+        if self.overlap_chunk_cls is not None:
+            ilens = masks.squeeze(1).sum(1)
+            chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
+            xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
+            masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+            mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
+                                                                           dtype=xs_pad.dtype)
+            mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
+                                                                                       xs_pad.size(0),
+                                                                                       dtype=xs_pad.dtype)
+
+        encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+        if len(self.interctc_layer_idx) == 0:
+            encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+                xs_pad, masks = encoder_outs[0], encoder_outs[1]
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+        
+        olens = masks.squeeze(1).sum(1)
+
+        xs_pad, olens = self.overlap_chunk_cls.remove_chunk(xs_pad, olens, chunk_outs=None)
+
+        if self.time_reduction_factor > 1:
+            xs_pad = xs_pad[:,::self.time_reduction_factor,:]
+            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens
+
+    def gen_tf2torch_map_dict(self):
+        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+        map_dict_local = {
+            ## encoder
+            # cicd
+            "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (768,256),(1,256,768)
+            "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (768,),(768,)
+            "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 2, 0),
+                 },  # (256,1,31),(1,31,256,1)
+            "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,256),(1,256,256)
+            "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            # ffn
+            "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+            "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            # out norm
+            "{}.after_norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.after_norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+        
+        }
+    
+        return map_dict_local
+
+    def convert_tf2torch(self,
+                         var_dict_tf,
+                         var_dict_torch,
+                         ):
+    
+        map_dict = self.gen_tf2torch_map_dict()
+    
+        var_dict_torch_update = dict()
+        for name in sorted(var_dict_torch.keys(), reverse=False):
+            names = name.split('.')
+            if names[0] == self.tf2torch_tensor_name_prefix_torch:
+                if names[1] == "encoders0":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                
+                    name_q = name_q.replace("encoders0", "encoders")
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+                elif names[1] == "encoders":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                    layeridx_bias = 1
+                    layeridx += layeridx_bias
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "after_norm":
+                    name_tf = map_dict[name]["name"]
+                    data_tf = var_dict_tf[name_tf]
+                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                    var_dict_torch_update[name] = data_tf
+                    logging.info(
+                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+                                                                                      var_dict_tf[name_tf].shape))
+    
+        return var_dict_torch_update
diff --git a/funasr/models_transducer/encoder/validation.py b/funasr/models_transducer/encoder/validation.py
new file mode 100644
index 0000000..0003536
--- /dev/null
+++ b/funasr/models_transducer/encoder/validation.py
@@ -0,0 +1,171 @@
+"""Set of methods to validate encoder architecture."""
+
+from typing import Any, Dict, List, Tuple
+
+from funasr.models_transducer.utils import sub_factor_to_params
+
+
+def validate_block_arguments(
+    configuration: Dict[str, Any],
+    block_id: int,
+    previous_block_output: int,
+) -> Tuple[int, int]:
+    """Validate block arguments.
+
+    Args:
+        configuration: Architecture configuration.
+        block_id: Block ID.
+        previous_block_output: Previous block output size.
+
+    Returns:
+        input_size: Block input size.
+        output_size: Block output size.
+
+    """
+    block_type = configuration.get("block_type")
+
+    if block_type is None:
+        raise ValueError(
+            "Block %d in encoder doesn't have a type assigned. " % block_id
+        )
+
+    if block_type in ["branchformer", "conformer"]:
+        if configuration.get("linear_size") is None:
+            raise ValueError(
+                "Missing 'linear_size' argument for X-former block (ID: %d)" % block_id
+            )
+
+        if configuration.get("conv_mod_kernel_size") is None:
+            raise ValueError(
+                "Missing 'conv_mod_kernel_size' argument for X-former block (ID: %d)"
+                % block_id
+            )
+
+        input_size = configuration.get("hidden_size")
+        output_size = configuration.get("hidden_size")
+
+    elif block_type == "conv1d":
+        output_size = configuration.get("output_size")
+
+        if output_size is None:
+            raise ValueError(
+                "Missing 'output_size' argument for Conv1d block (ID: %d)" % block_id
+            )
+
+        if configuration.get("kernel_size") is None:
+            raise ValueError(
+                "Missing 'kernel_size' argument for Conv1d block (ID: %d)" % block_id
+            )
+
+        input_size = configuration["input_size"] = previous_block_output
+    else:
+        raise ValueError("Block type: %s is not supported." % block_type)
+
+    return input_size, output_size
+
+
+def validate_input_block(
+    configuration: Dict[str, Any], body_first_conf: Dict[str, Any], input_size: int
+) -> int:
+    """Validate input block.
+
+    Args:
+        configuration: Encoder input block configuration.
+        body_first_conf: Encoder first body block configuration.
+        input_size: Encoder input block input size.
+
+    Return:
+        output_size: Encoder input block output size.
+
+    """
+    vgg_like = configuration.get("vgg_like", False)
+    linear = configuration.get("linear", False)
+    next_block_type = body_first_conf.get("block_type")
+    allowed_next_block_type = ["branchformer", "conformer", "conv1d"]
+
+    if next_block_type is None or (next_block_type not in allowed_next_block_type):
+        return -1
+
+    if configuration.get("subsampling_factor") is None:
+        configuration["subsampling_factor"] = 4
+
+    if vgg_like:
+        conv_size = configuration.get("conv_size", (64, 128))
+
+        if isinstance(conv_size, int):
+            conv_size = (conv_size, conv_size)
+    else:
+        conv_size = configuration.get("conv_size", None)
+
+        if isinstance(conv_size, tuple):
+            conv_size = conv_size[0]
+
+    if next_block_type == "conv1d":
+        if vgg_like:
+            output_size = conv_size[1] * ((input_size // 2) // 2)
+        else:
+            if conv_size is None:
+                conv_size = body_first_conf.get("output_size", 64)
+
+            sub_factor = configuration["subsampling_factor"]
+
+            _, _, conv_osize = sub_factor_to_params(sub_factor, input_size)
+            assert (
+                conv_osize > 0
+            ), "Conv2D output size is <1 with input size %d and subsampling %d" % (
+                input_size,
+                sub_factor,
+            )
+
+            output_size = conv_osize * conv_size
+
+        configuration["output_size"] = None
+    else:
+        output_size = body_first_conf.get("hidden_size")
+
+        if conv_size is None:
+            conv_size = output_size
+
+        configuration["output_size"] = output_size
+
+    configuration["conv_size"] = conv_size
+    configuration["vgg_like"] = vgg_like
+    configuration["linear"] = linear
+
+    return output_size
+
+
+def validate_architecture(
+    input_conf: Dict[str, Any], body_conf: List[Dict[str, Any]], input_size: int
+) -> Tuple[int, int]:
+    """Validate specified architecture is valid.
+
+    Args:
+        input_conf: Encoder input block configuration.
+        body_conf: Encoder body blocks configuration.
+        input_size: Encoder input size.
+
+    Returns:
+        input_block_osize: Encoder input block output size.
+        : Encoder body block output size.
+
+    """
+    input_block_osize = validate_input_block(input_conf, body_conf[0], input_size)
+
+    cmp_io = []
+
+    for i, b in enumerate(body_conf):
+        _io = validate_block_arguments(
+            b, (i + 1), input_block_osize if i == 0 else cmp_io[i - 1][1]
+        )
+
+        cmp_io.append(_io)
+
+    for i in range(1, len(cmp_io)):
+        if cmp_io[(i - 1)][1] != cmp_io[i][0]:
+            raise ValueError(
+                "Output/Input mismatch between blocks %d and %d"
+                " in the encoder body." % ((i - 1), i)
+            )
+
+    return input_block_osize, cmp_io[-1][1]
diff --git a/funasr/models_transducer/error_calculator.py b/funasr/models_transducer/error_calculator.py
new file mode 100644
index 0000000..17dbf36
--- /dev/null
+++ b/funasr/models_transducer/error_calculator.py
@@ -0,0 +1,170 @@
+"""Error Calculator module for Transducer."""
+
+from typing import List, Optional, Tuple
+
+import torch
+
+from funasr.models_transducer.beam_search_transducer import BeamSearchTransducer
+from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.models_transducer.joint_network import JointNetwork
+
+
+class ErrorCalculator:
+    """Calculate CER and WER for transducer models.
+
+    Args:
+        decoder: Decoder module.
+        joint_network: Joint Network module.
+        token_list: List of token units.
+        sym_space: Space symbol.
+        sym_blank: Blank symbol.
+        report_cer: Whether to compute CER.
+        report_wer: Whether to compute WER.
+
+    """
+
+    def __init__(
+        self,
+        decoder: AbsDecoder,
+        joint_network: JointNetwork,
+        token_list: List[int],
+        sym_space: str,
+        sym_blank: str,
+        report_cer: bool = False,
+        report_wer: bool = False,
+    ) -> None:
+        """Construct an ErrorCalculatorTransducer object."""
+        super().__init__()
+
+        self.beam_search = BeamSearchTransducer(
+            decoder=decoder,
+            joint_network=joint_network,
+            beam_size=1,
+            search_type="default",
+            score_norm=False,
+        )
+
+        self.decoder = decoder
+
+        self.token_list = token_list
+        self.space = sym_space
+        self.blank = sym_blank
+
+        self.report_cer = report_cer
+        self.report_wer = report_wer
+
+    def __call__(
+        self, encoder_out: torch.Tensor, target: torch.Tensor
+    ) -> Tuple[Optional[float], Optional[float]]:
+        """Calculate sentence-level WER or/and CER score for Transducer model.
+
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            target: Target label ID sequences. (B, L)
+
+        Returns:
+            : Sentence-level CER score.
+            : Sentence-level WER score.
+
+        """
+        cer, wer = None, None
+
+        batchsize = int(encoder_out.size(0))
+
+        encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
+
+        batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
+        pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
+
+        char_pred, char_target = self.convert_to_char(pred, target)
+
+        if self.report_cer:
+            cer = self.calculate_cer(char_pred, char_target)
+
+        if self.report_wer:
+            wer = self.calculate_wer(char_pred, char_target)
+
+        return cer, wer
+
+    def convert_to_char(
+        self, pred: torch.Tensor, target: torch.Tensor
+    ) -> Tuple[List, List]:
+        """Convert label ID sequences to character sequences.
+
+        Args:
+            pred: Prediction label ID sequences. (B, U)
+            target: Target label ID sequences. (B, L)
+
+        Returns:
+            char_pred: Prediction character sequences. (B, ?)
+            char_target: Target character sequences. (B, ?)
+
+        """
+        char_pred, char_target = [], []
+
+        for i, pred_i in enumerate(pred):
+            char_pred_i = [self.token_list[int(h)] for h in pred_i]
+            char_target_i = [self.token_list[int(r)] for r in target[i]]
+
+            char_pred_i = "".join(char_pred_i).replace(self.space, " ")
+            char_pred_i = char_pred_i.replace(self.blank, "")
+
+            char_target_i = "".join(char_target_i).replace(self.space, " ")
+            char_target_i = char_target_i.replace(self.blank, "")
+
+            char_pred.append(char_pred_i)
+            char_target.append(char_target_i)
+
+        return char_pred, char_target
+
+    def calculate_cer(
+        self, char_pred: torch.Tensor, char_target: torch.Tensor
+    ) -> float:
+        """Calculate sentence-level CER score.
+
+        Args:
+            char_pred: Prediction character sequences. (B, ?)
+            char_target: Target character sequences. (B, ?)
+
+        Returns:
+            : Average sentence-level CER score.
+
+        """
+        import editdistance
+
+        distances, lens = [], []
+
+        for i, char_pred_i in enumerate(char_pred):
+            pred = char_pred_i.replace(" ", "")
+            target = char_target[i].replace(" ", "")
+
+            distances.append(editdistance.eval(pred, target))
+            lens.append(len(target))
+
+        return float(sum(distances)) / sum(lens)
+
+    def calculate_wer(
+        self, char_pred: torch.Tensor, char_target: torch.Tensor
+    ) -> float:
+        """Calculate sentence-level WER score.
+
+        Args:
+            char_pred: Prediction character sequences. (B, ?)
+            char_target: Target character sequences. (B, ?)
+
+        Returns:
+            : Average sentence-level WER score
+
+        """
+        import editdistance
+
+        distances, lens = [], []
+
+        for i, char_pred_i in enumerate(char_pred):
+            pred = char_pred_i.replace("鈻�", " ").split()
+            target = char_target[i].replace("鈻�", " ").split()
+
+            distances.append(editdistance.eval(pred, target))
+            lens.append(len(target))
+
+        return float(sum(distances)) / sum(lens)
diff --git a/funasr/models_transducer/espnet_transducer_model.py b/funasr/models_transducer/espnet_transducer_model.py
new file mode 100644
index 0000000..e32f6e3
--- /dev/null
+++ b/funasr/models_transducer/espnet_transducer_model.py
@@ -0,0 +1,484 @@
+"""ESPnet2 ASR Transducer model."""
+
+import logging
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from packaging.version import parse as V
+from typeguard import check_argument_types
+
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
+from funasr.models_transducer.encoder.encoder import Encoder
+from funasr.models_transducer.joint_network import JointNetwork
+from funasr.models_transducer.utils import get_transducer_task_io
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if V(torch.__version__) >= V("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
+
+class ESPnetASRTransducerModel(AbsESPnetModel):
+    """ESPnet2ASRTransducerModel module definition.
+
+    Args:
+        vocab_size: Size of complete vocabulary (w/ EOS and blank included).
+        token_list: List of token
+        frontend: Frontend module.
+        specaug: SpecAugment module.
+        normalize: Normalization module.
+        encoder: Encoder module.
+        decoder: Decoder module.
+        joint_network: Joint Network module.
+        transducer_weight: Weight of the Transducer loss.
+        fastemit_lambda: FastEmit lambda value.
+        auxiliary_ctc_weight: Weight of auxiliary CTC loss.
+        auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
+        auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
+        auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
+        ignore_id: Initial padding ID.
+        sym_space: Space symbol.
+        sym_blank: Blank Symbol
+        report_cer: Whether to report Character Error Rate during validation.
+        report_wer: Whether to report Word Error Rate during validation.
+        extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
+
+    """
+
+    def __init__(
+        self,
+        vocab_size: int,
+        token_list: Union[Tuple[str, ...], List[str]],
+        frontend: Optional[AbsFrontend],
+        specaug: Optional[AbsSpecAug],
+        normalize: Optional[AbsNormalize],
+        encoder: Encoder,
+        decoder: AbsDecoder,
+        att_decoder: Optional[AbsAttDecoder],
+        joint_network: JointNetwork,
+        transducer_weight: float = 1.0,
+        fastemit_lambda: float = 0.0,
+        auxiliary_ctc_weight: float = 0.0,
+        auxiliary_ctc_dropout_rate: float = 0.0,
+        auxiliary_lm_loss_weight: float = 0.0,
+        auxiliary_lm_loss_smoothing: float = 0.0,
+        ignore_id: int = -1,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        report_cer: bool = True,
+        report_wer: bool = True,
+        extract_feats_in_collect_stats: bool = True,
+    ) -> None:
+        """Construct an ESPnetASRTransducerModel object."""
+        super().__init__()
+
+        assert check_argument_types()
+
+        # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
+        self.blank_id = 0
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.token_list = token_list.copy()
+
+        self.sym_space = sym_space
+        self.sym_blank = sym_blank
+
+        self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+
+        self.encoder = encoder
+        self.decoder = decoder
+        self.joint_network = joint_network
+
+        self.criterion_transducer = None
+        self.error_calculator = None
+
+        self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+        self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+        if self.use_auxiliary_ctc:
+            self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
+            self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+        if self.use_auxiliary_lm_loss:
+            self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+            self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+        self.transducer_weight = transducer_weight
+        self.fastemit_lambda = fastemit_lambda
+
+        self.auxiliary_ctc_weight = auxiliary_ctc_weight
+        self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+
+        self.report_cer = report_cer
+        self.report_wer = report_wer
+
+        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Forward architecture and compute loss(es).
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+            text: Label ID sequences. (B, L)
+            text_lengths: Label ID sequences lengths. (B,)
+            kwargs: Contains "utts_id".
+
+        Return:
+            loss: Main loss value.
+            stats: Task statistics.
+            weight: Task weights.
+
+        """
+        assert text_lengths.dim() == 1, text_lengths.shape
+        assert (
+            speech.shape[0]
+            == speech_lengths.shape[0]
+            == text.shape[0]
+            == text_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+
+        batch_size = speech.shape[0]
+        text = text[:, : text_lengths.max()]
+
+        # 1. Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        # 2. Transducer-related I/O preparation
+        decoder_in, target, t_len, u_len = get_transducer_task_io(
+            text,
+            encoder_out_lens,
+            ignore_id=self.ignore_id,
+        )
+
+        # 3. Decoder
+        self.decoder.set_device(encoder_out.device)
+        decoder_out = self.decoder(decoder_in, u_len)
+
+        # 4. Joint Network
+        joint_out = self.joint_network(
+            encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
+        )
+
+        # 5. Losses
+        loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
+            encoder_out,
+            joint_out,
+            target,
+            t_len,
+            u_len,
+        )
+
+        loss_ctc, loss_lm = 0.0, 0.0
+
+        if self.use_auxiliary_ctc:
+            loss_ctc = self._calc_ctc_loss(
+                encoder_out,
+                target,
+                t_len,
+                u_len,
+            )
+
+        if self.use_auxiliary_lm_loss:
+            loss_lm = self._calc_lm_loss(decoder_out, target)
+
+        loss = (
+            self.transducer_weight * loss_trans
+            + self.auxiliary_ctc_weight * loss_ctc
+            + self.auxiliary_lm_loss_weight * loss_lm
+        )
+
+        stats = dict(
+            loss=loss.detach(),
+            loss_transducer=loss_trans.detach(),
+            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+            cer_transducer=cer_trans,
+            wer_transducer=wer_trans,
+        )
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+
+        return loss, stats, weight
+
+    def collect_feats(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Dict[str, torch.Tensor]:
+        """Collect features sequences and features lengths sequences.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+            text: Label ID sequences. (B, L)
+            text_lengths: Label ID sequences lengths. (B,)
+            kwargs: Contains "utts_id".
+
+        Return:
+            {}: "feats": Features sequences. (B, T, D_feats),
+                "feats_lengths": Features sequences lengths. (B,)
+
+        """
+        if self.extract_feats_in_collect_stats:
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+        else:
+            # Generate dummy stats if extract_feats_in_collect_stats is False
+            logging.warning(
+                "Generating dummy stats for feats and feats_lengths, "
+                "because encoder_conf.extract_feats_in_collect_stats is "
+                f"{self.extract_feats_in_collect_stats}"
+            )
+
+            feats, feats_lengths = speech, speech_lengths
+
+        return {"feats": feats, "feats_lengths": feats_lengths}
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encoder speech sequences.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+
+        Return:
+            encoder_out: Encoder outputs. (B, T, D_enc)
+            encoder_out_lens: Encoder outputs lengths. (B,)
+
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # 4. Forward encoder
+        encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
+
+        assert encoder_out.size(0) == speech.size(0), (
+            encoder_out.size(),
+            speech.size(0),
+        )
+        assert encoder_out.size(1) <= encoder_out_lens.max(), (
+            encoder_out.size(),
+            encoder_out_lens.max(),
+        )
+
+        return encoder_out, encoder_out_lens
+
+    def _extract_feats(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Extract features sequences and features sequences lengths.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+
+        Return:
+            feats: Features sequences. (B, T, D_feats)
+            feats_lengths: Features sequences lengths. (B,)
+
+        """
+        assert speech_lengths.dim() == 1, speech_lengths.shape
+
+        # for data-parallel
+        speech = speech[:, : speech_lengths.max()]
+
+        if self.frontend is not None:
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:
+            feats, feats_lengths = speech, speech_lengths
+
+        return feats, feats_lengths
+
+    def _calc_transducer_loss(
+        self,
+        encoder_out: torch.Tensor,
+        joint_out: torch.Tensor,
+        target: torch.Tensor,
+        t_len: torch.Tensor,
+        u_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
+        """Compute Transducer loss.
+
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            joint_out: Joint Network output sequences (B, T, U, D_joint)
+            target: Target label ID sequences. (B, L)
+            t_len: Encoder output sequences lengths. (B,)
+            u_len: Target label ID sequences lengths. (B,)
+
+        Return:
+            loss_transducer: Transducer loss value.
+            cer_transducer: Character error rate for Transducer.
+            wer_transducer: Word Error Rate for Transducer.
+
+        """
+        if self.criterion_transducer is None:
+            try:
+                # from warprnnt_pytorch import RNNTLoss
+	        # self.criterion_transducer = RNNTLoss(
+                    # reduction="mean",
+                    # fastemit_lambda=self.fastemit_lambda,
+                # )
+                from warp_rnnt import rnnt_loss as RNNTLoss
+                self.criterion_transducer = RNNTLoss
+
+            except ImportError:
+                logging.error(
+                    "warp-rnnt was not installed."
+                    "Please consult the installation documentation."
+                )
+                exit(1)
+
+        # loss_transducer = self.criterion_transducer(
+        #     joint_out,
+        #     target,
+        #     t_len,
+        #     u_len,
+        # )
+        log_probs = torch.log_softmax(joint_out, dim=-1)
+
+        loss_transducer = self.criterion_transducer(
+                log_probs,
+                target,
+                t_len,
+                u_len,
+                reduction="mean",
+                blank=self.blank_id,
+                fastemit_lambda=self.fastemit_lambda,
+                gather=True,
+        )
+
+        if not self.training and (self.report_cer or self.report_wer):
+            if self.error_calculator is None:
+                from espnet2.asr_transducer.error_calculator import ErrorCalculator
+
+                self.error_calculator = ErrorCalculator(
+                    self.decoder,
+                    self.joint_network,
+                    self.token_list,
+                    self.sym_space,
+                    self.sym_blank,
+                    report_cer=self.report_cer,
+                    report_wer=self.report_wer,
+                )
+
+            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+
+            return loss_transducer, cer_transducer, wer_transducer
+
+        return loss_transducer, None, None
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        target: torch.Tensor,
+        t_len: torch.Tensor,
+        u_len: torch.Tensor,
+    ) -> torch.Tensor:
+        """Compute CTC loss.
+
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            target: Target label ID sequences. (B, L)
+            t_len: Encoder output sequences lengths. (B,)
+            u_len: Target label ID sequences lengths. (B,)
+
+        Return:
+            loss_ctc: CTC loss value.
+
+        """
+        ctc_in = self.ctc_lin(
+            torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
+        )
+        ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
+
+        target_mask = target != 0
+        ctc_target = target[target_mask].cpu()
+
+        with torch.backends.cudnn.flags(deterministic=True):
+            loss_ctc = torch.nn.functional.ctc_loss(
+                ctc_in,
+                ctc_target,
+                t_len,
+                u_len,
+                zero_infinity=True,
+                reduction="sum",
+            )
+        loss_ctc /= target.size(0)
+
+        return loss_ctc
+
+    def _calc_lm_loss(
+        self,
+        decoder_out: torch.Tensor,
+        target: torch.Tensor,
+    ) -> torch.Tensor:
+        """Compute LM loss.
+
+        Args:
+            decoder_out: Decoder output sequences. (B, U, D_dec)
+            target: Target label ID sequences. (B, L)
+
+        Return:
+            loss_lm: LM loss value.
+
+        """
+        lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
+        lm_target = target.view(-1).type(torch.int64)
+
+        with torch.no_grad():
+            true_dist = lm_loss_in.clone()
+            true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
+
+            # Ignore blank ID (0)
+            ignore = lm_target == 0
+            lm_target = lm_target.masked_fill(ignore, 0)
+
+            true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
+
+        loss_lm = torch.nn.functional.kl_div(
+            torch.log_softmax(lm_loss_in, dim=1),
+            true_dist,
+            reduction="none",
+        )
+        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
+            0
+        )
+
+        return loss_lm
diff --git a/funasr/models_transducer/espnet_transducer_model_uni_asr.py b/funasr/models_transducer/espnet_transducer_model_uni_asr.py
new file mode 100644
index 0000000..2add3fa
--- /dev/null
+++ b/funasr/models_transducer/espnet_transducer_model_uni_asr.py
@@ -0,0 +1,485 @@
+"""ESPnet2 ASR Transducer model."""
+
+import logging
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from packaging.version import parse as V
+from typeguard import check_argument_types
+
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
+from funasr.models_transducer.encoder.encoder import Encoder
+from funasr.models_transducer.joint_network import JointNetwork
+from funasr.models_transducer.utils import get_transducer_task_io
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if V(torch.__version__) >= V("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
+
+class UniASRTransducerModel(AbsESPnetModel):
+    """ESPnet2ASRTransducerModel module definition.
+
+    Args:
+        vocab_size: Size of complete vocabulary (w/ EOS and blank included).
+        token_list: List of token
+        frontend: Frontend module.
+        specaug: SpecAugment module.
+        normalize: Normalization module.
+        encoder: Encoder module.
+        decoder: Decoder module.
+        joint_network: Joint Network module.
+        transducer_weight: Weight of the Transducer loss.
+        fastemit_lambda: FastEmit lambda value.
+        auxiliary_ctc_weight: Weight of auxiliary CTC loss.
+        auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
+        auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
+        auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
+        ignore_id: Initial padding ID.
+        sym_space: Space symbol.
+        sym_blank: Blank Symbol
+        report_cer: Whether to report Character Error Rate during validation.
+        report_wer: Whether to report Word Error Rate during validation.
+        extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
+
+    """
+
+    def __init__(
+        self,
+        vocab_size: int,
+        token_list: Union[Tuple[str, ...], List[str]],
+        frontend: Optional[AbsFrontend],
+        specaug: Optional[AbsSpecAug],
+        normalize: Optional[AbsNormalize],
+        encoder,
+        decoder: AbsDecoder,
+        att_decoder: Optional[AbsAttDecoder],
+        joint_network: JointNetwork,
+        transducer_weight: float = 1.0,
+        fastemit_lambda: float = 0.0,
+        auxiliary_ctc_weight: float = 0.0,
+        auxiliary_ctc_dropout_rate: float = 0.0,
+        auxiliary_lm_loss_weight: float = 0.0,
+        auxiliary_lm_loss_smoothing: float = 0.0,
+        ignore_id: int = -1,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        report_cer: bool = True,
+        report_wer: bool = True,
+        extract_feats_in_collect_stats: bool = True,
+    ) -> None:
+        """Construct an ESPnetASRTransducerModel object."""
+        super().__init__()
+
+        assert check_argument_types()
+
+        # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
+        self.blank_id = 0
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.token_list = token_list.copy()
+
+        self.sym_space = sym_space
+        self.sym_blank = sym_blank
+
+        self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+
+        self.encoder = encoder
+        self.decoder = decoder
+        self.joint_network = joint_network
+
+        self.criterion_transducer = None
+        self.error_calculator = None
+
+        self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+        self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+        if self.use_auxiliary_ctc:
+            self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
+            self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+        if self.use_auxiliary_lm_loss:
+            self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+            self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+        self.transducer_weight = transducer_weight
+        self.fastemit_lambda = fastemit_lambda
+
+        self.auxiliary_ctc_weight = auxiliary_ctc_weight
+        self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+
+        self.report_cer = report_cer
+        self.report_wer = report_wer
+
+        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        decoding_ind: int = None,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Forward architecture and compute loss(es).
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+            text: Label ID sequences. (B, L)
+            text_lengths: Label ID sequences lengths. (B,)
+            kwargs: Contains "utts_id".
+
+        Return:
+            loss: Main loss value.
+            stats: Task statistics.
+            weight: Task weights.
+
+        """
+        assert text_lengths.dim() == 1, text_lengths.shape
+        assert (
+            speech.shape[0]
+            == speech_lengths.shape[0]
+            == text.shape[0]
+            == text_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+
+        batch_size = speech.shape[0]
+        text = text[:, : text_lengths.max()]
+
+        # 1. Encoder
+        ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+        # 2. Transducer-related I/O preparation
+        decoder_in, target, t_len, u_len = get_transducer_task_io(
+            text,
+            encoder_out_lens,
+            ignore_id=self.ignore_id,
+        )
+
+        # 3. Decoder
+        self.decoder.set_device(encoder_out.device)
+        decoder_out = self.decoder(decoder_in, u_len)
+
+        # 4. Joint Network
+        joint_out = self.joint_network(
+            encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
+        )
+
+        # 5. Losses
+        loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
+            encoder_out,
+            joint_out,
+            target,
+            t_len,
+            u_len,
+        )
+
+        loss_ctc, loss_lm = 0.0, 0.0
+
+        if self.use_auxiliary_ctc:
+            loss_ctc = self._calc_ctc_loss(
+                encoder_out,
+                target,
+                t_len,
+                u_len,
+            )
+
+        if self.use_auxiliary_lm_loss:
+            loss_lm = self._calc_lm_loss(decoder_out, target)
+
+        loss = (
+            self.transducer_weight * loss_trans
+            + self.auxiliary_ctc_weight * loss_ctc
+            + self.auxiliary_lm_loss_weight * loss_lm
+        )
+
+        stats = dict(
+            loss=loss.detach(),
+            loss_transducer=loss_trans.detach(),
+            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+            cer_transducer=cer_trans,
+            wer_transducer=wer_trans,
+        )
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+
+        return loss, stats, weight
+
+    def collect_feats(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Dict[str, torch.Tensor]:
+        """Collect features sequences and features lengths sequences.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+            text: Label ID sequences. (B, L)
+            text_lengths: Label ID sequences lengths. (B,)
+            kwargs: Contains "utts_id".
+
+        Return:
+            {}: "feats": Features sequences. (B, T, D_feats),
+                "feats_lengths": Features sequences lengths. (B,)
+
+        """
+        if self.extract_feats_in_collect_stats:
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+        else:
+            # Generate dummy stats if extract_feats_in_collect_stats is False
+            logging.warning(
+                "Generating dummy stats for feats and feats_lengths, "
+                "because encoder_conf.extract_feats_in_collect_stats is "
+                f"{self.extract_feats_in_collect_stats}"
+            )
+
+            feats, feats_lengths = speech, speech_lengths
+
+        return {"feats": feats, "feats_lengths": feats_lengths}
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        ind: int,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encoder speech sequences.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+
+        Return:
+            encoder_out: Encoder outputs. (B, T, D_enc)
+            encoder_out_lens: Encoder outputs lengths. (B,)
+
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # 4. Forward encoder
+        encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths, ind=ind)
+
+        assert encoder_out.size(0) == speech.size(0), (
+            encoder_out.size(),
+            speech.size(0),
+        )
+        assert encoder_out.size(1) <= encoder_out_lens.max(), (
+            encoder_out.size(),
+            encoder_out_lens.max(),
+        )
+
+        return encoder_out, encoder_out_lens
+
+    def _extract_feats(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Extract features sequences and features sequences lengths.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+
+        Return:
+            feats: Features sequences. (B, T, D_feats)
+            feats_lengths: Features sequences lengths. (B,)
+
+        """
+        assert speech_lengths.dim() == 1, speech_lengths.shape
+
+        # for data-parallel
+        speech = speech[:, : speech_lengths.max()]
+
+        if self.frontend is not None:
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:
+            feats, feats_lengths = speech, speech_lengths
+
+        return feats, feats_lengths
+
+    def _calc_transducer_loss(
+        self,
+        encoder_out: torch.Tensor,
+        joint_out: torch.Tensor,
+        target: torch.Tensor,
+        t_len: torch.Tensor,
+        u_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
+        """Compute Transducer loss.
+
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            joint_out: Joint Network output sequences (B, T, U, D_joint)
+            target: Target label ID sequences. (B, L)
+            t_len: Encoder output sequences lengths. (B,)
+            u_len: Target label ID sequences lengths. (B,)
+
+        Return:
+            loss_transducer: Transducer loss value.
+            cer_transducer: Character error rate for Transducer.
+            wer_transducer: Word Error Rate for Transducer.
+
+        """
+        if self.criterion_transducer is None:
+            try:
+                # from warprnnt_pytorch import RNNTLoss
+	        # self.criterion_transducer = RNNTLoss(
+                    # reduction="mean",
+                    # fastemit_lambda=self.fastemit_lambda,
+                # )
+                from warp_rnnt import rnnt_loss as RNNTLoss
+                self.criterion_transducer = RNNTLoss
+
+            except ImportError:
+                logging.error(
+                    "warp-rnnt was not installed."
+                    "Please consult the installation documentation."
+                )
+                exit(1)
+
+        # loss_transducer = self.criterion_transducer(
+        #     joint_out,
+        #     target,
+        #     t_len,
+        #     u_len,
+        # )
+        log_probs = torch.log_softmax(joint_out, dim=-1)
+
+        loss_transducer = self.criterion_transducer(
+                log_probs,
+                target,
+                t_len,
+                u_len,
+                reduction="mean",
+                blank=self.blank_id,
+                gather=True,
+        )
+
+        if not self.training and (self.report_cer or self.report_wer):
+            if self.error_calculator is None:
+                from espnet2.asr_transducer.error_calculator import ErrorCalculator
+
+                self.error_calculator = ErrorCalculator(
+                    self.decoder,
+                    self.joint_network,
+                    self.token_list,
+                    self.sym_space,
+                    self.sym_blank,
+                    report_cer=self.report_cer,
+                    report_wer=self.report_wer,
+                )
+
+            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+
+            return loss_transducer, cer_transducer, wer_transducer
+
+        return loss_transducer, None, None
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        target: torch.Tensor,
+        t_len: torch.Tensor,
+        u_len: torch.Tensor,
+    ) -> torch.Tensor:
+        """Compute CTC loss.
+
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            target: Target label ID sequences. (B, L)
+            t_len: Encoder output sequences lengths. (B,)
+            u_len: Target label ID sequences lengths. (B,)
+
+        Return:
+            loss_ctc: CTC loss value.
+
+        """
+        ctc_in = self.ctc_lin(
+            torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
+        )
+        ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
+
+        target_mask = target != 0
+        ctc_target = target[target_mask].cpu()
+
+        with torch.backends.cudnn.flags(deterministic=True):
+            loss_ctc = torch.nn.functional.ctc_loss(
+                ctc_in,
+                ctc_target,
+                t_len,
+                u_len,
+                zero_infinity=True,
+                reduction="sum",
+            )
+        loss_ctc /= target.size(0)
+
+        return loss_ctc
+
+    def _calc_lm_loss(
+        self,
+        decoder_out: torch.Tensor,
+        target: torch.Tensor,
+    ) -> torch.Tensor:
+        """Compute LM loss.
+
+        Args:
+            decoder_out: Decoder output sequences. (B, U, D_dec)
+            target: Target label ID sequences. (B, L)
+
+        Return:
+            loss_lm: LM loss value.
+
+        """
+        lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
+        lm_target = target.view(-1).type(torch.int64)
+
+        with torch.no_grad():
+            true_dist = lm_loss_in.clone()
+            true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
+
+            # Ignore blank ID (0)
+            ignore = lm_target == 0
+            lm_target = lm_target.masked_fill(ignore, 0)
+
+            true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
+
+        loss_lm = torch.nn.functional.kl_div(
+            torch.log_softmax(lm_loss_in, dim=1),
+            true_dist,
+            reduction="none",
+        )
+        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
+            0
+        )
+
+        return loss_lm
diff --git a/funasr/models_transducer/espnet_transducer_model_unified.py b/funasr/models_transducer/espnet_transducer_model_unified.py
new file mode 100644
index 0000000..efe3f4e
--- /dev/null
+++ b/funasr/models_transducer/espnet_transducer_model_unified.py
@@ -0,0 +1,588 @@
+"""ESPnet2 ASR Transducer model."""
+
+import logging
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from packaging.version import parse as V
+from typeguard import check_argument_types
+
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.models_transducer.encoder.encoder import Encoder
+from funasr.models_transducer.joint_network import JointNetwork
+from funasr.models_transducer.utils import get_transducer_task_io
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
+from funasr.modules.nets_utils import th_accuracy
+from funasr.losses.label_smoothing_loss import (  # noqa: H301
+    LabelSmoothingLoss,
+)
+from funasr.models_transducer.error_calculator import ErrorCalculator
+if V(torch.__version__) >= V("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
+
+class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
+    """ESPnet2ASRTransducerModel module definition.
+
+    Args:
+        vocab_size: Size of complete vocabulary (w/ EOS and blank included).
+        token_list: List of token
+        frontend: Frontend module.
+        specaug: SpecAugment module.
+        normalize: Normalization module.
+        encoder: Encoder module.
+        decoder: Decoder module.
+        joint_network: Joint Network module.
+        transducer_weight: Weight of the Transducer loss.
+        fastemit_lambda: FastEmit lambda value.
+        auxiliary_ctc_weight: Weight of auxiliary CTC loss.
+        auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
+        auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
+        auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
+        ignore_id: Initial padding ID.
+        sym_space: Space symbol.
+        sym_blank: Blank Symbol
+        report_cer: Whether to report Character Error Rate during validation.
+        report_wer: Whether to report Word Error Rate during validation.
+        extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
+
+    """
+
+    def __init__(
+        self,
+        vocab_size: int,
+        token_list: Union[Tuple[str, ...], List[str]],
+        frontend: Optional[AbsFrontend],
+        specaug: Optional[AbsSpecAug],
+        normalize: Optional[AbsNormalize],
+        encoder: Encoder,
+        decoder: AbsDecoder,
+        att_decoder: Optional[AbsAttDecoder],
+        joint_network: JointNetwork,
+        transducer_weight: float = 1.0,
+        fastemit_lambda: float = 0.0,
+        auxiliary_ctc_weight: float = 0.0,
+        auxiliary_att_weight: float = 0.0,
+        auxiliary_ctc_dropout_rate: float = 0.0,
+        auxiliary_lm_loss_weight: float = 0.0,
+        auxiliary_lm_loss_smoothing: float = 0.0,
+        ignore_id: int = -1,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        report_cer: bool = True,
+        report_wer: bool = True,
+        sym_sos: str = "<sos/eos>",
+        sym_eos: str = "<sos/eos>",
+        extract_feats_in_collect_stats: bool = True,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+    ) -> None:
+        """Construct an ESPnetASRTransducerModel object."""
+        super().__init__()
+
+        assert check_argument_types()
+
+        # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
+        self.blank_id = 0
+
+        if sym_sos in token_list:
+            self.sos = token_list.index(sym_sos)
+        else:
+            self.sos = vocab_size - 1
+        if sym_eos in token_list:
+            self.eos = token_list.index(sym_eos)
+        else:
+            self.eos = vocab_size - 1
+
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.token_list = token_list.copy()
+
+        self.sym_space = sym_space
+        self.sym_blank = sym_blank
+
+        self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+
+        self.encoder = encoder
+        self.decoder = decoder
+        self.joint_network = joint_network
+
+        self.criterion_transducer = None
+        self.error_calculator = None
+
+        self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+        self.use_auxiliary_att = auxiliary_att_weight > 0
+        self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+        if self.use_auxiliary_ctc:
+            self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
+            self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+        if self.use_auxiliary_att:
+            self.att_decoder = att_decoder
+
+            self.criterion_att = LabelSmoothingLoss(
+                size=vocab_size,
+                padding_idx=ignore_id,
+                smoothing=lsm_weight,
+                normalize_length=length_normalized_loss,
+            )
+
+        if self.use_auxiliary_lm_loss:
+            self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+            self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+        self.transducer_weight = transducer_weight
+        self.fastemit_lambda = fastemit_lambda
+
+        self.auxiliary_ctc_weight = auxiliary_ctc_weight
+        self.auxiliary_att_weight = auxiliary_att_weight
+        self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+
+        self.report_cer = report_cer
+        self.report_wer = report_wer
+
+        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Forward architecture and compute loss(es).
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+            text: Label ID sequences. (B, L)
+            text_lengths: Label ID sequences lengths. (B,)
+            kwargs: Contains "utts_id".
+
+        Return:
+            loss: Main loss value.
+            stats: Task statistics.
+            weight: Task weights.
+
+        """
+        assert text_lengths.dim() == 1, text_lengths.shape
+        assert (
+            speech.shape[0]
+            == speech_lengths.shape[0]
+            == text.shape[0]
+            == text_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+
+        batch_size = speech.shape[0]
+        text = text[:, : text_lengths.max()]
+        #print(speech.shape)
+        # 1. Encoder
+        encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        loss_att, loss_att_chunk = 0.0, 0.0
+
+        if self.use_auxiliary_att:
+            loss_att, _ = self._calc_att_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+            loss_att_chunk, _ = self._calc_att_loss(
+                encoder_out_chunk, encoder_out_lens, text, text_lengths
+            )
+
+        # 2. Transducer-related I/O preparation
+        decoder_in, target, t_len, u_len = get_transducer_task_io(
+            text,
+            encoder_out_lens,
+            ignore_id=self.ignore_id,
+        )
+        
+        # 3. Decoder
+        self.decoder.set_device(encoder_out.device)
+        decoder_out = self.decoder(decoder_in, u_len)
+
+        # 4. Joint Network
+        joint_out = self.joint_network(
+            encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
+        )
+        
+        joint_out_chunk = self.joint_network(
+            encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1)
+        )
+
+        # 5. Losses
+        loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss(
+            encoder_out,
+            joint_out,
+            target,
+            t_len,
+            u_len,
+        )
+        
+        loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss(
+            encoder_out_chunk,
+            joint_out_chunk,
+            target,
+            t_len,
+            u_len,
+        )
+
+        loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0
+
+        if self.use_auxiliary_ctc:
+            loss_ctc = self._calc_ctc_loss(
+                encoder_out,
+                target,
+                t_len,
+                u_len,
+            )
+            loss_ctc_chunk = self._calc_ctc_loss(
+                encoder_out_chunk,
+                target,
+                t_len,
+                u_len,
+            )
+
+        if self.use_auxiliary_lm_loss:
+            loss_lm = self._calc_lm_loss(decoder_out, target)
+
+        loss_trans = loss_trans_utt + loss_trans_chunk
+        loss_ctc = loss_ctc + loss_ctc_chunk 
+        loss_ctc = loss_att + loss_att_chunk
+
+        loss = (
+            self.transducer_weight * loss_trans
+            + self.auxiliary_ctc_weight * loss_ctc
+            + self.auxiliary_att_weight * loss_att
+            + self.auxiliary_lm_loss_weight * loss_lm
+        )
+
+        stats = dict(
+            loss=loss.detach(),
+            loss_transducer=loss_trans_utt.detach(),
+            loss_transducer_chunk=loss_trans_chunk.detach(),
+            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+            aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None,
+            aux_att_loss=loss_att.detach() if loss_att > 0.0 else None,
+            aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None,
+            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+            cer_transducer=cer_trans,
+            wer_transducer=wer_trans,
+            cer_transducer_chunk=cer_trans_chunk,
+            wer_transducer_chunk=wer_trans_chunk,
+        )
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+
+        return loss, stats, weight
+
+    def collect_feats(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Dict[str, torch.Tensor]:
+        """Collect features sequences and features lengths sequences.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+            text: Label ID sequences. (B, L)
+            text_lengths: Label ID sequences lengths. (B,)
+            kwargs: Contains "utts_id".
+
+        Return:
+            {}: "feats": Features sequences. (B, T, D_feats),
+                "feats_lengths": Features sequences lengths. (B,)
+
+        """
+        if self.extract_feats_in_collect_stats:
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+        else:
+            # Generate dummy stats if extract_feats_in_collect_stats is False
+            logging.warning(
+                "Generating dummy stats for feats and feats_lengths, "
+                "because encoder_conf.extract_feats_in_collect_stats is "
+                f"{self.extract_feats_in_collect_stats}"
+            )
+
+            feats, feats_lengths = speech, speech_lengths
+
+        return {"feats": feats, "feats_lengths": feats_lengths}
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encoder speech sequences.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+
+        Return:
+            encoder_out: Encoder outputs. (B, T, D_enc)
+            encoder_out_lens: Encoder outputs lengths. (B,)
+
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # 4. Forward encoder
+        encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths)
+
+        assert encoder_out.size(0) == speech.size(0), (
+            encoder_out.size(),
+            speech.size(0),
+        )
+        assert encoder_out.size(1) <= encoder_out_lens.max(), (
+            encoder_out.size(),
+            encoder_out_lens.max(),
+        )
+
+        return encoder_out, encoder_out_chunk, encoder_out_lens
+
+    def _extract_feats(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Extract features sequences and features sequences lengths.
+
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+
+        Return:
+            feats: Features sequences. (B, T, D_feats)
+            feats_lengths: Features sequences lengths. (B,)
+
+        """
+        assert speech_lengths.dim() == 1, speech_lengths.shape
+
+        # for data-parallel
+        speech = speech[:, : speech_lengths.max()]
+
+        if self.frontend is not None:
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:
+            feats, feats_lengths = speech, speech_lengths
+
+        return feats, feats_lengths
+
+    def _calc_transducer_loss(
+        self,
+        encoder_out: torch.Tensor,
+        joint_out: torch.Tensor,
+        target: torch.Tensor,
+        t_len: torch.Tensor,
+        u_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
+        """Compute Transducer loss.
+
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            joint_out: Joint Network output sequences (B, T, U, D_joint)
+            target: Target label ID sequences. (B, L)
+            t_len: Encoder output sequences lengths. (B,)
+            u_len: Target label ID sequences lengths. (B,)
+
+        Return:
+            loss_transducer: Transducer loss value.
+            cer_transducer: Character error rate for Transducer.
+            wer_transducer: Word Error Rate for Transducer.
+
+        """
+        if self.criterion_transducer is None:
+            try:
+                # from warprnnt_pytorch import RNNTLoss
+	        # self.criterion_transducer = RNNTLoss(
+                    # reduction="mean",
+                    # fastemit_lambda=self.fastemit_lambda,
+                # )
+                from warp_rnnt import rnnt_loss as RNNTLoss
+                self.criterion_transducer = RNNTLoss
+
+            except ImportError:
+                logging.error(
+                    "warp-rnnt was not installed."
+                    "Please consult the installation documentation."
+                )
+                exit(1)
+
+        # loss_transducer = self.criterion_transducer(
+        #     joint_out,
+        #     target,
+        #     t_len,
+        #     u_len,
+        # )
+        log_probs = torch.log_softmax(joint_out, dim=-1)
+
+        loss_transducer = self.criterion_transducer(
+                log_probs,
+                target,
+                t_len,
+                u_len,
+                reduction="mean",
+                blank=self.blank_id,
+                fastemit_lambda=self.fastemit_lambda,
+                gather=True,
+        )
+
+        if not self.training and (self.report_cer or self.report_wer):
+            if self.error_calculator is None:
+                self.error_calculator = ErrorCalculator(
+                    self.decoder,
+                    self.joint_network,
+                    self.token_list,
+                    self.sym_space,
+                    self.sym_blank,
+                    report_cer=self.report_cer,
+                    report_wer=self.report_wer,
+                )
+
+            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+
+            return loss_transducer, cer_transducer, wer_transducer
+
+        return loss_transducer, None, None
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        target: torch.Tensor,
+        t_len: torch.Tensor,
+        u_len: torch.Tensor,
+    ) -> torch.Tensor:
+        """Compute CTC loss.
+
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            target: Target label ID sequences. (B, L)
+            t_len: Encoder output sequences lengths. (B,)
+            u_len: Target label ID sequences lengths. (B,)
+
+        Return:
+            loss_ctc: CTC loss value.
+
+        """
+        ctc_in = self.ctc_lin(
+            torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
+        )
+        ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
+
+        target_mask = target != 0
+        ctc_target = target[target_mask].cpu()
+
+        with torch.backends.cudnn.flags(deterministic=True):
+            loss_ctc = torch.nn.functional.ctc_loss(
+                ctc_in,
+                ctc_target,
+                t_len,
+                u_len,
+                zero_infinity=True,
+                reduction="sum",
+            )
+        loss_ctc /= target.size(0)
+
+        return loss_ctc
+
+    def _calc_lm_loss(
+        self,
+        decoder_out: torch.Tensor,
+        target: torch.Tensor,
+    ) -> torch.Tensor:
+        """Compute LM loss.
+
+        Args:
+            decoder_out: Decoder output sequences. (B, U, D_dec)
+            target: Target label ID sequences. (B, L)
+
+        Return:
+            loss_lm: LM loss value.
+
+        """
+        lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
+        lm_target = target.view(-1).type(torch.int64)
+
+        with torch.no_grad():
+            true_dist = lm_loss_in.clone()
+            true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
+
+            # Ignore blank ID (0)
+            ignore = lm_target == 0
+            lm_target = lm_target.masked_fill(ignore, 0)
+
+            true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
+
+        loss_lm = torch.nn.functional.kl_div(
+            torch.log_softmax(lm_loss_in, dim=1),
+            true_dist,
+            reduction="none",
+        )
+        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
+            0
+        )
+
+        return loss_lm
+
+    def _calc_att_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
+            ys_pad = torch.cat(
+                [
+                    self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device),
+                    ys_pad,
+                ],
+                dim=1,
+            )
+            ys_pad_lens += 1
+
+        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+        ys_in_lens = ys_pad_lens + 1
+
+        # 1. Forward decoder
+        decoder_out, _ = self.att_decoder(
+            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+        )
+
+        # 2. Compute attention loss
+        loss_att = self.criterion_att(decoder_out, ys_out_pad)
+        acc_att = th_accuracy(
+            decoder_out.view(-1, self.vocab_size),
+            ys_out_pad,
+            ignore_label=self.ignore_id,
+        )
+
+        return loss_att, acc_att
diff --git a/funasr/models_transducer/joint_network.py b/funasr/models_transducer/joint_network.py
new file mode 100644
index 0000000..119dd84
--- /dev/null
+++ b/funasr/models_transducer/joint_network.py
@@ -0,0 +1,62 @@
+"""Transducer joint network implementation."""
+
+import torch
+
+from funasr.models_transducer.activation import get_activation
+
+
+class JointNetwork(torch.nn.Module):
+    """Transducer joint network module.
+
+    Args:
+        output_size: Output size.
+        encoder_size: Encoder output size.
+        decoder_size: Decoder output size..
+        joint_space_size: Joint space size.
+        joint_act_type: Type of activation for joint network.
+        **activation_parameters: Parameters for the activation function.
+
+    """
+
+    def __init__(
+        self,
+        output_size: int,
+        encoder_size: int,
+        decoder_size: int,
+        joint_space_size: int = 256,
+        joint_activation_type: str = "tanh",
+        **activation_parameters,
+    ) -> None:
+        """Construct a JointNetwork object."""
+        super().__init__()
+
+        self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size)
+        self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False)
+
+        self.lin_out = torch.nn.Linear(joint_space_size, output_size)
+
+        self.joint_activation = get_activation(
+            joint_activation_type, **activation_parameters
+        )
+
+    def forward(
+        self,
+        enc_out: torch.Tensor,
+        dec_out: torch.Tensor,
+        project_input: bool = True,
+    ) -> torch.Tensor:
+        """Joint computation of encoder and decoder hidden state sequences.
+
+        Args:
+            enc_out: Expanded encoder output state sequences (B, T, 1, D_enc)
+            dec_out: Expanded decoder output state sequences (B, 1, U, D_dec)
+
+        Returns:
+            joint_out: Joint output state sequences. (B, T, U, D_out)
+
+        """
+        if project_input:
+            joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out))
+        else:
+            joint_out = self.joint_activation(enc_out + dec_out)
+        return self.lin_out(joint_out)
diff --git a/funasr/models_transducer/utils.py b/funasr/models_transducer/utils.py
new file mode 100644
index 0000000..fd3c531
--- /dev/null
+++ b/funasr/models_transducer/utils.py
@@ -0,0 +1,200 @@
+"""Utility functions for Transducer models."""
+
+from typing import List, Tuple
+
+import torch
+
+
+class TooShortUttError(Exception):
+    """Raised when the utt is too short for subsampling.
+
+    Args:
+        message: Error message to display.
+        actual_size: The size that cannot pass the subsampling.
+        limit: The size limit for subsampling.
+
+    """
+
+    def __init__(self, message: str, actual_size: int, limit: int) -> None:
+        """Construct a TooShortUttError module."""
+        super().__init__(message)
+
+        self.actual_size = actual_size
+        self.limit = limit
+
+
+def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
+    """Check if the input is too short for subsampling.
+
+    Args:
+        sub_factor: Subsampling factor for Conv2DSubsampling.
+        size: Input size.
+
+    Returns:
+        : Whether an error should be sent.
+        : Size limit for specified subsampling factor.
+
+    """
+    if sub_factor == 2 and size < 3:
+        return True, 7
+    elif sub_factor == 4 and size < 7:
+        return True, 7
+    elif sub_factor == 6 and size < 11:
+        return True, 11
+
+    return False, -1
+
+
+def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
+    """Get conv2D second layer parameters for given subsampling factor.
+
+    Args:
+        sub_factor: Subsampling factor (1/X).
+        input_size: Input size.
+
+    Returns:
+        : Kernel size for second convolution.
+        : Stride for second convolution.
+        : Conv2DSubsampling output size.
+
+    """
+    if sub_factor == 2:
+        return 3, 1, (((input_size - 1) // 2 - 2))
+    elif sub_factor == 4:
+        return 3, 2, (((input_size - 1) // 2 - 1) // 2)
+    elif sub_factor == 6:
+        return 5, 3, (((input_size - 1) // 2 - 2) // 3)
+    else:
+        raise ValueError(
+            "subsampling_factor parameter should be set to either 2, 4 or 6."
+        )
+
+
+def make_chunk_mask(
+    size: int,
+    chunk_size: int,
+    left_chunk_size: int = 0,
+    device: torch.device = None,
+) -> torch.Tensor:
+    """Create chunk mask for the subsequent steps (size, size).
+
+    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+    Args:
+        size: Size of the source mask.
+        chunk_size: Number of frames in chunk.
+        left_chunk_size: Size of the left context in chunks (0 means full context).
+        device: Device for the mask tensor.
+
+    Returns:
+        mask: Chunk mask. (size, size)
+
+    """
+    mask = torch.zeros(size, size, device=device, dtype=torch.bool)
+
+    for i in range(size):
+        if left_chunk_size <= 0:
+            start = 0
+        else:
+            start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
+
+        end = min((i // chunk_size + 1) * chunk_size, size)
+        mask[i, start:end] = True
+
+    return ~mask
+
+
+def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
+    """Create source mask for given lengths.
+
+    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+    Args:
+        lengths: Sequence lengths. (B,)
+
+    Returns:
+        : Mask for the sequence lengths. (B, max_len)
+
+    """
+    max_len = lengths.max()
+    batch_size = lengths.size(0)
+
+    expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
+
+    return expanded_lengths >= lengths.unsqueeze(1)
+
+
+def get_transducer_task_io(
+    labels: torch.Tensor,
+    encoder_out_lens: torch.Tensor,
+    ignore_id: int = -1,
+    blank_id: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    """Get Transducer loss I/O.
+
+    Args:
+        labels: Label ID sequences. (B, L)
+        encoder_out_lens: Encoder output lengths. (B,)
+        ignore_id: Padding symbol ID.
+        blank_id: Blank symbol ID.
+
+    Returns:
+        decoder_in: Decoder inputs. (B, U)
+        target: Target label ID sequences. (B, U)
+        t_len: Time lengths. (B,)
+        u_len: Label lengths. (B,)
+
+    """
+
+    def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
+        """Create padded batch of labels from a list of labels sequences.
+
+        Args:
+            labels: Labels sequences. [B x (?)]
+            padding_value: Padding value.
+
+        Returns:
+            labels: Batch of padded labels sequences. (B,)
+
+        """
+        batch_size = len(labels)
+
+        padded = (
+            labels[0]
+            .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
+            .fill_(padding_value)
+        )
+
+        for i in range(batch_size):
+            padded[i, : labels[i].size(0)] = labels[i]
+
+        return padded
+
+    device = labels.device
+
+    labels_unpad = [y[y != ignore_id] for y in labels]
+    blank = labels[0].new([blank_id])
+
+    decoder_in = pad_list(
+        [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
+    ).to(device)
+
+    target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
+
+    encoder_out_lens = list(map(int, encoder_out_lens))
+    t_len = torch.IntTensor(encoder_out_lens).to(device)
+
+    u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
+
+    return decoder_in, target, t_len, u_len
+
+def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
+    """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
+    if t.size(dim) == pad_len:
+        return t
+    else:
+        pad_size = list(t.shape)
+        pad_size[dim] = pad_len - t.size(dim)
+        return torch.cat(
+            [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
+        )
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
new file mode 100644
index 0000000..3c7a782
--- /dev/null
+++ b/funasr/tasks/asr_transducer.py
@@ -0,0 +1,487 @@
+"""ASR Transducer Task."""
+
+import argparse
+import logging
+from typing import Callable, Collection, Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+from typeguard import check_argument_types, check_return_type
+
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
+from funasr.models.decoder.transformer_decoder import (
+    DynamicConvolution2DTransformerDecoder,
+    DynamicConvolutionTransformerDecoder,
+    LightweightConvolution2DTransformerDecoder,
+    LightweightConvolutionTransformerDecoder,
+    TransformerDecoder,
+)
+from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
+from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
+from funasr.models_transducer.encoder.encoder import Encoder
+from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
+from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
+from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
+from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
+from funasr.models_transducer.joint_network import JointNetwork
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.tasks.abs_task import AbsTask
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.train.class_choices import ClassChoices
+from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.train.trainer import Trainer
+from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.types import float_or_none, int_or_none, str2bool, str_or_none
+
+frontend_choices = ClassChoices(
+    name="frontend",
+    classes=dict(
+        default=DefaultFrontend,
+        sliding_window=SlidingWindow,
+    ),
+    type_check=AbsFrontend,
+    default="default",
+)
+specaug_choices = ClassChoices(
+    "specaug",
+    classes=dict(
+        specaug=SpecAug,
+    ),
+    type_check=AbsSpecAug,
+    default=None,
+    optional=True,
+)
+normalize_choices = ClassChoices(
+    "normalize",
+    classes=dict(
+        global_mvn=GlobalMVN,
+        utterance_mvn=UtteranceMVN,
+    ),
+    type_check=AbsNormalize,
+    default="utterance_mvn",
+    optional=True,
+)
+encoder_choices = ClassChoices(
+        "encoder",
+        classes=dict(
+                encoder=Encoder,
+                sanm_chunk_opt=SANMEncoderChunkOpt,
+        ),
+        default="encoder",
+)
+
+decoder_choices = ClassChoices(
+    "decoder",
+    classes=dict(
+        rnn=RNNDecoder,
+        stateless=StatelessDecoder,
+    ),
+    type_check=AbsDecoder,
+    default="rnn",
+)
+
+att_decoder_choices = ClassChoices(
+    "att_decoder",
+    classes=dict(
+        transformer=TransformerDecoder,
+        lightweight_conv=LightweightConvolutionTransformerDecoder,
+        lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+        dynamic_conv=DynamicConvolutionTransformerDecoder,
+        dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+    ),
+    type_check=AbsAttDecoder,
+    default=None,
+    optional=True,
+)
+class ASRTransducerTask(AbsTask):
+    """ASR Transducer Task definition."""
+
+    num_optimizers: int = 1
+
+    class_choices_list = [
+        frontend_choices,
+        specaug_choices,
+        normalize_choices,
+        encoder_choices,
+        decoder_choices,
+        att_decoder_choices,
+    ]
+
+    trainer = Trainer
+
+    @classmethod
+    def add_task_arguments(cls, parser: argparse.ArgumentParser):
+        """Add Transducer task arguments.
+        Args:
+            cls: ASRTransducerTask object.
+            parser: Transducer arguments parser.
+        """
+        group = parser.add_argument_group(description="Task related.")
+
+        # required = parser.get_default("required")
+        # required += ["token_list"]
+
+        group.add_argument(
+            "--token_list",
+            type=str_or_none,
+            default=None,
+            help="Integer-string mapper for tokens.",
+        )
+        group.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of dimensions for input features.",
+        )
+        group.add_argument(
+            "--init",
+            type=str_or_none,
+            default=None,
+            help="Type of model initialization to use.",
+        )
+        group.add_argument(
+            "--model_conf",
+            action=NestedDictAction,
+            default=get_default_kwargs(ESPnetASRTransducerModel),
+            help="The keyword arguments for the model class.",
+        )
+        # group.add_argument(
+        #     "--encoder_conf",
+        #     action=NestedDictAction,
+        #     default={},
+        #     help="The keyword arguments for the encoder class.",
+        # )
+        group.add_argument(
+            "--joint_network_conf",
+            action=NestedDictAction,
+            default={},
+            help="The keyword arguments for the joint network class.",
+        )
+        group = parser.add_argument_group(description="Preprocess related.")
+        group.add_argument(
+            "--use_preprocessor",
+            type=str2bool,
+            default=True,
+            help="Whether to apply preprocessing to input data.",
+        )
+        group.add_argument(
+            "--token_type",
+            type=str,
+            default="bpe",
+            choices=["bpe", "char", "word", "phn"],
+            help="The type of tokens to use during tokenization.",
+        )
+        group.add_argument(
+            "--bpemodel",
+            type=str_or_none,
+            default=None,
+            help="The path of the sentencepiece model.",
+        )
+        parser.add_argument(
+            "--non_linguistic_symbols",
+            type=str_or_none,
+            help="The 'non_linguistic_symbols' file path.",
+        )
+        parser.add_argument(
+            "--cleaner",
+            type=str_or_none,
+            choices=[None, "tacotron", "jaconv", "vietnamese"],
+            default=None,
+            help="Text cleaner to use.",
+        )
+        parser.add_argument(
+            "--g2p",
+            type=str_or_none,
+            choices=g2p_choices,
+            default=None,
+            help="g2p method to use if --token_type=phn.",
+        )
+        parser.add_argument(
+            "--speech_volume_normalize",
+            type=float_or_none,
+            default=None,
+            help="Normalization value for maximum amplitude scaling.",
+        )
+        parser.add_argument(
+            "--rir_scp",
+            type=str_or_none,
+            default=None,
+            help="The RIR SCP file path.",
+        )
+        parser.add_argument(
+            "--rir_apply_prob",
+            type=float,
+            default=1.0,
+            help="The probability of the applied RIR convolution.",
+        )
+        parser.add_argument(
+            "--noise_scp",
+            type=str_or_none,
+            default=None,
+            help="The path of noise SCP file.",
+        )
+        parser.add_argument(
+            "--noise_apply_prob",
+            type=float,
+            default=1.0,
+            help="The probability of the applied noise addition.",
+        )
+        parser.add_argument(
+            "--noise_db_range",
+            type=str,
+            default="13_15",
+            help="The range of the noise decibel level.",
+        )
+        for class_choices in cls.class_choices_list:
+            # Append --<name> and --<name>_conf.
+            # e.g. --decoder and --decoder_conf
+            class_choices.add_arguments(group)
+
+    @classmethod
+    def build_collate_fn(
+        cls, args: argparse.Namespace, train: bool
+    ) -> Callable[
+        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+        Tuple[List[str], Dict[str, torch.Tensor]],
+    ]:
+        """Build collate function.
+        Args:
+            cls: ASRTransducerTask object.
+            args: Task arguments.
+            train: Training mode.
+        Return:
+            : Callable collate function.
+        """
+        assert check_argument_types()
+
+        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+    @classmethod
+    def build_preprocess_fn(
+        cls, args: argparse.Namespace, train: bool
+    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+        """Build pre-processing function.
+        Args:
+            cls: ASRTransducerTask object.
+            args: Task arguments.
+            train: Training mode.
+        Return:
+            : Callable pre-processing function.
+        """
+        assert check_argument_types()
+
+        if args.use_preprocessor:
+            retval = CommonPreprocessor(
+                train=train,
+                token_type=args.token_type,
+                token_list=args.token_list,
+                bpemodel=args.bpemodel,
+                non_linguistic_symbols=args.non_linguistic_symbols,
+                text_cleaner=args.cleaner,
+                g2p_type=args.g2p,
+                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+                rir_apply_prob=args.rir_apply_prob
+                if hasattr(args, "rir_apply_prob")
+                else 1.0,
+                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+                noise_apply_prob=args.noise_apply_prob
+                if hasattr(args, "noise_apply_prob")
+                else 1.0,
+                noise_db_range=args.noise_db_range
+                if hasattr(args, "noise_db_range")
+                else "13_15",
+                speech_volume_normalize=args.speech_volume_normalize
+                if hasattr(args, "rir_scp")
+                else None,
+            )
+        else:
+            retval = None
+
+        assert check_return_type(retval)
+        return retval
+
+    @classmethod
+    def required_data_names(
+        cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        """Required data depending on task mode.
+        Args:
+            cls: ASRTransducerTask object.
+            train: Training mode.
+            inference: Inference mode.
+        Return:
+            retval: Required task data.
+        """
+        if not inference:
+            retval = ("speech", "text")
+        else:
+            retval = ("speech",)
+
+        return retval
+
+    @classmethod
+    def optional_data_names(
+        cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        """Optional data depending on task mode.
+        Args:
+            cls: ASRTransducerTask object.
+            train: Training mode.
+            inference: Inference mode.
+        Return:
+            retval: Optional task data.
+        """
+        retval = ()
+        assert check_return_type(retval)
+
+        return retval
+
+    @classmethod
+    def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
+        """Required data depending on task mode.
+        Args:
+            cls: ASRTransducerTask object.
+            args: Task arguments.
+        Return:
+            model: ASR Transducer model.
+        """
+        assert check_argument_types()
+
+        if isinstance(args.token_list, str):
+            with open(args.token_list, encoding="utf-8") as f:
+                token_list = [line.rstrip() for line in f]
+
+            # Overwriting token_list to keep it as "portable".
+            args.token_list = list(token_list)
+        elif isinstance(args.token_list, (tuple, list)):
+            token_list = list(args.token_list)
+        else:
+            raise RuntimeError("token_list must be str or list")
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size }")
+
+        # 1. frontend
+        if args.input_size is None:
+            # Extract features in the model
+            frontend_class = frontend_choices.get_class(args.frontend)
+            frontend = frontend_class(**args.frontend_conf)
+            input_size = frontend.output_size()
+        else:
+            # Give features from data-loader
+            frontend = None
+            input_size = args.input_size
+
+        # 2. Data augmentation for spectrogram
+        if args.specaug is not None:
+            specaug_class = specaug_choices.get_class(args.specaug)
+            specaug = specaug_class(**args.specaug_conf)
+        else:
+            specaug = None
+
+        # 3. Normalization layer
+        if args.normalize is not None:
+            normalize_class = normalize_choices.get_class(args.normalize)
+            normalize = normalize_class(**args.normalize_conf)
+        else:
+            normalize = None
+
+        # 4. Encoder
+        
+        if getattr(args, "encoder", None) is not None:
+            encoder_class = encoder_choices.get_class(args.encoder)
+            encoder = encoder_class(input_size, **args.encoder_conf)
+        else:
+            encoder = Encoder(input_size, **args.encoder_conf)
+        encoder_output_size = encoder.output_size
+
+        # 5. Decoder
+        decoder_class = decoder_choices.get_class(args.decoder)
+        decoder = decoder_class(
+            vocab_size,
+            **args.decoder_conf,
+        )
+        decoder_output_size = decoder.output_size
+
+        if getattr(args, "att_decoder", None) is not None:
+            att_decoder_class = att_decoder_choices.get_class(args.att_decoder)
+
+            att_decoder = att_decoder_class(
+                vocab_size=vocab_size,
+                encoder_output_size=encoder_output_size,
+                **args.att_decoder_conf,
+            )
+        else:
+            att_decoder = None
+
+        # 6. Joint Network
+        joint_network = JointNetwork(
+            vocab_size,
+            encoder_output_size,
+            decoder_output_size,
+            **args.joint_network_conf,
+        )
+
+        # 7. Build model
+
+        if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
+            model = UniASRTransducerModel(
+                vocab_size=vocab_size,
+                token_list=token_list,
+                frontend=frontend,
+                specaug=specaug,
+                normalize=normalize,
+                encoder=encoder,
+                decoder=decoder,
+                att_decoder=att_decoder,
+                joint_network=joint_network,
+                **args.model_conf,
+            )
+
+        elif encoder.unified_model_training:
+            model = ESPnetASRUnifiedTransducerModel(
+                vocab_size=vocab_size,
+                token_list=token_list,
+                frontend=frontend,
+                specaug=specaug,
+                normalize=normalize,
+                encoder=encoder,
+                decoder=decoder,
+                att_decoder=att_decoder,
+                joint_network=joint_network,
+                **args.model_conf,
+            )
+
+        else:
+            model = ESPnetASRTransducerModel(
+                vocab_size=vocab_size,
+                token_list=token_list,
+                frontend=frontend,
+                specaug=specaug,
+                normalize=normalize,
+                encoder=encoder,
+                decoder=decoder,
+                att_decoder=att_decoder,
+                joint_network=joint_network,
+                **args.model_conf,
+            )
+
+        # 8. Initialize model
+        if args.init is not None:
+            raise NotImplementedError(
+                "Currently not supported.",
+                "Initialization part will be reworked in a short future.",
+            )
+
+        #assert check_return_type(model)
+
+        return model

--
Gitblit v1.9.1