From 8672352ecde80a86609fe01195b398ebe77f0ed1 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期一, 17 四月 2023 16:09:23 +0800
Subject: [PATCH] merge many functions

---
 /dev/null                                               |  477 -------------------
 funasr/bin/asr_train_transducer.py                      |    2 
 funasr/models/encoder/conformer_encoder.py              |    7 
 funasr/modules/e2e_asr_common.py                        |    3 
 funasr/modules/beam_search/beam_search_transducer.py    |    3 
 funasr/bin/asr_inference_rnnt.py                        |    2 
 funasr/tasks/asr.py                                     |  391 +++++++++++++++
 funasr/models/decoder/rnnt_decoder.py                   |    3 
 egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml |   30 
 funasr/models/e2e_asr_transducer.py                     |  535 +++++++++++++++++++++
 10 files changed, 944 insertions(+), 509 deletions(-)

diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
index 60f796c..8a1c40c 100644
--- a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
+++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -1,32 +1,26 @@
+encoder: chunk_conformer
 encoder_conf:
-    main_conf:
-      pos_wise_act_type: swish
-      pos_enc_dropout_rate: 0.5
-      conv_mod_act_type: swish
+      activation_type: swish
+      positional_dropout_rate: 0.5
       time_reduction_factor: 2
       unified_model_training: true
       default_chunk_size: 16
       jitter_range: 4
       left_chunk_size: 0
-    input_conf:
-      block_type: conv2d
-      conv_size: 512
+      embed_vgg_like: false
       subsampling_factor: 4
-      num_frame: 1
-    body_conf:
-    - block_type: conformer
-      linear_size: 2048
-      hidden_size: 512
-      heads: 8
+      linear_units: 2048
+      output_size: 512
+      attention_heads: 8
       dropout_rate: 0.5
-      pos_wise_dropout_rate: 0.5
-      att_dropout_rate: 0.5
-      conv_mod_kernel_size: 15
+      positional_dropout_rate: 0.5
+      attention_dropout_rate: 0.5
+      cnn_module_kernel: 15
       num_blocks: 12    
 
 # decoder related
-decoder: rnn
-decoder_conf:
+rnnt_decoder: rnnt
+rnnt_decoder_conf:
     embed_size: 512
     hidden_size: 512
     embed_dropout_rate: 0.5
diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index 465f882..bff8702 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -22,7 +22,7 @@
 )
 from funasr.modules.nets_utils import TooShortUttError
 from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.tasks.asr_transducer import ASRTransducerTask
+from funasr.tasks.asr import ASRTransducerTask
 from funasr.tasks.lm import LMTask
 from funasr.text.build_tokenizer import build_tokenizer
 from funasr.text.token_id_converter import TokenIDConverter
diff --git a/funasr/bin/asr_train_transducer.py b/funasr/bin/asr_train_transducer.py
index 9b6d287..fe418db 100755
--- a/funasr/bin/asr_train_transducer.py
+++ b/funasr/bin/asr_train_transducer.py
@@ -2,7 +2,7 @@
 
 import os
 
-from funasr.tasks.asr_transducer import ASRTransducerTask
+from funasr.tasks.asr import ASRTransducerTask
 
 
 # for ASR Training
diff --git a/funasr/models/rnnt_predictor/rnn_decoder.py b/funasr/models/decoder/rnnt_decoder.py
similarity index 98%
rename from funasr/models/rnnt_predictor/rnn_decoder.py
rename to funasr/models/decoder/rnnt_decoder.py
index 0df6fc7..5401ab2 100644
--- a/funasr/models/rnnt_predictor/rnn_decoder.py
+++ b/funasr/models/decoder/rnnt_decoder.py
@@ -6,10 +6,9 @@
 from typeguard import check_argument_types
 
 from funasr.modules.beam_search.beam_search_transducer import Hypothesis
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
 from funasr.models.specaug.specaug import SpecAug
 
-class RNNDecoder(AbsDecoder):
+class RNNTDecoder(torch.nn.Module):
     """RNN decoder module.
 
     Args:
diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py
index 6eb0023..0cae306 100644
--- a/funasr/models/e2e_asr_transducer.py
+++ b/funasr/models/e2e_asr_transducer.py
@@ -10,7 +10,7 @@
 
 from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
 from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
 from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
 from funasr.models.joint_net.joint_network import JointNetwork
@@ -63,9 +63,9 @@
         specaug: Optional[AbsSpecAug],
         normalize: Optional[AbsNormalize],
         encoder: Encoder,
-        decoder: AbsDecoder,
-        att_decoder: Optional[AbsAttDecoder],
+        decoder: RNNTDecoder,
         joint_network: JointNetwork,
+        att_decoder: Optional[AbsAttDecoder] = None,
         transducer_weight: float = 1.0,
         fastemit_lambda: float = 0.0,
         auxiliary_ctc_weight: float = 0.0,
@@ -482,3 +482,532 @@
         )
 
         return loss_lm
+
+class UnifiedTransducerModel(AbsESPnetModel):
+    """ESPnet2ASRTransducerModel module definition.
+    Args:
+        vocab_size: Size of complete vocabulary (w/ EOS and blank included).
+        token_list: List of token
+        frontend: Frontend module.
+        specaug: SpecAugment module.
+        normalize: Normalization module.
+        encoder: Encoder module.
+        decoder: Decoder module.
+        joint_network: Joint Network module.
+        transducer_weight: Weight of the Transducer loss.
+        fastemit_lambda: FastEmit lambda value.
+        auxiliary_ctc_weight: Weight of auxiliary CTC loss.
+        auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
+        auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
+        auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
+        ignore_id: Initial padding ID.
+        sym_space: Space symbol.
+        sym_blank: Blank Symbol
+        report_cer: Whether to report Character Error Rate during validation.
+        report_wer: Whether to report Word Error Rate during validation.
+        extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
+    """
+
+    def __init__(
+        self,
+        vocab_size: int,
+        token_list: Union[Tuple[str, ...], List[str]],
+        frontend: Optional[AbsFrontend],
+        specaug: Optional[AbsSpecAug],
+        normalize: Optional[AbsNormalize],
+        encoder: Encoder,
+        decoder: RNNTDecoder,
+        joint_network: JointNetwork,
+        att_decoder: Optional[AbsAttDecoder] = None,
+        transducer_weight: float = 1.0,
+        fastemit_lambda: float = 0.0,
+        auxiliary_ctc_weight: float = 0.0,
+        auxiliary_att_weight: float = 0.0,
+        auxiliary_ctc_dropout_rate: float = 0.0,
+        auxiliary_lm_loss_weight: float = 0.0,
+        auxiliary_lm_loss_smoothing: float = 0.0,
+        ignore_id: int = -1,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        report_cer: bool = True,
+        report_wer: bool = True,
+        sym_sos: str = "<sos/eos>",
+        sym_eos: str = "<sos/eos>",
+        extract_feats_in_collect_stats: bool = True,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+    ) -> None:
+        """Construct an ESPnetASRTransducerModel object."""
+        super().__init__()
+
+        assert check_argument_types()
+
+        # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
+        self.blank_id = 0
+
+        if sym_sos in token_list:
+            self.sos = token_list.index(sym_sos)
+        else:
+            self.sos = vocab_size - 1
+        if sym_eos in token_list:
+            self.eos = token_list.index(sym_eos)
+        else:
+            self.eos = vocab_size - 1
+
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.token_list = token_list.copy()
+
+        self.sym_space = sym_space
+        self.sym_blank = sym_blank
+
+        self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+
+        self.encoder = encoder
+        self.decoder = decoder
+        self.joint_network = joint_network
+
+        self.criterion_transducer = None
+        self.error_calculator = None
+
+        self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+        self.use_auxiliary_att = auxiliary_att_weight > 0
+        self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+        if self.use_auxiliary_ctc:
+            self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
+            self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+        if self.use_auxiliary_att:
+            self.att_decoder = att_decoder
+
+            self.criterion_att = LabelSmoothingLoss(
+                size=vocab_size,
+                padding_idx=ignore_id,
+                smoothing=lsm_weight,
+                normalize_length=length_normalized_loss,
+            )
+
+        if self.use_auxiliary_lm_loss:
+            self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+            self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+        self.transducer_weight = transducer_weight
+        self.fastemit_lambda = fastemit_lambda
+
+        self.auxiliary_ctc_weight = auxiliary_ctc_weight
+        self.auxiliary_att_weight = auxiliary_att_weight
+        self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+
+        self.report_cer = report_cer
+        self.report_wer = report_wer
+
+        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Forward architecture and compute loss(es).
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+            text: Label ID sequences. (B, L)
+            text_lengths: Label ID sequences lengths. (B,)
+            kwargs: Contains "utts_id".
+        Return:
+            loss: Main loss value.
+            stats: Task statistics.
+            weight: Task weights.
+        """
+        assert text_lengths.dim() == 1, text_lengths.shape
+        assert (
+            speech.shape[0]
+            == speech_lengths.shape[0]
+            == text.shape[0]
+            == text_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+
+        batch_size = speech.shape[0]
+        text = text[:, : text_lengths.max()]
+        #print(speech.shape)
+        # 1. Encoder
+        encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        loss_att, loss_att_chunk = 0.0, 0.0
+
+        if self.use_auxiliary_att:
+            loss_att, _ = self._calc_att_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+            loss_att_chunk, _ = self._calc_att_loss(
+                encoder_out_chunk, encoder_out_lens, text, text_lengths
+            )
+
+        # 2. Transducer-related I/O preparation
+        decoder_in, target, t_len, u_len = get_transducer_task_io(
+            text,
+            encoder_out_lens,
+            ignore_id=self.ignore_id,
+        )
+
+        # 3. Decoder
+        self.decoder.set_device(encoder_out.device)
+        decoder_out = self.decoder(decoder_in, u_len)
+
+        # 4. Joint Network
+        joint_out = self.joint_network(
+            encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
+        )
+
+        joint_out_chunk = self.joint_network(
+            encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1)
+        )
+
+        # 5. Losses
+        loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss(
+            encoder_out,
+            joint_out,
+            target,
+            t_len,
+            u_len,
+        )
+
+        loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss(
+            encoder_out_chunk,
+            joint_out_chunk,
+            target,
+            t_len,
+            u_len,
+        )
+
+        loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0
+
+        if self.use_auxiliary_ctc:
+            loss_ctc = self._calc_ctc_loss(
+                encoder_out,
+                target,
+                t_len,
+                u_len,
+            )
+            loss_ctc_chunk = self._calc_ctc_loss(
+                encoder_out_chunk,
+                target,
+                t_len,
+                u_len,
+            )
+
+        if self.use_auxiliary_lm_loss:
+            loss_lm = self._calc_lm_loss(decoder_out, target)
+
+        loss_trans = loss_trans_utt + loss_trans_chunk
+        loss_ctc = loss_ctc + loss_ctc_chunk 
+        loss_ctc = loss_att + loss_att_chunk
+
+        loss = (
+            self.transducer_weight * loss_trans
+            + self.auxiliary_ctc_weight * loss_ctc
+            + self.auxiliary_att_weight * loss_att
+            + self.auxiliary_lm_loss_weight * loss_lm
+        )
+
+        stats = dict(
+            loss=loss.detach(),
+            loss_transducer=loss_trans_utt.detach(),
+            loss_transducer_chunk=loss_trans_chunk.detach(),
+            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+            aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None,
+            aux_att_loss=loss_att.detach() if loss_att > 0.0 else None,
+            aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None,
+            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+            cer_transducer=cer_trans,
+            wer_transducer=wer_trans,
+            cer_transducer_chunk=cer_trans_chunk,
+            wer_transducer_chunk=wer_trans_chunk,
+        )
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def collect_feats(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Dict[str, torch.Tensor]:
+        """Collect features sequences and features lengths sequences.
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+            text: Label ID sequences. (B, L)
+            text_lengths: Label ID sequences lengths. (B,)
+            kwargs: Contains "utts_id".
+        Return:
+            {}: "feats": Features sequences. (B, T, D_feats),
+                "feats_lengths": Features sequences lengths. (B,)
+        """
+        if self.extract_feats_in_collect_stats:
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+        else:
+            # Generate dummy stats if extract_feats_in_collect_stats is False
+            logging.warning(
+                "Generating dummy stats for feats and feats_lengths, "
+                "because encoder_conf.extract_feats_in_collect_stats is "
+                f"{self.extract_feats_in_collect_stats}"
+            )
+
+            feats, feats_lengths = speech, speech_lengths
+
+        return {"feats": feats, "feats_lengths": feats_lengths}
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encoder speech sequences.
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+        Return:
+            encoder_out: Encoder outputs. (B, T, D_enc)
+            encoder_out_lens: Encoder outputs lengths. (B,)
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # 4. Forward encoder
+        encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths)
+
+        assert encoder_out.size(0) == speech.size(0), (
+            encoder_out.size(),
+            speech.size(0),
+        )
+        assert encoder_out.size(1) <= encoder_out_lens.max(), (
+            encoder_out.size(),
+            encoder_out_lens.max(),
+        )
+
+        return encoder_out, encoder_out_chunk, encoder_out_lens
+
+    def _extract_feats(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Extract features sequences and features sequences lengths.
+        Args:
+            speech: Speech sequences. (B, S)
+            speech_lengths: Speech sequences lengths. (B,)
+        Return:
+            feats: Features sequences. (B, T, D_feats)
+            feats_lengths: Features sequences lengths. (B,)
+        """
+        assert speech_lengths.dim() == 1, speech_lengths.shape
+
+        # for data-parallel
+        speech = speech[:, : speech_lengths.max()]
+
+        if self.frontend is not None:
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:
+            feats, feats_lengths = speech, speech_lengths
+
+        return feats, feats_lengths
+
+    def _calc_transducer_loss(
+        self,
+        encoder_out: torch.Tensor,
+        joint_out: torch.Tensor,
+        target: torch.Tensor,
+        t_len: torch.Tensor,
+        u_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
+        """Compute Transducer loss.
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            joint_out: Joint Network output sequences (B, T, U, D_joint)
+            target: Target label ID sequences. (B, L)
+            t_len: Encoder output sequences lengths. (B,)
+            u_len: Target label ID sequences lengths. (B,)
+        Return:
+            loss_transducer: Transducer loss value.
+            cer_transducer: Character error rate for Transducer.
+            wer_transducer: Word Error Rate for Transducer.
+        """
+        if self.criterion_transducer is None:
+            try:
+                # from warprnnt_pytorch import RNNTLoss
+            # self.criterion_transducer = RNNTLoss(
+                    # reduction="mean",
+                    # fastemit_lambda=self.fastemit_lambda,
+                # )
+                from warp_rnnt import rnnt_loss as RNNTLoss
+                self.criterion_transducer = RNNTLoss
+
+            except ImportError:
+                logging.error(
+                    "warp-rnnt was not installed."
+                    "Please consult the installation documentation."
+                )
+                exit(1)
+
+        # loss_transducer = self.criterion_transducer(
+        #     joint_out,
+        #     target,
+        #     t_len,
+        #     u_len,
+        # )
+        log_probs = torch.log_softmax(joint_out, dim=-1)
+
+        loss_transducer = self.criterion_transducer(
+                log_probs,
+                target,
+                t_len,
+                u_len,
+                reduction="mean",
+                blank=self.blank_id,
+                fastemit_lambda=self.fastemit_lambda,
+                gather=True,
+        )
+
+        if not self.training and (self.report_cer or self.report_wer):
+            if self.error_calculator is None:
+                self.error_calculator = ErrorCalculator(
+                    self.decoder,
+                    self.joint_network,
+                    self.token_list,
+                    self.sym_space,
+                    self.sym_blank,
+                    report_cer=self.report_cer,
+                    report_wer=self.report_wer,
+                )
+
+            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+            return loss_transducer, cer_transducer, wer_transducer
+
+        return loss_transducer, None, None
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        target: torch.Tensor,
+        t_len: torch.Tensor,
+        u_len: torch.Tensor,
+    ) -> torch.Tensor:
+        """Compute CTC loss.
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            target: Target label ID sequences. (B, L)
+            t_len: Encoder output sequences lengths. (B,)
+            u_len: Target label ID sequences lengths. (B,)
+        Return:
+            loss_ctc: CTC loss value.
+        """
+        ctc_in = self.ctc_lin(
+            torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
+        )
+        ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
+
+        target_mask = target != 0
+        ctc_target = target[target_mask].cpu()
+
+        with torch.backends.cudnn.flags(deterministic=True):
+            loss_ctc = torch.nn.functional.ctc_loss(
+                ctc_in,
+                ctc_target,
+                t_len,
+                u_len,
+                zero_infinity=True,
+                reduction="sum",
+            )
+        loss_ctc /= target.size(0)
+
+        return loss_ctc
+
+    def _calc_lm_loss(
+        self,
+        decoder_out: torch.Tensor,
+        target: torch.Tensor,
+    ) -> torch.Tensor:
+        """Compute LM loss.
+        Args:
+            decoder_out: Decoder output sequences. (B, U, D_dec)
+            target: Target label ID sequences. (B, L)
+        Return:
+            loss_lm: LM loss value.
+        """
+        lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
+        lm_target = target.view(-1).type(torch.int64)
+
+        with torch.no_grad():
+            true_dist = lm_loss_in.clone()
+            true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
+
+            # Ignore blank ID (0)
+            ignore = lm_target == 0
+            lm_target = lm_target.masked_fill(ignore, 0)
+
+            true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
+
+        loss_lm = torch.nn.functional.kl_div(
+            torch.log_softmax(lm_loss_in, dim=1),
+            true_dist,
+            reduction="none",
+        )
+        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
+            0
+        )
+
+        return loss_lm
+
+    def _calc_att_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
+            ys_pad = torch.cat(
+                [
+                    self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device),
+                    ys_pad,
+                ],
+                dim=1,
+            )
+            ys_pad_lens += 1
+
+        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+        ys_in_lens = ys_pad_lens + 1
+
+        # 1. Forward decoder
+        decoder_out, _ = self.att_decoder(
+            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+        )
+
+        # 2. Compute attention loss
+        loss_att = self.criterion_att(decoder_out, ys_out_pad)
+        acc_att = th_accuracy(
+            decoder_out.view(-1, self.vocab_size),
+            ys_out_pad,
+            ignore_label=self.ignore_id,
+        )
+
+        return loss_att, acc_att
diff --git a/funasr/models/e2e_asr_transducer_unified.py b/funasr/models/e2e_asr_transducer_unified.py
deleted file mode 100644
index ad61d12..0000000
--- a/funasr/models/e2e_asr_transducer_unified.py
+++ /dev/null
@@ -1,586 +0,0 @@
-"""ESPnet2 ASR Transducer model."""
-
-import logging
-from contextlib import contextmanager
-from typing import Dict, List, Optional, Tuple, Union
-
-import torch
-from packaging.version import parse as V
-from typeguard import check_argument_types
-
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
-from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
-from funasr.models.joint_net.joint_network import JointNetwork
-from funasr.modules.nets_utils import get_transducer_task_io
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
-from funasr.modules.add_sos_eos import add_sos_eos
-from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.modules.nets_utils import th_accuracy
-from funasr.losses.label_smoothing_loss import (  # noqa: H301
-    LabelSmoothingLoss,
-)
-from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
-if V(torch.__version__) >= V("1.6.0"):
-    from torch.cuda.amp import autocast
-else:
-
-    @contextmanager
-    def autocast(enabled=True):
-        yield
-
-
-class UnifiedTransducerModel(AbsESPnetModel):
-    """ESPnet2ASRTransducerModel module definition.
-
-    Args:
-        vocab_size: Size of complete vocabulary (w/ EOS and blank included).
-        token_list: List of token
-        frontend: Frontend module.
-        specaug: SpecAugment module.
-        normalize: Normalization module.
-        encoder: Encoder module.
-        decoder: Decoder module.
-        joint_network: Joint Network module.
-        transducer_weight: Weight of the Transducer loss.
-        fastemit_lambda: FastEmit lambda value.
-        auxiliary_ctc_weight: Weight of auxiliary CTC loss.
-        auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
-        auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
-        auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
-        ignore_id: Initial padding ID.
-        sym_space: Space symbol.
-        sym_blank: Blank Symbol
-        report_cer: Whether to report Character Error Rate during validation.
-        report_wer: Whether to report Word Error Rate during validation.
-        extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
-
-    """
-
-    def __init__(
-        self,
-        vocab_size: int,
-        token_list: Union[Tuple[str, ...], List[str]],
-        frontend: Optional[AbsFrontend],
-        specaug: Optional[AbsSpecAug],
-        normalize: Optional[AbsNormalize],
-        encoder: Encoder,
-        decoder: AbsDecoder,
-        att_decoder: Optional[AbsAttDecoder],
-        joint_network: JointNetwork,
-        transducer_weight: float = 1.0,
-        fastemit_lambda: float = 0.0,
-        auxiliary_ctc_weight: float = 0.0,
-        auxiliary_att_weight: float = 0.0,
-        auxiliary_ctc_dropout_rate: float = 0.0,
-        auxiliary_lm_loss_weight: float = 0.0,
-        auxiliary_lm_loss_smoothing: float = 0.0,
-        ignore_id: int = -1,
-        sym_space: str = "<space>",
-        sym_blank: str = "<blank>",
-        report_cer: bool = True,
-        report_wer: bool = True,
-        sym_sos: str = "<sos/eos>",
-        sym_eos: str = "<sos/eos>",
-        extract_feats_in_collect_stats: bool = True,
-        lsm_weight: float = 0.0,
-        length_normalized_loss: bool = False,
-    ) -> None:
-        """Construct an ESPnetASRTransducerModel object."""
-        super().__init__()
-
-        assert check_argument_types()
-
-        # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
-        self.blank_id = 0
-
-        if sym_sos in token_list:
-            self.sos = token_list.index(sym_sos)
-        else:
-            self.sos = vocab_size - 1
-        if sym_eos in token_list:
-            self.eos = token_list.index(sym_eos)
-        else:
-            self.eos = vocab_size - 1
-
-        self.vocab_size = vocab_size
-        self.ignore_id = ignore_id
-        self.token_list = token_list.copy()
-
-        self.sym_space = sym_space
-        self.sym_blank = sym_blank
-
-        self.frontend = frontend
-        self.specaug = specaug
-        self.normalize = normalize
-
-        self.encoder = encoder
-        self.decoder = decoder
-        self.joint_network = joint_network
-
-        self.criterion_transducer = None
-        self.error_calculator = None
-
-        self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
-        self.use_auxiliary_att = auxiliary_att_weight > 0
-        self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
-
-        if self.use_auxiliary_ctc:
-            self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
-            self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
-
-        if self.use_auxiliary_att:
-            self.att_decoder = att_decoder
-
-            self.criterion_att = LabelSmoothingLoss(
-                size=vocab_size,
-                padding_idx=ignore_id,
-                smoothing=lsm_weight,
-                normalize_length=length_normalized_loss,
-            )
-
-        if self.use_auxiliary_lm_loss:
-            self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
-            self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
-
-        self.transducer_weight = transducer_weight
-        self.fastemit_lambda = fastemit_lambda
-
-        self.auxiliary_ctc_weight = auxiliary_ctc_weight
-        self.auxiliary_att_weight = auxiliary_att_weight
-        self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
-
-        self.report_cer = report_cer
-        self.report_wer = report_wer
-
-        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
-
-    def forward(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
-        **kwargs,
-    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        """Forward architecture and compute loss(es).
-
-        Args:
-            speech: Speech sequences. (B, S)
-            speech_lengths: Speech sequences lengths. (B,)
-            text: Label ID sequences. (B, L)
-            text_lengths: Label ID sequences lengths. (B,)
-            kwargs: Contains "utts_id".
-
-        Return:
-            loss: Main loss value.
-            stats: Task statistics.
-            weight: Task weights.
-
-        """
-        assert text_lengths.dim() == 1, text_lengths.shape
-        assert (
-            speech.shape[0]
-            == speech_lengths.shape[0]
-            == text.shape[0]
-            == text_lengths.shape[0]
-        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
-
-        batch_size = speech.shape[0]
-        text = text[:, : text_lengths.max()]
-        #print(speech.shape)
-        # 1. Encoder
-        encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
-
-        loss_att, loss_att_chunk = 0.0, 0.0
-
-        if self.use_auxiliary_att:
-            loss_att, _ = self._calc_att_loss(
-                encoder_out, encoder_out_lens, text, text_lengths
-            )
-            loss_att_chunk, _ = self._calc_att_loss(
-                encoder_out_chunk, encoder_out_lens, text, text_lengths
-            )
-
-        # 2. Transducer-related I/O preparation
-        decoder_in, target, t_len, u_len = get_transducer_task_io(
-            text,
-            encoder_out_lens,
-            ignore_id=self.ignore_id,
-        )
-        
-        # 3. Decoder
-        self.decoder.set_device(encoder_out.device)
-        decoder_out = self.decoder(decoder_in, u_len)
-
-        # 4. Joint Network
-        joint_out = self.joint_network(
-            encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
-        )
-        
-        joint_out_chunk = self.joint_network(
-            encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1)
-        )
-
-        # 5. Losses
-        loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss(
-            encoder_out,
-            joint_out,
-            target,
-            t_len,
-            u_len,
-        )
-        
-        loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss(
-            encoder_out_chunk,
-            joint_out_chunk,
-            target,
-            t_len,
-            u_len,
-        )
-
-        loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0
-
-        if self.use_auxiliary_ctc:
-            loss_ctc = self._calc_ctc_loss(
-                encoder_out,
-                target,
-                t_len,
-                u_len,
-            )
-            loss_ctc_chunk = self._calc_ctc_loss(
-                encoder_out_chunk,
-                target,
-                t_len,
-                u_len,
-            )
-
-        if self.use_auxiliary_lm_loss:
-            loss_lm = self._calc_lm_loss(decoder_out, target)
-
-        loss_trans = loss_trans_utt + loss_trans_chunk
-        loss_ctc = loss_ctc + loss_ctc_chunk 
-        loss_ctc = loss_att + loss_att_chunk
-
-        loss = (
-            self.transducer_weight * loss_trans
-            + self.auxiliary_ctc_weight * loss_ctc
-            + self.auxiliary_att_weight * loss_att
-            + self.auxiliary_lm_loss_weight * loss_lm
-        )
-
-        stats = dict(
-            loss=loss.detach(),
-            loss_transducer=loss_trans_utt.detach(),
-            loss_transducer_chunk=loss_trans_chunk.detach(),
-            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
-            aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None,
-            aux_att_loss=loss_att.detach() if loss_att > 0.0 else None,
-            aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None,
-            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
-            cer_transducer=cer_trans,
-            wer_transducer=wer_trans,
-            cer_transducer_chunk=cer_trans_chunk,
-            wer_transducer_chunk=wer_trans_chunk,
-        )
-
-        # force_gatherable: to-device and to-tensor if scalar for DataParallel
-        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-        return loss, stats, weight
-
-    def collect_feats(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
-        **kwargs,
-    ) -> Dict[str, torch.Tensor]:
-        """Collect features sequences and features lengths sequences.
-
-        Args:
-            speech: Speech sequences. (B, S)
-            speech_lengths: Speech sequences lengths. (B,)
-            text: Label ID sequences. (B, L)
-            text_lengths: Label ID sequences lengths. (B,)
-            kwargs: Contains "utts_id".
-
-        Return:
-            {}: "feats": Features sequences. (B, T, D_feats),
-                "feats_lengths": Features sequences lengths. (B,)
-
-        """
-        if self.extract_feats_in_collect_stats:
-            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-        else:
-            # Generate dummy stats if extract_feats_in_collect_stats is False
-            logging.warning(
-                "Generating dummy stats for feats and feats_lengths, "
-                "because encoder_conf.extract_feats_in_collect_stats is "
-                f"{self.extract_feats_in_collect_stats}"
-            )
-
-            feats, feats_lengths = speech, speech_lengths
-
-        return {"feats": feats, "feats_lengths": feats_lengths}
-
-    def encode(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Encoder speech sequences.
-
-        Args:
-            speech: Speech sequences. (B, S)
-            speech_lengths: Speech sequences lengths. (B,)
-
-        Return:
-            encoder_out: Encoder outputs. (B, T, D_enc)
-            encoder_out_lens: Encoder outputs lengths. (B,)
-
-        """
-        with autocast(False):
-            # 1. Extract feats
-            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
-            # 2. Data augmentation
-            if self.specaug is not None and self.training:
-                feats, feats_lengths = self.specaug(feats, feats_lengths)
-
-            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
-            if self.normalize is not None:
-                feats, feats_lengths = self.normalize(feats, feats_lengths)
-
-        # 4. Forward encoder
-        encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths)
-
-        assert encoder_out.size(0) == speech.size(0), (
-            encoder_out.size(),
-            speech.size(0),
-        )
-        assert encoder_out.size(1) <= encoder_out_lens.max(), (
-            encoder_out.size(),
-            encoder_out_lens.max(),
-        )
-
-        return encoder_out, encoder_out_chunk, encoder_out_lens
-
-    def _extract_feats(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Extract features sequences and features sequences lengths.
-
-        Args:
-            speech: Speech sequences. (B, S)
-            speech_lengths: Speech sequences lengths. (B,)
-
-        Return:
-            feats: Features sequences. (B, T, D_feats)
-            feats_lengths: Features sequences lengths. (B,)
-
-        """
-        assert speech_lengths.dim() == 1, speech_lengths.shape
-
-        # for data-parallel
-        speech = speech[:, : speech_lengths.max()]
-
-        if self.frontend is not None:
-            feats, feats_lengths = self.frontend(speech, speech_lengths)
-        else:
-            feats, feats_lengths = speech, speech_lengths
-
-        return feats, feats_lengths
-
-    def _calc_transducer_loss(
-        self,
-        encoder_out: torch.Tensor,
-        joint_out: torch.Tensor,
-        target: torch.Tensor,
-        t_len: torch.Tensor,
-        u_len: torch.Tensor,
-    ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
-        """Compute Transducer loss.
-
-        Args:
-            encoder_out: Encoder output sequences. (B, T, D_enc)
-            joint_out: Joint Network output sequences (B, T, U, D_joint)
-            target: Target label ID sequences. (B, L)
-            t_len: Encoder output sequences lengths. (B,)
-            u_len: Target label ID sequences lengths. (B,)
-
-        Return:
-            loss_transducer: Transducer loss value.
-            cer_transducer: Character error rate for Transducer.
-            wer_transducer: Word Error Rate for Transducer.
-
-        """
-        if self.criterion_transducer is None:
-            try:
-                # from warprnnt_pytorch import RNNTLoss
-	        # self.criterion_transducer = RNNTLoss(
-                    # reduction="mean",
-                    # fastemit_lambda=self.fastemit_lambda,
-                # )
-                from warp_rnnt import rnnt_loss as RNNTLoss
-                self.criterion_transducer = RNNTLoss
-
-            except ImportError:
-                logging.error(
-                    "warp-rnnt was not installed."
-                    "Please consult the installation documentation."
-                )
-                exit(1)
-
-        # loss_transducer = self.criterion_transducer(
-        #     joint_out,
-        #     target,
-        #     t_len,
-        #     u_len,
-        # )
-        log_probs = torch.log_softmax(joint_out, dim=-1)
-
-        loss_transducer = self.criterion_transducer(
-                log_probs,
-                target,
-                t_len,
-                u_len,
-                reduction="mean",
-                blank=self.blank_id,
-                fastemit_lambda=self.fastemit_lambda,
-                gather=True,
-        )
-
-        if not self.training and (self.report_cer or self.report_wer):
-            if self.error_calculator is None:
-                self.error_calculator = ErrorCalculator(
-                    self.decoder,
-                    self.joint_network,
-                    self.token_list,
-                    self.sym_space,
-                    self.sym_blank,
-                    report_cer=self.report_cer,
-                    report_wer=self.report_wer,
-                )
-
-            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
-            return loss_transducer, cer_transducer, wer_transducer
-
-        return loss_transducer, None, None
-
-    def _calc_ctc_loss(
-        self,
-        encoder_out: torch.Tensor,
-        target: torch.Tensor,
-        t_len: torch.Tensor,
-        u_len: torch.Tensor,
-    ) -> torch.Tensor:
-        """Compute CTC loss.
-
-        Args:
-            encoder_out: Encoder output sequences. (B, T, D_enc)
-            target: Target label ID sequences. (B, L)
-            t_len: Encoder output sequences lengths. (B,)
-            u_len: Target label ID sequences lengths. (B,)
-
-        Return:
-            loss_ctc: CTC loss value.
-
-        """
-        ctc_in = self.ctc_lin(
-            torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
-        )
-        ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
-
-        target_mask = target != 0
-        ctc_target = target[target_mask].cpu()
-
-        with torch.backends.cudnn.flags(deterministic=True):
-            loss_ctc = torch.nn.functional.ctc_loss(
-                ctc_in,
-                ctc_target,
-                t_len,
-                u_len,
-                zero_infinity=True,
-                reduction="sum",
-            )
-        loss_ctc /= target.size(0)
-
-        return loss_ctc
-
-    def _calc_lm_loss(
-        self,
-        decoder_out: torch.Tensor,
-        target: torch.Tensor,
-    ) -> torch.Tensor:
-        """Compute LM loss.
-
-        Args:
-            decoder_out: Decoder output sequences. (B, U, D_dec)
-            target: Target label ID sequences. (B, L)
-
-        Return:
-            loss_lm: LM loss value.
-
-        """
-        lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
-        lm_target = target.view(-1).type(torch.int64)
-
-        with torch.no_grad():
-            true_dist = lm_loss_in.clone()
-            true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
-
-            # Ignore blank ID (0)
-            ignore = lm_target == 0
-            lm_target = lm_target.masked_fill(ignore, 0)
-
-            true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
-
-        loss_lm = torch.nn.functional.kl_div(
-            torch.log_softmax(lm_loss_in, dim=1),
-            true_dist,
-            reduction="none",
-        )
-        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
-            0
-        )
-
-        return loss_lm
-
-    def _calc_att_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
-    ):
-        if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
-            ys_pad = torch.cat(
-                [
-                    self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device),
-                    ys_pad,
-                ],
-                dim=1,
-            )
-            ys_pad_lens += 1
-
-        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-        ys_in_lens = ys_pad_lens + 1
-
-        # 1. Forward decoder
-        decoder_out, _ = self.att_decoder(
-            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
-        )
-
-        # 2. Compute attention loss
-        loss_att = self.criterion_att(decoder_out, ys_out_pad)
-        acc_att = th_accuracy(
-            decoder_out.view(-1, self.vocab_size),
-            ys_out_pad,
-            ignore_label=self.ignore_id,
-        )
-
-        return loss_att, acc_att
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index b7b552c..9777cee 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -894,7 +894,7 @@
 
         return x, cache
 
-class ConformerChunkEncoder(torch.nn.Module):
+class ConformerChunkEncoder(AbsEncoder):
     """Encoder module definition.
     Args:
         input_size: Input size.
@@ -1007,7 +1007,7 @@
             output_size,
         )
 
-        self.output_size = output_size
+        self._output_size = output_size
 
         self.dynamic_chunk_training = dynamic_chunk_training
         self.short_chunk_threshold = short_chunk_threshold
@@ -1020,6 +1020,9 @@
 
         self.time_reduction_factor = time_reduction_factor
 
+    def output_size(self) -> int:
+        return self._output_size
+
     def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
         """Return the corresponding number of sample for a given chunk size, in frames.
         Where size is the number of features frames after applying subsampling.
diff --git a/funasr/models/rnnt_predictor/__init__.py b/funasr/models/rnnt_predictor/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models/rnnt_predictor/__init__.py
+++ /dev/null
diff --git a/funasr/models/rnnt_predictor/abs_decoder.py b/funasr/models/rnnt_predictor/abs_decoder.py
deleted file mode 100644
index 5b4a335..0000000
--- a/funasr/models/rnnt_predictor/abs_decoder.py
+++ /dev/null
@@ -1,110 +0,0 @@
-"""Abstract decoder definition for Transducer models."""
-
-from abc import ABC, abstractmethod
-from typing import Any, List, Optional, Tuple
-
-import torch
-
-
-class AbsDecoder(torch.nn.Module, ABC):
-    """Abstract decoder module."""
-
-    @abstractmethod
-    def forward(self, labels: torch.Tensor) -> torch.Tensor:
-        """Encode source label sequences.
-
-        Args:
-            labels: Label ID sequences. (B, L)
-
-        Returns:
-            dec_out: Decoder output sequences. (B, T, D_dec)
-
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def score(
-        self,
-        label: torch.Tensor,
-        label_sequence: List[int],
-        dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]],
-    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]:
-        """One-step forward hypothesis.
-
-        Args:
-            label: Previous label. (1, 1)
-            label_sequence: Current label sequence.
-            dec_state: Previous decoder hidden states.
-                         ((N, 1, D_dec), (N, 1, D_dec) or None) or None
-
-        Returns:
-            dec_out: Decoder output sequence. (1, D_dec) or (1, D_emb)
-            dec_state: Decoder hidden states.
-                         ((N, 1, D_dec), (N, 1, D_dec) or None) or None
-
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def batch_score(
-        self,
-        hyps: List[Any],
-    ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]]:
-        """One-step forward hypotheses.
-
-        Args:
-            hyps: Hypotheses.
-
-        Returns:
-            dec_out: Decoder output sequences. (B, D_dec) or (B, D_emb)
-            states: Decoder hidden states.
-                      ((N, B, D_dec), (N, B, D_dec) or None) or None
-
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def set_device(self, device: torch.Tensor) -> None:
-        """Set GPU device to use.
-
-        Args:
-            device: Device ID.
-
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def init_state(
-        self, batch_size: int
-    ) -> Optional[Tuple[torch.Tensor, Optional[torch.tensor]]]:
-        """Initialize decoder states.
-
-        Args:
-            batch_size: Batch size.
-
-        Returns:
-            : Initial decoder hidden states.
-                ((N, B, D_dec), (N, B, D_dec) or None) or None
-
-        """
-        raise NotImplementedError
-
-    @abstractmethod
-    def select_state(
-        self,
-        states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
-        idx: int = 0,
-    ) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
-        """Get specified ID state from batch of states, if provided.
-
-        Args:
-            states: Decoder hidden states.
-                      ((N, B, D_dec), (N, B, D_dec) or None) or None
-            idx: State ID to extract.
-
-        Returns:
-            : Decoder hidden state for given ID.
-                ((N, 1, D_dec), (N, 1, D_dec) or None) or None
-
-        """
-        raise NotImplementedError
diff --git a/funasr/models/rnnt_predictor/stateless_decoder.py b/funasr/models/rnnt_predictor/stateless_decoder.py
deleted file mode 100644
index 70cd877..0000000
--- a/funasr/models/rnnt_predictor/stateless_decoder.py
+++ /dev/null
@@ -1,145 +0,0 @@
-"""Stateless decoder definition for Transducer models."""
-
-from typing import List, Optional, Tuple
-
-import torch
-from typeguard import check_argument_types
-
-from funasr.modules.beam_search.beam_search_transducer import Hypothesis
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
-from funasr.models.specaug.specaug import SpecAug
-
-class StatelessDecoder(AbsDecoder):
-    """Stateless Transducer decoder module.
-
-    Args:
-        vocab_size: Output size.
-        embed_size: Embedding size.
-        embed_dropout_rate: Dropout rate for embedding layer.
-        embed_pad: Embed/Blank symbol ID.
-
-    """
-
-    def __init__(
-        self,
-        vocab_size: int,
-        embed_size: int = 256,
-        embed_dropout_rate: float = 0.0,
-        embed_pad: int = 0,
-    ) -> None:
-        """Construct a StatelessDecoder object."""
-        super().__init__()
-
-        assert check_argument_types()
-
-        self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
-        self.embed_dropout_rate = torch.nn.Dropout(p=embed_dropout_rate)
-
-        self.output_size = embed_size
-        self.vocab_size = vocab_size
-
-        self.device = next(self.parameters()).device
-        self.score_cache = {}
-
-
-
-    def forward(
-        self,
-        labels: torch.Tensor,
-        label_lens: torch.Tensor,
-        states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
-    ) -> torch.Tensor:
-        """Encode source label sequences.
-
-        Args:
-            labels: Label ID sequences. (B, L)
-            states: Decoder hidden states. None
-
-        Returns:
-            dec_embed: Decoder output sequences. (B, U, D_emb)
-
-        """
-        dec_embed = self.embed_dropout_rate(self.embed(labels))
-        return dec_embed
-
-    def score(
-        self,
-        label: torch.Tensor,
-        label_sequence: List[int],
-        state: None,
-    ) -> Tuple[torch.Tensor, None]:
-        """One-step forward hypothesis.
-
-        Args:
-            label: Previous label. (1, 1)
-            label_sequence: Current label sequence.
-            state: Previous decoder hidden states. None
-
-        Returns:
-            dec_out: Decoder output sequence. (1, D_emb)
-            state: Decoder hidden states. None
-
-        """
-        str_labels = "_".join(map(str, label_sequence))
-
-        if str_labels in self.score_cache:
-            dec_embed = self.score_cache[str_labels]
-        else:
-            dec_embed = self.embed(label)
-
-            self.score_cache[str_labels] = dec_embed
-
-        return dec_embed[0], None
-
-    def batch_score(
-        self,
-        hyps: List[Hypothesis],
-    ) -> Tuple[torch.Tensor, None]:
-        """One-step forward hypotheses.
-
-        Args:
-            hyps: Hypotheses.
-
-        Returns:
-            dec_out: Decoder output sequences. (B, D_dec)
-            states: Decoder hidden states. None
-
-        """
-        labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
-        dec_embed = self.embed(labels)
-
-        return dec_embed.squeeze(1), None
-
-    def set_device(self, device: torch.device) -> None:
-        """Set GPU device to use.
-
-        Args:
-            device: Device ID.
-
-        """
-        self.device = device
-
-    def init_state(self, batch_size: int) -> None:
-        """Initialize decoder states.
-
-        Args:
-            batch_size: Batch size.
-
-        Returns:
-            : Initial decoder hidden states. None
-
-        """
-        return None
-
-    def select_state(self, states: Optional[torch.Tensor], idx: int) -> None:
-        """Get specified ID state from decoder hidden states.
-
-        Args:
-            states: Decoder hidden states. None
-            idx: State ID to extract.
-
-        Returns:
-            : Decoder hidden state for given ID. None
-
-        """
-        return None
diff --git a/funasr/modules/beam_search/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py
index 8b7e613..3eb8e08 100644
--- a/funasr/modules/beam_search/beam_search_transducer.py
+++ b/funasr/modules/beam_search/beam_search_transducer.py
@@ -6,7 +6,6 @@
 import numpy as np
 import torch
 
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
 from funasr.models.joint_net.joint_network import JointNetwork
 
 
@@ -68,7 +67,7 @@
 
     def __init__(
         self,
-        decoder: AbsDecoder,
+        decoder,
         joint_network: JointNetwork,
         beam_size: int,
         lm: Optional[torch.nn.Module] = None,
diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py
index a01cd5e..f430fcb 100644
--- a/funasr/modules/e2e_asr_common.py
+++ b/funasr/modules/e2e_asr_common.py
@@ -18,7 +18,6 @@
 import torch
 
 from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
 from funasr.models.joint_net.joint_network import JointNetwork
 
 def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
@@ -268,7 +267,7 @@
 
     def __init__(
         self,
-        decoder: AbsDecoder,
+        decoder,
         joint_network: JointNetwork,
         token_list: List[int],
         sym_space: str,
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index e151473..87db05c 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -38,13 +38,16 @@
 from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.models.decoder.transformer_decoder import TransformerDecoder
 from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
+from funasr.models.joint_net.joint_network import JointNetwork
 from funasr.models.e2e_asr import ESPnetASRModel
 from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
 from funasr.models.e2e_tp import TimestampPredictor
 from funasr.models.e2e_asr_mfcca import MFCCA
 from funasr.models.e2e_uni_asr import UniASR
+from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
 from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
 from funasr.models.encoder.rnn_encoder import RNNEncoder
 from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
@@ -150,6 +153,7 @@
         sanm_chunk_opt=SANMEncoderChunkOpt,
         data2vec_encoder=Data2VecEncoder,
         mfcca_enc=MFCCAEncoder,
+        chunk_conformer=ConformerChunkEncoder,
     ),
     type_check=AbsEncoder,
     default="rnn",
@@ -207,6 +211,16 @@
     type_check=AbsDecoder,
     default="rnn",
 )
+
+rnnt_decoder_choices = ClassChoices(
+    "rnnt_decoder",
+    classes=dict(
+        rnnt=RNNTDecoder,
+    ),
+    type_check=RNNTDecoder,
+    default="rnnt",
+)
+
 predictor_choices = ClassChoices(
     name="predictor",
     classes=dict(
@@ -1331,3 +1345,378 @@
     ) -> Tuple[str, ...]:
         retval = ("speech", "text")
         return retval
+
+
+class ASRTransducerTask(AbsTask):
+    """ASR Transducer Task definition."""
+
+    num_optimizers: int = 1
+
+    class_choices_list = [
+        frontend_choices,
+        specaug_choices,
+        normalize_choices,
+        encoder_choices,
+        rnnt_decoder_choices,
+    ]
+
+    trainer = Trainer
+
+    @classmethod
+    def add_task_arguments(cls, parser: argparse.ArgumentParser):
+        """Add Transducer task arguments.
+        Args:
+            cls: ASRTransducerTask object.
+            parser: Transducer arguments parser.
+        """
+        group = parser.add_argument_group(description="Task related.")
+
+        # required = parser.get_default("required")
+        # required += ["token_list"]
+
+        group.add_argument(
+            "--token_list",
+            type=str_or_none,
+            default=None,
+            help="Integer-string mapper for tokens.",
+        )
+        group.add_argument(
+            "--split_with_space",
+            type=str2bool,
+            default=True,
+            help="whether to split text using <space>",
+        )
+        group.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of dimensions for input features.",
+        )
+        group.add_argument(
+            "--init",
+            type=str_or_none,
+            default=None,
+            help="Type of model initialization to use.",
+        )
+        group.add_argument(
+            "--model_conf",
+            action=NestedDictAction,
+            default=get_default_kwargs(TransducerModel),
+            help="The keyword arguments for the model class.",
+        )
+        # group.add_argument(
+        #     "--encoder_conf",
+        #     action=NestedDictAction,
+        #     default={},
+        #     help="The keyword arguments for the encoder class.",
+        # )
+        group.add_argument(
+            "--joint_network_conf",
+            action=NestedDictAction,
+            default={},
+            help="The keyword arguments for the joint network class.",
+        )
+        group = parser.add_argument_group(description="Preprocess related.")
+        group.add_argument(
+            "--use_preprocessor",
+            type=str2bool,
+            default=True,
+            help="Whether to apply preprocessing to input data.",
+        )
+        group.add_argument(
+            "--token_type",
+            type=str,
+            default="bpe",
+            choices=["bpe", "char", "word", "phn"],
+            help="The type of tokens to use during tokenization.",
+        )
+        group.add_argument(
+            "--bpemodel",
+            type=str_or_none,
+            default=None,
+            help="The path of the sentencepiece model.",
+        )
+        parser.add_argument(
+            "--non_linguistic_symbols",
+            type=str_or_none,
+            help="The 'non_linguistic_symbols' file path.",
+        )
+        parser.add_argument(
+            "--cleaner",
+            type=str_or_none,
+            choices=[None, "tacotron", "jaconv", "vietnamese"],
+            default=None,
+            help="Text cleaner to use.",
+        )
+        parser.add_argument(
+            "--g2p",
+            type=str_or_none,
+            choices=g2p_choices,
+            default=None,
+            help="g2p method to use if --token_type=phn.",
+        )
+        parser.add_argument(
+            "--speech_volume_normalize",
+            type=float_or_none,
+            default=None,
+            help="Normalization value for maximum amplitude scaling.",
+        )
+        parser.add_argument(
+            "--rir_scp",
+            type=str_or_none,
+            default=None,
+            help="The RIR SCP file path.",
+        )
+        parser.add_argument(
+            "--rir_apply_prob",
+            type=float,
+            default=1.0,
+            help="The probability of the applied RIR convolution.",
+        )
+        parser.add_argument(
+            "--noise_scp",
+            type=str_or_none,
+            default=None,
+            help="The path of noise SCP file.",
+        )
+        parser.add_argument(
+            "--noise_apply_prob",
+            type=float,
+            default=1.0,
+            help="The probability of the applied noise addition.",
+        )
+        parser.add_argument(
+            "--noise_db_range",
+            type=str,
+            default="13_15",
+            help="The range of the noise decibel level.",
+        )
+        for class_choices in cls.class_choices_list:
+            # Append --<name> and --<name>_conf.
+            # e.g. --decoder and --decoder_conf
+            class_choices.add_arguments(group)
+
+    @classmethod
+    def build_collate_fn(
+        cls, args: argparse.Namespace, train: bool
+    ) -> Callable[
+        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+        Tuple[List[str], Dict[str, torch.Tensor]],
+    ]:
+        """Build collate function.
+        Args:
+            cls: ASRTransducerTask object.
+            args: Task arguments.
+            train: Training mode.
+        Return:
+            : Callable collate function.
+        """
+        assert check_argument_types()
+
+        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+    @classmethod
+    def build_preprocess_fn(
+        cls, args: argparse.Namespace, train: bool
+    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+        """Build pre-processing function.
+        Args:
+            cls: ASRTransducerTask object.
+            args: Task arguments.
+            train: Training mode.
+        Return:
+            : Callable pre-processing function.
+        """
+        assert check_argument_types()
+
+        if args.use_preprocessor:
+            retval = CommonPreprocessor(
+                train=train,
+                token_type=args.token_type,
+                token_list=args.token_list,
+                bpemodel=args.bpemodel,
+                non_linguistic_symbols=args.non_linguistic_symbols,
+                text_cleaner=args.cleaner,
+                g2p_type=args.g2p,
+                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+                rir_apply_prob=args.rir_apply_prob
+                if hasattr(args, "rir_apply_prob")
+                else 1.0,
+                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+                noise_apply_prob=args.noise_apply_prob
+                if hasattr(args, "noise_apply_prob")
+                else 1.0,
+                noise_db_range=args.noise_db_range
+                if hasattr(args, "noise_db_range")
+                else "13_15",
+                speech_volume_normalize=args.speech_volume_normalize
+                if hasattr(args, "rir_scp")
+                else None,
+            )
+        else:
+            retval = None
+
+        assert check_return_type(retval)
+        return retval
+
+    @classmethod
+    def required_data_names(
+        cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        """Required data depending on task mode.
+        Args:
+            cls: ASRTransducerTask object.
+            train: Training mode.
+            inference: Inference mode.
+        Return:
+            retval: Required task data.
+        """
+        if not inference:
+            retval = ("speech", "text")
+        else:
+            retval = ("speech",)
+
+        return retval
+
+    @classmethod
+    def optional_data_names(
+        cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        """Optional data depending on task mode.
+        Args:
+            cls: ASRTransducerTask object.
+            train: Training mode.
+            inference: Inference mode.
+        Return:
+            retval: Optional task data.
+        """
+        retval = ()
+        assert check_return_type(retval)
+
+        return retval
+
+    @classmethod
+    def build_model(cls, args: argparse.Namespace) -> TransducerModel:
+        """Required data depending on task mode.
+        Args:
+            cls: ASRTransducerTask object.
+            args: Task arguments.
+        Return:
+            model: ASR Transducer model.
+        """
+        assert check_argument_types()
+
+        if isinstance(args.token_list, str):
+            with open(args.token_list, encoding="utf-8") as f:
+                token_list = [line.rstrip() for line in f]
+
+            # Overwriting token_list to keep it as "portable".
+            args.token_list = list(token_list)
+        elif isinstance(args.token_list, (tuple, list)):
+            token_list = list(args.token_list)
+        else:
+            raise RuntimeError("token_list must be str or list")
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size }")
+
+        # 1. frontend
+        if args.input_size is None:
+            # Extract features in the model
+            frontend_class = frontend_choices.get_class(args.frontend)
+            frontend = frontend_class(**args.frontend_conf)
+            input_size = frontend.output_size()
+        else:
+            # Give features from data-loader
+            frontend = None
+            input_size = args.input_size
+
+        # 2. Data augmentation for spectrogram
+        if args.specaug is not None:
+            specaug_class = specaug_choices.get_class(args.specaug)
+            specaug = specaug_class(**args.specaug_conf)
+        else:
+            specaug = None
+
+        # 3. Normalization layer
+        if args.normalize is not None:
+            normalize_class = normalize_choices.get_class(args.normalize)
+            normalize = normalize_class(**args.normalize_conf)
+        else:
+            normalize = None
+
+        # 4. Encoder
+
+        if getattr(args, "encoder", None) is not None:
+            encoder_class = encoder_choices.get_class(args.encoder)
+            encoder = encoder_class(input_size, **args.encoder_conf)
+        else:
+            encoder = Encoder(input_size, **args.encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        # 5. Decoder
+        rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
+        decoder = rnnt_decoder_class(
+            vocab_size,
+            **args.rnnt_decoder_conf,
+        )
+        decoder_output_size = decoder.output_size
+
+        if getattr(args, "decoder", None) is not None:
+            att_decoder_class = decoder_choices.get_class(args.att_decoder)
+
+            att_decoder = att_decoder_class(
+                vocab_size=vocab_size,
+                encoder_output_size=encoder_output_size,
+                **args.decoder_conf,
+            )
+        else:
+            att_decoder = None
+        # 6. Joint Network
+        joint_network = JointNetwork(
+            vocab_size,
+            encoder_output_size,
+            decoder_output_size,
+            **args.joint_network_conf,
+        )
+
+        # 7. Build model
+
+        if encoder.unified_model_training:
+            model = UnifiedTransducerModel(
+                vocab_size=vocab_size,
+                token_list=token_list,
+                frontend=frontend,
+                specaug=specaug,
+                normalize=normalize,
+                encoder=encoder,
+                decoder=decoder,
+                att_decoder=att_decoder,
+                joint_network=joint_network,
+                **args.model_conf,
+            )
+
+        else:
+            model = TransducerModel(
+                vocab_size=vocab_size,
+                token_list=token_list,
+                frontend=frontend,
+                specaug=specaug,
+                normalize=normalize,
+                encoder=encoder,
+                decoder=decoder,
+                att_decoder=att_decoder,
+                joint_network=joint_network,
+                **args.model_conf,
+            )
+
+        # 8. Initialize model
+        if args.init is not None:
+            raise NotImplementedError(
+                "Currently not supported.",
+                "Initialization part will be reworked in a short future.",
+            )
+
+        #assert check_return_type(model)
+
+        return model
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
deleted file mode 100644
index d4136d0..0000000
--- a/funasr/tasks/asr_transducer.py
+++ /dev/null
@@ -1,477 +0,0 @@
-"""ASR Transducer Task."""
-
-import argparse
-import logging
-from typing import Callable, Collection, Dict, List, Optional, Tuple
-
-import numpy as np
-import torch
-from typeguard import check_argument_types, check_return_type
-
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.specaug.specaug import SpecAug
-from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models.decoder.transformer_decoder import (
-    DynamicConvolution2DTransformerDecoder,
-    DynamicConvolutionTransformerDecoder,
-    LightweightConvolution2DTransformerDecoder,
-    LightweightConvolutionTransformerDecoder,
-    TransformerDecoder,
-)
-from funasr.models.rnnt_predictor.abs_decoder import AbsDecoder
-from funasr.models.rnnt_predictor.rnn_decoder import RNNDecoder
-from funasr.models.rnnt_predictor.stateless_decoder import StatelessDecoder
-from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder
-from funasr.models.e2e_asr_transducer import TransducerModel
-from funasr.models.e2e_asr_transducer_unified import UnifiedTransducerModel
-from funasr.models.joint_net.joint_network import JointNetwork
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
-from funasr.train.class_choices import ClassChoices
-from funasr.datasets.collate_fn import CommonCollateFn
-from funasr.datasets.preprocessor import CommonPreprocessor
-from funasr.train.trainer import Trainer
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import float_or_none, int_or_none, str2bool, str_or_none
-
-frontend_choices = ClassChoices(
-    name="frontend",
-    classes=dict(
-        default=DefaultFrontend,
-        sliding_window=SlidingWindow,
-    ),
-    type_check=AbsFrontend,
-    default="default",
-)
-specaug_choices = ClassChoices(
-    "specaug",
-    classes=dict(
-        specaug=SpecAug,
-    ),
-    type_check=AbsSpecAug,
-    default=None,
-    optional=True,
-)
-normalize_choices = ClassChoices(
-    "normalize",
-    classes=dict(
-        global_mvn=GlobalMVN,
-        utterance_mvn=UtteranceMVN,
-    ),
-    type_check=AbsNormalize,
-    default="utterance_mvn",
-    optional=True,
-)
-encoder_choices = ClassChoices(
-        "encoder",
-        classes=dict(
-                chunk_conformer=ConformerChunkEncoder,
-        ),
-        default="chunk_conformer",
-)
-
-decoder_choices = ClassChoices(
-    "decoder",
-    classes=dict(
-        rnn=RNNDecoder,
-        stateless=StatelessDecoder,
-    ),
-    type_check=AbsDecoder,
-    default="rnn",
-)
-
-att_decoder_choices = ClassChoices(
-    "att_decoder",
-    classes=dict(
-        transformer=TransformerDecoder,
-        lightweight_conv=LightweightConvolutionTransformerDecoder,
-        lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
-        dynamic_conv=DynamicConvolutionTransformerDecoder,
-        dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
-    ),
-    type_check=AbsAttDecoder,
-    default=None,
-    optional=True,
-)
-class ASRTransducerTask(AbsTask):
-    """ASR Transducer Task definition."""
-
-    num_optimizers: int = 1
-
-    class_choices_list = [
-        frontend_choices,
-        specaug_choices,
-        normalize_choices,
-        encoder_choices,
-        decoder_choices,
-        att_decoder_choices,
-    ]
-
-    trainer = Trainer
-
-    @classmethod
-    def add_task_arguments(cls, parser: argparse.ArgumentParser):
-        """Add Transducer task arguments.
-        Args:
-            cls: ASRTransducerTask object.
-            parser: Transducer arguments parser.
-        """
-        group = parser.add_argument_group(description="Task related.")
-
-        # required = parser.get_default("required")
-        # required += ["token_list"]
-
-        group.add_argument(
-            "--token_list",
-            type=str_or_none,
-            default=None,
-            help="Integer-string mapper for tokens.",
-        )
-        group.add_argument(
-            "--split_with_space",
-            type=str2bool,
-            default=True,
-            help="whether to split text using <space>",
-        )
-        group.add_argument(
-            "--input_size",
-            type=int_or_none,
-            default=None,
-            help="The number of dimensions for input features.",
-        )
-        group.add_argument(
-            "--init",
-            type=str_or_none,
-            default=None,
-            help="Type of model initialization to use.",
-        )
-        group.add_argument(
-            "--model_conf",
-            action=NestedDictAction,
-            default=get_default_kwargs(TransducerModel),
-            help="The keyword arguments for the model class.",
-        )
-        # group.add_argument(
-        #     "--encoder_conf",
-        #     action=NestedDictAction,
-        #     default={},
-        #     help="The keyword arguments for the encoder class.",
-        # )
-        group.add_argument(
-            "--joint_network_conf",
-            action=NestedDictAction,
-            default={},
-            help="The keyword arguments for the joint network class.",
-        )
-        group = parser.add_argument_group(description="Preprocess related.")
-        group.add_argument(
-            "--use_preprocessor",
-            type=str2bool,
-            default=True,
-            help="Whether to apply preprocessing to input data.",
-        )
-        group.add_argument(
-            "--token_type",
-            type=str,
-            default="bpe",
-            choices=["bpe", "char", "word", "phn"],
-            help="The type of tokens to use during tokenization.",
-        )
-        group.add_argument(
-            "--bpemodel",
-            type=str_or_none,
-            default=None,
-            help="The path of the sentencepiece model.",
-        )
-        parser.add_argument(
-            "--non_linguistic_symbols",
-            type=str_or_none,
-            help="The 'non_linguistic_symbols' file path.",
-        )
-        parser.add_argument(
-            "--cleaner",
-            type=str_or_none,
-            choices=[None, "tacotron", "jaconv", "vietnamese"],
-            default=None,
-            help="Text cleaner to use.",
-        )
-        parser.add_argument(
-            "--g2p",
-            type=str_or_none,
-            choices=g2p_choices,
-            default=None,
-            help="g2p method to use if --token_type=phn.",
-        )
-        parser.add_argument(
-            "--speech_volume_normalize",
-            type=float_or_none,
-            default=None,
-            help="Normalization value for maximum amplitude scaling.",
-        )
-        parser.add_argument(
-            "--rir_scp",
-            type=str_or_none,
-            default=None,
-            help="The RIR SCP file path.",
-        )
-        parser.add_argument(
-            "--rir_apply_prob",
-            type=float,
-            default=1.0,
-            help="The probability of the applied RIR convolution.",
-        )
-        parser.add_argument(
-            "--noise_scp",
-            type=str_or_none,
-            default=None,
-            help="The path of noise SCP file.",
-        )
-        parser.add_argument(
-            "--noise_apply_prob",
-            type=float,
-            default=1.0,
-            help="The probability of the applied noise addition.",
-        )
-        parser.add_argument(
-            "--noise_db_range",
-            type=str,
-            default="13_15",
-            help="The range of the noise decibel level.",
-        )
-        for class_choices in cls.class_choices_list:
-            # Append --<name> and --<name>_conf.
-            # e.g. --decoder and --decoder_conf
-            class_choices.add_arguments(group)
-
-    @classmethod
-    def build_collate_fn(
-        cls, args: argparse.Namespace, train: bool
-    ) -> Callable[
-        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
-        Tuple[List[str], Dict[str, torch.Tensor]],
-    ]:
-        """Build collate function.
-        Args:
-            cls: ASRTransducerTask object.
-            args: Task arguments.
-            train: Training mode.
-        Return:
-            : Callable collate function.
-        """
-        assert check_argument_types()
-
-        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
-
-    @classmethod
-    def build_preprocess_fn(
-        cls, args: argparse.Namespace, train: bool
-    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
-        """Build pre-processing function.
-        Args:
-            cls: ASRTransducerTask object.
-            args: Task arguments.
-            train: Training mode.
-        Return:
-            : Callable pre-processing function.
-        """
-        assert check_argument_types()
-
-        if args.use_preprocessor:
-            retval = CommonPreprocessor(
-                train=train,
-                token_type=args.token_type,
-                token_list=args.token_list,
-                bpemodel=args.bpemodel,
-                non_linguistic_symbols=args.non_linguistic_symbols,
-                text_cleaner=args.cleaner,
-                g2p_type=args.g2p,
-                split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
-                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
-                rir_apply_prob=args.rir_apply_prob
-                if hasattr(args, "rir_apply_prob")
-                else 1.0,
-                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
-                noise_apply_prob=args.noise_apply_prob
-                if hasattr(args, "noise_apply_prob")
-                else 1.0,
-                noise_db_range=args.noise_db_range
-                if hasattr(args, "noise_db_range")
-                else "13_15",
-                speech_volume_normalize=args.speech_volume_normalize
-                if hasattr(args, "rir_scp")
-                else None,
-            )
-        else:
-            retval = None
-
-        assert check_return_type(retval)
-        return retval
-
-    @classmethod
-    def required_data_names(
-        cls, train: bool = True, inference: bool = False
-    ) -> Tuple[str, ...]:
-        """Required data depending on task mode.
-        Args:
-            cls: ASRTransducerTask object.
-            train: Training mode.
-            inference: Inference mode.
-        Return:
-            retval: Required task data.
-        """
-        if not inference:
-            retval = ("speech", "text")
-        else:
-            retval = ("speech",)
-
-        return retval
-
-    @classmethod
-    def optional_data_names(
-        cls, train: bool = True, inference: bool = False
-    ) -> Tuple[str, ...]:
-        """Optional data depending on task mode.
-        Args:
-            cls: ASRTransducerTask object.
-            train: Training mode.
-            inference: Inference mode.
-        Return:
-            retval: Optional task data.
-        """
-        retval = ()
-        assert check_return_type(retval)
-
-        return retval
-
-    @classmethod
-    def build_model(cls, args: argparse.Namespace) -> TransducerModel:
-        """Required data depending on task mode.
-        Args:
-            cls: ASRTransducerTask object.
-            args: Task arguments.
-        Return:
-            model: ASR Transducer model.
-        """
-        assert check_argument_types()
-
-        if isinstance(args.token_list, str):
-            with open(args.token_list, encoding="utf-8") as f:
-                token_list = [line.rstrip() for line in f]
-
-            # Overwriting token_list to keep it as "portable".
-            args.token_list = list(token_list)
-        elif isinstance(args.token_list, (tuple, list)):
-            token_list = list(args.token_list)
-        else:
-            raise RuntimeError("token_list must be str or list")
-        vocab_size = len(token_list)
-        logging.info(f"Vocabulary size: {vocab_size }")
-
-        # 1. frontend
-        if args.input_size is None:
-            # Extract features in the model
-            frontend_class = frontend_choices.get_class(args.frontend)
-            frontend = frontend_class(**args.frontend_conf)
-            input_size = frontend.output_size()
-        else:
-            # Give features from data-loader
-            frontend = None
-            input_size = args.input_size
-
-        # 2. Data augmentation for spectrogram
-        if args.specaug is not None:
-            specaug_class = specaug_choices.get_class(args.specaug)
-            specaug = specaug_class(**args.specaug_conf)
-        else:
-            specaug = None
-
-        # 3. Normalization layer
-        if args.normalize is not None:
-            normalize_class = normalize_choices.get_class(args.normalize)
-            normalize = normalize_class(**args.normalize_conf)
-        else:
-            normalize = None
-
-        # 4. Encoder
-        
-        if getattr(args, "encoder", None) is not None:
-            encoder_class = encoder_choices.get_class(args.encoder)
-            encoder = encoder_class(input_size, **args.encoder_conf)
-        else:
-            encoder = Encoder(input_size, **args.encoder_conf)
-        encoder_output_size = encoder.output_size
-
-        # 5. Decoder
-        decoder_class = decoder_choices.get_class(args.decoder)
-        decoder = decoder_class(
-            vocab_size,
-            **args.decoder_conf,
-        )
-        decoder_output_size = decoder.output_size
-
-        if getattr(args, "att_decoder", None) is not None:
-            att_decoder_class = att_decoder_choices.get_class(args.att_decoder)
-
-            att_decoder = att_decoder_class(
-                vocab_size=vocab_size,
-                encoder_output_size=encoder_output_size,
-                **args.att_decoder_conf,
-            )
-        else:
-            att_decoder = None
-
-        # 6. Joint Network
-        joint_network = JointNetwork(
-            vocab_size,
-            encoder_output_size,
-            decoder_output_size,
-            **args.joint_network_conf,
-        )
-
-        # 7. Build model
-
-        if encoder.unified_model_training:
-            model = UnifiedTransducerModel(
-                vocab_size=vocab_size,
-                token_list=token_list,
-                frontend=frontend,
-                specaug=specaug,
-                normalize=normalize,
-                encoder=encoder,
-                decoder=decoder,
-                att_decoder=att_decoder,
-                joint_network=joint_network,
-                **args.model_conf,
-            )
-
-        else:
-            model = TransducerModel(
-                vocab_size=vocab_size,
-                token_list=token_list,
-                frontend=frontend,
-                specaug=specaug,
-                normalize=normalize,
-                encoder=encoder,
-                decoder=decoder,
-                att_decoder=att_decoder,
-                joint_network=joint_network,
-                **args.model_conf,
-            )
-
-        # 8. Initialize model
-        if args.init is not None:
-            raise NotImplementedError(
-                "Currently not supported.",
-                "Initialization part will be reworked in a short future.",
-            )
-
-        #assert check_return_type(model)
-
-        return model

--
Gitblit v1.9.1