From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 16:49:56 +0800
Subject: [PATCH] rnnt reorg

---
 funasr/modules/nets_utils.py                                       |  195 ++++++++++++++
 funasr/bin/asr_inference_rnnt.py                                   |   58 ----
 funasr/models/rnnt_decoder/__init__.py                             |    0 
 funasr/models/encoder/chunk_encoder_utils/validation.py            |    2 
 funasr/models/encoder/chunk_encoder_blocks/conv1d.py               |    0 
 egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml            |   32 +-
 funasr/models/encoder/chunk_encoder_blocks/conformer.py            |    0 
 funasr/modules/activation.py                                       |    0 
 funasr/models/e2e_transducer.py                                    |   10 
 funasr/models/joint_network.py                                     |    2 
 funasr/modules/e2e_asr_common.py                                   |  151 +++++++++++
 funasr/models/encoder/chunk_encoder_blocks/linear_input.py         |    0 
 funasr/models/encoder/chunk_encoder_modules/positional_encoding.py |    0 
 funasr/models/encoder/chunk_encoder_blocks/conv_input.py           |    2 
 funasr/models/rnnt_decoder/abs_decoder.py                          |    0 
 funasr/models/encoder/chunk_encoder.py                             |   26 -
 funasr/models/encoder/chunk_encoder_modules/__init__.py            |    0 
 funasr/models/e2e_transducer_unified.py                            |   13 
 funasr/models/encoder/chunk_encoder_modules/attention.py           |    0 
 funasr/models/encoder/chunk_encoder_modules/normalization.py       |    0 
 funasr/models/rnnt_decoder/rnn_decoder.py                          |    4 
 funasr/models/encoder/chunk_encoder_modules/multi_blocks.py        |    0 
 funasr/models/encoder/chunk_encoder_blocks/branchformer.py         |    0 
 funasr/models/encoder/chunk_encoder_blocks/__init__.py             |    0 
 funasr/models/encoder/chunk_encoder_modules/convolution.py         |    0 
 /dev/null                                                          |  200 ---------------
 funasr/models/rnnt_decoder/stateless_decoder.py                    |   16 -
 funasr/modules/beam_search/beam_search_transducer.py               |    4 
 funasr/tasks/asr_transducer.py                                     |   41 --
 funasr/models/encoder/chunk_encoder_utils/building.py              |   22 
 30 files changed, 418 insertions(+), 360 deletions(-)

diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
index ef37b97..60f796c 100644
--- a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
+++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -1,13 +1,13 @@
 encoder_conf:
     main_conf:
       pos_wise_act_type: swish
-      pos_enc_dropout_rate: 0.3
+      pos_enc_dropout_rate: 0.5
       conv_mod_act_type: swish
       time_reduction_factor: 2
       unified_model_training: true
       default_chunk_size: 16
       jitter_range: 4
-      left_chunk_size: 1
+      left_chunk_size: 0
     input_conf:
       block_type: conv2d
       conv_size: 512
@@ -18,9 +18,9 @@
       linear_size: 2048
       hidden_size: 512
       heads: 8
-      dropout_rate: 0.3
-      pos_wise_dropout_rate: 0.3
-      att_dropout_rate: 0.3
+      dropout_rate: 0.5
+      pos_wise_dropout_rate: 0.5
+      att_dropout_rate: 0.5
       conv_mod_kernel_size: 15
       num_blocks: 12    
 
@@ -29,8 +29,8 @@
 decoder_conf:
     embed_size: 512
     hidden_size: 512
-    embed_dropout_rate: 0.2
-    dropout_rate: 0.1
+    embed_dropout_rate: 0.5
+    dropout_rate: 0.5
 
 joint_network_conf:
     joint_space_size: 512
@@ -41,14 +41,14 @@
 
 # minibatch related
 use_amp: true
-batch_type: numel
-batch_bins: 1600000
+batch_type: unsorted
+batch_size: 16
 num_workers: 16
 
 # optimization related
 accum_grad: 1
 grad_clip: 5
-max_epoch: 80
+max_epoch: 200
 val_scheduler_criterion:
     - valid
     - loss
@@ -56,11 +56,11 @@
 -   - valid
     - cer_transducer_chunk
     - min
-keep_nbest_models: 5
+keep_nbest_models: 10
 
 optim: adam
 optim_conf:
-   lr: 0.0003
+   lr: 0.001
 scheduler: warmuplr
 scheduler_conf:
    warmup_steps: 25000
@@ -75,10 +75,12 @@
     apply_freq_mask: true
     freq_mask_width_range:
     - 0
-    - 30
+    - 40
     num_freq_mask: 2
     apply_time_mask: true
     time_mask_width_range:
     - 0
-    - 40
-    num_time_mask: 2
+    - 50
+    num_time_mask: 5
+
+log_interval: 50
diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index 768bf72..465f882 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -16,11 +16,11 @@
 from packaging.version import parse as V
 from typeguard import check_argument_types, check_return_type
 
-from funasr.models_transducer.beam_search_transducer import (
+from funasr.modules.beam_search.beam_search_transducer import (
     BeamSearchTransducer,
     Hypothesis,
 )
-from funasr.models_transducer.utils import TooShortUttError
+from funasr.modules.nets_utils import TooShortUttError
 from funasr.fileio.datadir_writer import DatadirWriter
 from funasr.tasks.asr_transducer import ASRTransducerTask
 from funasr.tasks.lm import LMTask
@@ -500,7 +500,6 @@
 
             _bs = len(next(iter(batch.values())))
             assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-<<<<<<< HEAD
             batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
             assert len(batch.keys()) == 1
 
@@ -541,59 +540,6 @@
 
                 if text is not None:
                     ibest_writer["text"][key] = text
-=======
-            # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
-
-            logging.info("decoding, utt_id: {}".format(keys))
-            # N-best list of (text, token, token_int, hyp_object)
-
-            time_beg = time.time()
-            results = speech2text(cache=cache, **batch)
-            if len(results) < 1:
-                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
-                results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
-            time_end = time.time()
-            forward_time = time_end - time_beg
-            lfr_factor = results[0][-1]
-            length = results[0][-2]
-            forward_time_total += forward_time
-            length_total += length
-            rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor))
-            logging.info(rtf_cur)
-
-            for batch_id in range(_bs):
-                result = [results[batch_id][:-2]]
-
-                key = keys[batch_id]
-                for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result):
-                    # Create a directory: outdir/{n}best_recog
-                    if writer is not None:
-                        ibest_writer = writer[f"{n}best_recog"]
-
-                        # Write the result to each file
-                        ibest_writer["token"][key] = " ".join(token)
-                        # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-                        ibest_writer["score"][key] = str(hyp.score)
-                        ibest_writer["rtf"][key] = rtf_cur
-
-                    if text is not None:
-                        text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
-                        item = {'key': key, 'value': text_postprocessed}
-                        asr_result_list.append(item)
-                        finish_count += 1
-                        # asr_utils.print_progress(finish_count / file_count)
-                        if writer is not None:
-                            ibest_writer["text"][key] = " ".join(word_lists)
-
-                    logging.info("decoding, utt: {}, predictions: {}".format(key, text))
-        rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
-        logging.info(rtf_avg)
-        if writer is not None:
-            ibest_writer["rtf"]["rtf_avf"] = rtf_avg
-        return asr_result_list
-
-    return _forward
->>>>>>> main
 
 
 def get_parser():
diff --git a/funasr/models_transducer/espnet_transducer_model.py b/funasr/models/e2e_transducer.py
similarity index 97%
rename from funasr/models_transducer/espnet_transducer_model.py
rename to funasr/models/e2e_transducer.py
index e32f6e3..b669c9d 100644
--- a/funasr/models_transducer/espnet_transducer_model.py
+++ b/funasr/models/e2e_transducer.py
@@ -10,11 +10,11 @@
 
 from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
 from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.joint_network import JointNetwork
-from funasr.models_transducer.utils import get_transducer_task_io
+from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.joint_network import JointNetwork
+from funasr.modules.nets_utils import get_transducer_task_io
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.torch_utils.device_funcs import force_gatherable
 from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -28,7 +28,7 @@
         yield
 
 
-class ESPnetASRTransducerModel(AbsESPnetModel):
+class TransducerModel(AbsESPnetModel):
     """ESPnet2ASRTransducerModel module definition.
 
     Args:
diff --git a/funasr/models_transducer/espnet_transducer_model_unified.py b/funasr/models/e2e_transducer_unified.py
similarity index 97%
rename from funasr/models_transducer/espnet_transducer_model_unified.py
rename to funasr/models/e2e_transducer_unified.py
index be61e83..6003542 100644
--- a/funasr/models_transducer/espnet_transducer_model_unified.py
+++ b/funasr/models/e2e_transducer_unified.py
@@ -10,10 +10,10 @@
 
 from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.joint_network import JointNetwork
-from funasr.models_transducer.utils import get_transducer_task_io
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.joint_network import JointNetwork
+from funasr.modules.nets_utils import get_transducer_task_io
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.torch_utils.device_funcs import force_gatherable
 from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -23,7 +23,7 @@
 from funasr.losses.label_smoothing_loss import (  # noqa: H301
     LabelSmoothingLoss,
 )
-from funasr.models_transducer.error_calculator import ErrorCalculator
+from funasr.modules.e2e_asr_common import ErrorCalculatorTransducer as ErrorCalculator
 if V(torch.__version__) >= V("1.6.0"):
     from torch.cuda.amp import autocast
 else:
@@ -33,7 +33,7 @@
         yield
 
 
-class ESPnetASRUnifiedTransducerModel(AbsESPnetModel):
+class UnifiedTransducerModel(AbsESPnetModel):
     """ESPnet2ASRTransducerModel module definition.
 
     Args:
@@ -289,7 +289,6 @@
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-
         return loss, stats, weight
 
     def collect_feats(
diff --git a/funasr/models_transducer/encoder/encoder.py b/funasr/models/encoder/chunk_encoder.py
similarity index 95%
rename from funasr/models_transducer/encoder/encoder.py
rename to funasr/models/encoder/chunk_encoder.py
index b486a11..c6fc292 100644
--- a/funasr/models_transducer/encoder/encoder.py
+++ b/funasr/models/encoder/chunk_encoder.py
@@ -1,26 +1,23 @@
-"""Encoder for Transducer model."""
-
 from typing import Any, Dict, List, Tuple
 
 import torch
 from typeguard import check_argument_types
 
-from funasr.models_transducer.encoder.building import (
+from funasr.models.encoder.chunk_encoder_utils.building import (
     build_body_blocks,
     build_input_block,
     build_main_parameters,
     build_positional_encoding,
 )
-from funasr.models_transducer.encoder.validation import validate_architecture
-from funasr.models_transducer.utils import (
+from funasr.models.encoder.chunk_encoder_utils.validation import validate_architecture
+from funasr.modules.nets_utils import (
     TooShortUttError,
     check_short_utt,
     make_chunk_mask,
     make_source_mask,
 )
 
-
-class Encoder(torch.nn.Module):
+class ChunkEncoder(torch.nn.Module):
     """Encoder module definition.
 
     Args:
@@ -61,10 +58,9 @@
 
         self.unified_model_training = main_params["unified_model_training"]
         self.default_chunk_size = main_params["default_chunk_size"]
-        self.jitter_range = main_params["jitter_range"]       
+        self.jitter_range = main_params["jitter_range"]
 
-        self.time_reduction_factor = main_params["time_reduction_factor"] 
-
+        self.time_reduction_factor = main_params["time_reduction_factor"]
     def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
         """Return the corresponding number of sample for a given chunk size, in frames.
 
@@ -79,7 +75,7 @@
 
         """
         return self.embed.get_size_before_subsampling(size) * hop_length
-    
+
     def get_encoder_input_size(self, size: int) -> int:
         """Return the corresponding number of sample for a given chunk size, in frames.
 
@@ -157,7 +153,7 @@
                 mask,
                 chunk_mask=chunk_mask,
             )
-       
+
             olens = mask.eq(0).sum(1)
             if self.time_reduction_factor > 1:
                 x_utt = x_utt[:,::self.time_reduction_factor,:]
@@ -194,14 +190,14 @@
             mask,
             chunk_mask=chunk_mask,
         )
-        
+
         olens = mask.eq(0).sum(1)
         if self.time_reduction_factor > 1:
             x = x[:,::self.time_reduction_factor,:]
             olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
 
         return x, olens
-     
+
     def simu_chunk_forward(
         self,
         x: torch.Tensor,
@@ -290,7 +286,7 @@
 
         if right_context > 0:
             x = x[:, 0:-right_context, :]
-        
+
         if self.time_reduction_factor > 1:
             x = x[:,::self.time_reduction_factor,:]
         return x
diff --git a/funasr/models_transducer/encoder/blocks/__init__.py b/funasr/models/encoder/chunk_encoder_blocks/__init__.py
similarity index 100%
rename from funasr/models_transducer/encoder/blocks/__init__.py
rename to funasr/models/encoder/chunk_encoder_blocks/__init__.py
diff --git a/funasr/models_transducer/encoder/blocks/branchformer.py b/funasr/models/encoder/chunk_encoder_blocks/branchformer.py
similarity index 100%
rename from funasr/models_transducer/encoder/blocks/branchformer.py
rename to funasr/models/encoder/chunk_encoder_blocks/branchformer.py
diff --git a/funasr/models_transducer/encoder/blocks/conformer.py b/funasr/models/encoder/chunk_encoder_blocks/conformer.py
similarity index 100%
rename from funasr/models_transducer/encoder/blocks/conformer.py
rename to funasr/models/encoder/chunk_encoder_blocks/conformer.py
diff --git a/funasr/models_transducer/encoder/blocks/conv1d.py b/funasr/models/encoder/chunk_encoder_blocks/conv1d.py
similarity index 100%
rename from funasr/models_transducer/encoder/blocks/conv1d.py
rename to funasr/models/encoder/chunk_encoder_blocks/conv1d.py
diff --git a/funasr/models_transducer/encoder/blocks/conv_input.py b/funasr/models/encoder/chunk_encoder_blocks/conv_input.py
similarity index 98%
rename from funasr/models_transducer/encoder/blocks/conv_input.py
rename to funasr/models/encoder/chunk_encoder_blocks/conv_input.py
index ffec93e..b9bd2fd 100644
--- a/funasr/models_transducer/encoder/blocks/conv_input.py
+++ b/funasr/models/encoder/chunk_encoder_blocks/conv_input.py
@@ -5,7 +5,7 @@
 import torch
 import math
 
-from funasr.models_transducer.utils import sub_factor_to_params, pad_to_len
+from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
 
 
 class ConvInput(torch.nn.Module):
diff --git a/funasr/models_transducer/encoder/blocks/linear_input.py b/funasr/models/encoder/chunk_encoder_blocks/linear_input.py
similarity index 100%
rename from funasr/models_transducer/encoder/blocks/linear_input.py
rename to funasr/models/encoder/chunk_encoder_blocks/linear_input.py
diff --git a/funasr/models_transducer/encoder/modules/__init__.py b/funasr/models/encoder/chunk_encoder_modules/__init__.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/__init__.py
rename to funasr/models/encoder/chunk_encoder_modules/__init__.py
diff --git a/funasr/models_transducer/encoder/modules/attention.py b/funasr/models/encoder/chunk_encoder_modules/attention.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/attention.py
rename to funasr/models/encoder/chunk_encoder_modules/attention.py
diff --git a/funasr/models_transducer/encoder/modules/convolution.py b/funasr/models/encoder/chunk_encoder_modules/convolution.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/convolution.py
rename to funasr/models/encoder/chunk_encoder_modules/convolution.py
diff --git a/funasr/models_transducer/encoder/modules/multi_blocks.py b/funasr/models/encoder/chunk_encoder_modules/multi_blocks.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/multi_blocks.py
rename to funasr/models/encoder/chunk_encoder_modules/multi_blocks.py
diff --git a/funasr/models_transducer/encoder/modules/normalization.py b/funasr/models/encoder/chunk_encoder_modules/normalization.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/normalization.py
rename to funasr/models/encoder/chunk_encoder_modules/normalization.py
diff --git a/funasr/models_transducer/encoder/modules/positional_encoding.py b/funasr/models/encoder/chunk_encoder_modules/positional_encoding.py
similarity index 100%
rename from funasr/models_transducer/encoder/modules/positional_encoding.py
rename to funasr/models/encoder/chunk_encoder_modules/positional_encoding.py
diff --git a/funasr/models_transducer/encoder/building.py b/funasr/models/encoder/chunk_encoder_utils/building.py
similarity index 91%
rename from funasr/models_transducer/encoder/building.py
rename to funasr/models/encoder/chunk_encoder_utils/building.py
index a19943b..21611aa 100644
--- a/funasr/models_transducer/encoder/building.py
+++ b/funasr/models/encoder/chunk_encoder_utils/building.py
@@ -2,22 +2,22 @@
 
 from typing import Any, Dict, List, Optional, Union
 
-from funasr.models_transducer.activation import get_activation
-from funasr.models_transducer.encoder.blocks.branchformer import Branchformer
-from funasr.models_transducer.encoder.blocks.conformer import Conformer
-from funasr.models_transducer.encoder.blocks.conv1d import Conv1d
-from funasr.models_transducer.encoder.blocks.conv_input import ConvInput
-from funasr.models_transducer.encoder.blocks.linear_input import LinearInput
-from funasr.models_transducer.encoder.modules.attention import (  # noqa: H301
+from funasr.modules.activation import get_activation
+from funasr.models.encoder.chunk_encoder_blocks.branchformer import Branchformer
+from funasr.models.encoder.chunk_encoder_blocks.conformer import Conformer
+from funasr.models.encoder.chunk_encoder_blocks.conv1d import Conv1d
+from funasr.models.encoder.chunk_encoder_blocks.conv_input import ConvInput
+from funasr.models.encoder.chunk_encoder_blocks.linear_input import LinearInput
+from funasr.models.encoder.chunk_encoder_modules.attention import (  # noqa: H301
     RelPositionMultiHeadedAttention,
 )
-from funasr.models_transducer.encoder.modules.convolution import (  # noqa: H301
+from funasr.models.encoder.chunk_encoder_modules.convolution import (  # noqa: H301
     ConformerConvolution,
     ConvolutionalSpatialGatingUnit,
 )
-from funasr.models_transducer.encoder.modules.multi_blocks import MultiBlocks
-from funasr.models_transducer.encoder.modules.normalization import get_normalization
-from funasr.models_transducer.encoder.modules.positional_encoding import (  # noqa: H301
+from funasr.models.encoder.chunk_encoder_modules.multi_blocks import MultiBlocks
+from funasr.models.encoder.chunk_encoder_modules.normalization import get_normalization
+from funasr.models.encoder.chunk_encoder_modules.positional_encoding import (  # noqa: H301
     RelPositionalEncoding,
 )
 from funasr.modules.positionwise_feed_forward import (
diff --git a/funasr/models_transducer/encoder/validation.py b/funasr/models/encoder/chunk_encoder_utils/validation.py
similarity index 98%
rename from funasr/models_transducer/encoder/validation.py
rename to funasr/models/encoder/chunk_encoder_utils/validation.py
index 0003536..1103cb9 100644
--- a/funasr/models_transducer/encoder/validation.py
+++ b/funasr/models/encoder/chunk_encoder_utils/validation.py
@@ -2,7 +2,7 @@
 
 from typing import Any, Dict, List, Tuple
 
-from funasr.models_transducer.utils import sub_factor_to_params
+from funasr.modules.nets_utils import sub_factor_to_params
 
 
 def validate_block_arguments(
diff --git a/funasr/models_transducer/joint_network.py b/funasr/models/joint_network.py
similarity index 96%
rename from funasr/models_transducer/joint_network.py
rename to funasr/models/joint_network.py
index 119dd84..5cabdb4 100644
--- a/funasr/models_transducer/joint_network.py
+++ b/funasr/models/joint_network.py
@@ -2,7 +2,7 @@
 
 import torch
 
-from funasr.models_transducer.activation import get_activation
+from funasr.modules.activation import get_activation
 
 
 class JointNetwork(torch.nn.Module):
diff --git a/funasr/models_transducer/decoder/__init__.py b/funasr/models/rnnt_decoder/__init__.py
similarity index 100%
rename from funasr/models_transducer/decoder/__init__.py
rename to funasr/models/rnnt_decoder/__init__.py
diff --git a/funasr/models_transducer/decoder/abs_decoder.py b/funasr/models/rnnt_decoder/abs_decoder.py
similarity index 100%
rename from funasr/models_transducer/decoder/abs_decoder.py
rename to funasr/models/rnnt_decoder/abs_decoder.py
diff --git a/funasr/models_transducer/decoder/rnn_decoder.py b/funasr/models/rnnt_decoder/rnn_decoder.py
similarity index 97%
rename from funasr/models_transducer/decoder/rnn_decoder.py
rename to funasr/models/rnnt_decoder/rnn_decoder.py
index 04c3228..c4e7951 100644
--- a/funasr/models_transducer/decoder/rnn_decoder.py
+++ b/funasr/models/rnnt_decoder/rnn_decoder.py
@@ -5,8 +5,8 @@
 import torch
 from typeguard import check_argument_types
 
-from funasr.models_transducer.beam_search_transducer import Hypothesis
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.modules.beam_search.beam_search_transducer import Hypothesis
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
 from funasr.models.specaug.specaug import SpecAug
 
 class RNNDecoder(AbsDecoder):
diff --git a/funasr/models_transducer/decoder/stateless_decoder.py b/funasr/models/rnnt_decoder/stateless_decoder.py
similarity index 86%
rename from funasr/models_transducer/decoder/stateless_decoder.py
rename to funasr/models/rnnt_decoder/stateless_decoder.py
index 07c8f51..a2e1fc1 100644
--- a/funasr/models_transducer/decoder/stateless_decoder.py
+++ b/funasr/models/rnnt_decoder/stateless_decoder.py
@@ -5,8 +5,8 @@
 import torch
 from typeguard import check_argument_types
 
-from funasr.models_transducer.beam_search_transducer import Hypothesis
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.modules.beam_search.beam_search_transducer import Hypothesis
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
 from funasr.models.specaug.specaug import SpecAug
 
 class StatelessDecoder(AbsDecoder):
@@ -26,7 +26,6 @@
         embed_size: int = 256,
         embed_dropout_rate: float = 0.0,
         embed_pad: int = 0,
-        use_embed_mask: bool = False,
     ) -> None:
         """Construct a StatelessDecoder object."""
         super().__init__()
@@ -42,14 +41,6 @@
         self.device = next(self.parameters()).device
         self.score_cache = {}
 
-        self.use_embed_mask = use_embed_mask
-        if self.use_embed_mask:
-            self._embed_mask = SpecAug(
-                time_mask_width_range=3,
-                num_time_mask=1,
-                apply_freq_mask=False,
-                apply_time_warp=False
-            )
 
 
     def forward(
@@ -69,9 +60,6 @@
 
         """
         dec_embed = self.embed_dropout_rate(self.embed(labels))
-        if self.use_embed_mask and self.training:
-            dec_embed = self._embed_mask(dec_embed, label_lens)[0]
-
         return dec_embed
 
     def score(
diff --git a/funasr/models_transducer/__init__.py b/funasr/models_transducer/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models_transducer/__init__.py
+++ /dev/null
diff --git a/funasr/models_transducer/encoder/__init__.py b/funasr/models_transducer/encoder/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models_transducer/encoder/__init__.py
+++ /dev/null
diff --git a/funasr/models_transducer/encoder/sanm_encoder.py b/funasr/models_transducer/encoder/sanm_encoder.py
deleted file mode 100644
index 9e74bdf..0000000
--- a/funasr/models_transducer/encoder/sanm_encoder.py
+++ /dev/null
@@ -1,835 +0,0 @@
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-import logging
-import torch
-import torch.nn as nn
-from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
-import numpy as np
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
-from funasr.modules.embedding import SinusoidalPositionEncoder
-from funasr.modules.layer_norm import LayerNorm
-from funasr.modules.multi_layer_conv import Conv1dLinear
-from funasr.modules.multi_layer_conv import MultiLayeredConv1d
-from funasr.modules.positionwise_feed_forward import (
-    PositionwiseFeedForward,  # noqa: H301
-)
-from funasr.modules.repeat import repeat
-from funasr.modules.subsampling import Conv2dSubsampling
-from funasr.modules.subsampling import Conv2dSubsampling2
-from funasr.modules.subsampling import Conv2dSubsampling6
-from funasr.modules.subsampling import Conv2dSubsampling8
-from funasr.modules.subsampling import TooShortUttError
-from funasr.modules.subsampling import check_short_utt
-from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
-
-
-class EncoderLayerSANM(nn.Module):
-    def __init__(
-        self,
-        in_size,
-        size,
-        self_attn,
-        feed_forward,
-        dropout_rate,
-        normalize_before=True,
-        concat_after=False,
-        stochastic_depth_rate=0.0,
-    ):
-        """Construct an EncoderLayer object."""
-        super(EncoderLayerSANM, self).__init__()
-        self.self_attn = self_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(in_size)
-        self.norm2 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.in_size = in_size
-        self.size = size
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        if self.concat_after:
-            self.concat_linear = nn.Linear(size + size, size)
-        self.stochastic_depth_rate = stochastic_depth_rate
-        self.dropout_rate = dropout_rate
-
-    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
-        """Compute encoded features.
-        Args:
-            x_input (torch.Tensor): Input tensor (#batch, time, size).
-            mask (torch.Tensor): Mask tensor for the input (#batch, time).
-            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-        Returns:
-            torch.Tensor: Output tensor (#batch, time, size).
-            torch.Tensor: Mask tensor (#batch, time).
-        """
-        skip_layer = False
-        # with stochastic depth, residual connection `x + f(x)` becomes
-        # `x <- x + 1 / (1 - p) * f(x)` at training time.
-        stoch_layer_coeff = 1.0
-        if self.training and self.stochastic_depth_rate > 0:
-            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
-            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
-        if skip_layer:
-            if cache is not None:
-                x = torch.cat([cache, x], dim=1)
-            return x, mask
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm1(x)
-
-        if self.concat_after:
-            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
-            if self.in_size == self.size:
-                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
-            else:
-                x = stoch_layer_coeff * self.concat_linear(x_concat)
-        else:
-            if self.in_size == self.size:
-                x = residual + stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
-                )
-            else:
-                x = stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
-                )
-        if not self.normalize_before:
-            x = self.norm1(x)
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm2(x)
-        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm2(x)
-
-
-        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
-
-class SANMEncoder(AbsEncoder):
-    """
-    author: Speech Lab, Alibaba Group, China
-    San-m: Memory equipped self-attention for end-to-end speech recognition
-    https://arxiv.org/abs/2006.01713
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        output_size: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        attention_dropout_rate: float = 0.0,
-        pos_enc_class=SinusoidalPositionEncoder,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        positionwise_layer_type: str = "linear",
-        positionwise_conv_kernel_size: int = 1,
-        padding_idx: int = -1,
-        interctc_layer_idx: List[int] = [],
-        interctc_use_conditioning: bool = False,
-        kernel_size : int = 11,
-        sanm_shfit : int = 0,
-        tf2torch_tensor_name_prefix_torch: str = "encoder",
-        tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
-    ):
-        assert check_argument_types()
-        super().__init__()
-
-        self.embed = SinusoidalPositionEncoder()
-        self.normalize_before = normalize_before
-        if positionwise_layer_type == "linear":
-            positionwise_layer = PositionwiseFeedForward
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d":
-            positionwise_layer = MultiLayeredConv1d
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d-linear":
-            positionwise_layer = Conv1dLinear
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        else:
-            raise NotImplementedError("Support only linear or conv1d.")
-
-        encoder_selfattn_layer = MultiHeadedAttentionSANM
-        encoder_selfattn_layer_args0 = (
-            attention_heads,
-            input_size,
-            output_size,
-            attention_dropout_rate,
-            kernel_size,
-            sanm_shfit,
-        )
-
-        encoder_selfattn_layer_args = (
-            attention_heads,
-            output_size,
-            output_size,
-            attention_dropout_rate,
-            kernel_size,
-            sanm_shfit,
-        )
-        self.encoders0 = repeat(
-            1,
-            lambda lnum: EncoderLayerSANM(
-                input_size,
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-
-        self.encoders = repeat(
-            num_blocks-1,
-            lambda lnum: EncoderLayerSANM(
-                output_size,
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-
-        self.interctc_layer_idx = interctc_layer_idx
-        if len(interctc_layer_idx) > 0:
-            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
-        self.interctc_use_conditioning = interctc_use_conditioning
-        self.conditioning_layer = None
-        self.dropout = nn.Dropout(dropout_rate)
-        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
-        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-
-    def forward(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        prev_states: torch.Tensor = None,
-        ctc: CTC = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Embed positions in tensor.
-        Args:
-            xs_pad: input tensor (B, L, D)
-            ilens: input length (B)
-            prev_states: Not to be used now.
-        Returns:
-            position embedded tensor and mask
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad = xs_pad * self.output_size**0.5
-        if self.embed is None:
-            xs_pad = xs_pad
-        elif (
-            isinstance(self.embed, Conv2dSubsampling)
-            or isinstance(self.embed, Conv2dSubsampling2)
-            or isinstance(self.embed, Conv2dSubsampling6)
-            or isinstance(self.embed, Conv2dSubsampling8)
-        ):
-            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
-            if short_status:
-                raise TooShortUttError(
-                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
-                    + f"(it needs more than {limit_size} frames), return empty results",
-                    xs_pad.size(1),
-                    limit_size,
-                )
-            xs_pad, masks = self.embed(xs_pad, masks)
-        else:
-            xs_pad = self.embed(xs_pad)
-
-        # xs_pad = self.dropout(xs_pad)
-        encoder_outs = self.encoders0(xs_pad, masks)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        intermediate_outs = []
-        if len(self.interctc_layer_idx) == 0:
-            encoder_outs = self.encoders(xs_pad, masks)
-            xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        else:
-            for layer_idx, encoder_layer in enumerate(self.encoders):
-                encoder_outs = encoder_layer(xs_pad, masks)
-                xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
-                if layer_idx + 1 in self.interctc_layer_idx:
-                    encoder_out = xs_pad
-
-                    # intermediate outputs are also normalized
-                    if self.normalize_before:
-                        encoder_out = self.after_norm(encoder_out)
-
-                    intermediate_outs.append((layer_idx + 1, encoder_out))
-
-                    if self.interctc_use_conditioning:
-                        ctc_out = ctc.softmax(encoder_out)
-                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-
-        olens = masks.squeeze(1).sum(1)
-        if len(intermediate_outs) > 0:
-            return (xs_pad, intermediate_outs), olens, None
-        return xs_pad, olens
-
-    def gen_tf2torch_map_dict(self):
-        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
-        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
-        map_dict_local = {
-            ## encoder
-            # cicd
-            "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (768,256),(1,256,768)
-            "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (768,),(768,)
-            "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 2, 0),
-                 },  # (256,1,31),(1,31,256,1)
-            "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # ffn
-            "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-            "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # out norm
-            "{}.after_norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.after_norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-        
-        }
-    
-        return map_dict_local
-
-    def convert_tf2torch(self,
-                         var_dict_tf,
-                         var_dict_torch,
-                         ):
-        
-        map_dict = self.gen_tf2torch_map_dict()
-    
-        var_dict_torch_update = dict()
-        for name in sorted(var_dict_torch.keys(), reverse=False):
-            names = name.split('.')
-            if names[0] == self.tf2torch_tensor_name_prefix_torch:
-                if names[1] == "encoders0":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                
-                    name_q = name_q.replace("encoders0", "encoders")
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-                elif names[1] == "encoders":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                    layeridx_bias = 1
-                    layeridx += layeridx_bias
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "after_norm":
-                    name_tf = map_dict[name]["name"]
-                    data_tf = var_dict_tf[name_tf]
-                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                    var_dict_torch_update[name] = data_tf
-                    logging.info(
-                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                      var_dict_tf[name_tf].shape))
-    
-        return var_dict_torch_update
-
-
-class SANMEncoderChunkOpt(AbsEncoder):
-    """
-    author: Speech Lab, Alibaba Group, China
-    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
-    https://arxiv.org/abs/2006.01713
-    """
-
-    def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            pos_enc_class=SinusoidalPositionEncoder,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            positionwise_layer_type: str = "linear",
-            positionwise_conv_kernel_size: int = 1,
-            padding_idx: int = -1,
-            interctc_layer_idx: List[int] = [],
-            interctc_use_conditioning: bool = False,
-            kernel_size: int = 11,
-            sanm_shfit: int = 0,
-            chunk_size: Union[int, Sequence[int]] = (16,),
-            stride: Union[int, Sequence[int]] = (10,),
-            pad_left: Union[int, Sequence[int]] = (0,),
-            time_reduction_factor: int = 1,
-            encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
-            decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
-            tf2torch_tensor_name_prefix_torch: str = "encoder",
-            tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
-    ):
-        assert check_argument_types()
-        super().__init__()
-        self.output_size = output_size
-
-        self.embed = SinusoidalPositionEncoder()
-        
-        self.normalize_before = normalize_before
-        if positionwise_layer_type == "linear":
-            positionwise_layer = PositionwiseFeedForward
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d":
-            positionwise_layer = MultiLayeredConv1d
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d-linear":
-            positionwise_layer = Conv1dLinear
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        else:
-            raise NotImplementedError("Support only linear or conv1d.")
-
-        encoder_selfattn_layer = MultiHeadedAttentionSANM
-        encoder_selfattn_layer_args0 = (
-            attention_heads,
-            input_size,
-            output_size,
-            attention_dropout_rate,
-            kernel_size,
-            sanm_shfit,
-        )
-
-        encoder_selfattn_layer_args = (
-            attention_heads,
-            output_size,
-            output_size,
-            attention_dropout_rate,
-            kernel_size,
-            sanm_shfit,
-        )
-        self.encoders0 = repeat(
-            1,
-            lambda lnum: EncoderLayerSANM(
-                input_size,
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-
-        self.encoders = repeat(
-            num_blocks - 1,
-            lambda lnum: EncoderLayerSANM(
-                output_size,
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-
-        self.interctc_layer_idx = interctc_layer_idx
-        if len(interctc_layer_idx) > 0:
-            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
-        self.interctc_use_conditioning = interctc_use_conditioning
-        self.conditioning_layer = None
-        shfit_fsmn = (kernel_size - 1) // 2
-        self.overlap_chunk_cls = overlap_chunk(
-            chunk_size=chunk_size,
-            stride=stride,
-            pad_left=pad_left,
-            shfit_fsmn=shfit_fsmn,
-            encoder_att_look_back_factor=encoder_att_look_back_factor,
-            decoder_att_look_back_factor=decoder_att_look_back_factor,
-        )
-        self.time_reduction_factor = time_reduction_factor
-        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
-        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-
-    def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            prev_states: torch.Tensor = None,
-            ctc: CTC = None,
-            ind: int = 0,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Embed positions in tensor.
-        Args:
-            xs_pad: input tensor (B, L, D)
-            ilens: input length (B)
-            prev_states: Not to be used now.
-        Returns:
-            position embedded tensor and mask
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad *= self.output_size ** 0.5
-        if self.embed is None:
-            xs_pad = xs_pad
-        elif (
-                isinstance(self.embed, Conv2dSubsampling)
-                or isinstance(self.embed, Conv2dSubsampling2)
-                or isinstance(self.embed, Conv2dSubsampling6)
-                or isinstance(self.embed, Conv2dSubsampling8)
-        ):
-            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
-            if short_status:
-                raise TooShortUttError(
-                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
-                    + f"(it needs more than {limit_size} frames), return empty results",
-                    xs_pad.size(1),
-                    limit_size,
-                )
-            xs_pad, masks = self.embed(xs_pad, masks)
-        else:
-            xs_pad = self.embed(xs_pad)
-
-        mask_shfit_chunk, mask_att_chunk_encoder = None, None
-        if self.overlap_chunk_cls is not None:
-            ilens = masks.squeeze(1).sum(1)
-            chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
-            xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
-            masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-            mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
-                                                                           dtype=xs_pad.dtype)
-            mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
-                                                                                       xs_pad.size(0),
-                                                                                       dtype=xs_pad.dtype)
-
-        encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        intermediate_outs = []
-        if len(self.interctc_layer_idx) == 0:
-            encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
-            xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        else:
-            for layer_idx, encoder_layer in enumerate(self.encoders):
-                encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
-                xs_pad, masks = encoder_outs[0], encoder_outs[1]
-                if layer_idx + 1 in self.interctc_layer_idx:
-                    encoder_out = xs_pad
-
-                    # intermediate outputs are also normalized
-                    if self.normalize_before:
-                        encoder_out = self.after_norm(encoder_out)
-
-                    intermediate_outs.append((layer_idx + 1, encoder_out))
-
-                    if self.interctc_use_conditioning:
-                        ctc_out = ctc.softmax(encoder_out)
-                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-        
-        olens = masks.squeeze(1).sum(1)
-
-        xs_pad, olens = self.overlap_chunk_cls.remove_chunk(xs_pad, olens, chunk_outs=None)
-
-        if self.time_reduction_factor > 1:
-            xs_pad = xs_pad[:,::self.time_reduction_factor,:]
-            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
-
-        if len(intermediate_outs) > 0:
-            return (xs_pad, intermediate_outs), olens, None
-        return xs_pad, olens
-
-    def gen_tf2torch_map_dict(self):
-        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
-        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
-        map_dict_local = {
-            ## encoder
-            # cicd
-            "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (768,256),(1,256,768)
-            "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (768,),(768,)
-            "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 2, 0),
-                 },  # (256,1,31),(1,31,256,1)
-            "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # ffn
-            "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-            "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # out norm
-            "{}.after_norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.after_norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-        
-        }
-    
-        return map_dict_local
-
-    def convert_tf2torch(self,
-                         var_dict_tf,
-                         var_dict_torch,
-                         ):
-    
-        map_dict = self.gen_tf2torch_map_dict()
-    
-        var_dict_torch_update = dict()
-        for name in sorted(var_dict_torch.keys(), reverse=False):
-            names = name.split('.')
-            if names[0] == self.tf2torch_tensor_name_prefix_torch:
-                if names[1] == "encoders0":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                
-                    name_q = name_q.replace("encoders0", "encoders")
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-                elif names[1] == "encoders":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                    layeridx_bias = 1
-                    layeridx += layeridx_bias
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "after_norm":
-                    name_tf = map_dict[name]["name"]
-                    data_tf = var_dict_tf[name_tf]
-                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                    var_dict_torch_update[name] = data_tf
-                    logging.info(
-                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                      var_dict_tf[name_tf].shape))
-    
-        return var_dict_torch_update
diff --git a/funasr/models_transducer/error_calculator.py b/funasr/models_transducer/error_calculator.py
deleted file mode 100644
index 34b1dc7..0000000
--- a/funasr/models_transducer/error_calculator.py
+++ /dev/null
@@ -1,169 +0,0 @@
-"""Error Calculator module for Transducer."""
-
-from typing import List, Optional, Tuple
-
-import torch
-
-from funasr.models_transducer.beam_search_transducer import BeamSearchTransducer
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.joint_network import JointNetwork
-
-
-class ErrorCalculator:
-    """Calculate CER and WER for transducer models.
-
-    Args:
-        decoder: Decoder module.
-        joint_network: Joint Network module.
-        token_list: List of token units.
-        sym_space: Space symbol.
-        sym_blank: Blank symbol.
-        report_cer: Whether to compute CER.
-        report_wer: Whether to compute WER.
-
-    """
-
-    def __init__(
-        self,
-        decoder: AbsDecoder,
-        joint_network: JointNetwork,
-        token_list: List[int],
-        sym_space: str,
-        sym_blank: str,
-        report_cer: bool = False,
-        report_wer: bool = False,
-    ) -> None:
-        """Construct an ErrorCalculatorTransducer object."""
-        super().__init__()
-
-        self.beam_search = BeamSearchTransducer(
-            decoder=decoder,
-            joint_network=joint_network,
-            beam_size=1,
-            search_type="default",
-            score_norm=False,
-        )
-
-        self.decoder = decoder
-
-        self.token_list = token_list
-        self.space = sym_space
-        self.blank = sym_blank
-
-        self.report_cer = report_cer
-        self.report_wer = report_wer
-
-    def __call__(
-        self, encoder_out: torch.Tensor, target: torch.Tensor
-    ) -> Tuple[Optional[float], Optional[float]]:
-        """Calculate sentence-level WER or/and CER score for Transducer model.
-
-        Args:
-            encoder_out: Encoder output sequences. (B, T, D_enc)
-            target: Target label ID sequences. (B, L)
-
-        Returns:
-            : Sentence-level CER score.
-            : Sentence-level WER score.
-
-        """
-        cer, wer = None, None
-
-        batchsize = int(encoder_out.size(0))
-
-        encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
-
-        batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
-        pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
-
-        char_pred, char_target = self.convert_to_char(pred, target)
-
-        if self.report_cer:
-            cer = self.calculate_cer(char_pred, char_target)
-
-        if self.report_wer:
-            wer = self.calculate_wer(char_pred, char_target)
-
-        return cer, wer
-
-    def convert_to_char(
-        self, pred: torch.Tensor, target: torch.Tensor
-    ) -> Tuple[List, List]:
-        """Convert label ID sequences to character sequences.
-
-        Args:
-            pred: Prediction label ID sequences. (B, U)
-            target: Target label ID sequences. (B, L)
-
-        Returns:
-            char_pred: Prediction character sequences. (B, ?)
-            char_target: Target character sequences. (B, ?)
-
-        """
-        char_pred, char_target = [], []
-
-        for i, pred_i in enumerate(pred):
-            char_pred_i = [self.token_list[int(h)] for h in pred_i]
-            char_target_i = [self.token_list[int(r)] for r in target[i]]
-
-            char_pred_i = "".join(char_pred_i).replace(self.space, " ")
-            char_pred_i = char_pred_i.replace(self.blank, "")
-
-            char_target_i = "".join(char_target_i).replace(self.space, " ")
-            char_target_i = char_target_i.replace(self.blank, "")
-
-            char_pred.append(char_pred_i)
-            char_target.append(char_target_i)
-
-        return char_pred, char_target
-
-    def calculate_cer(
-        self, char_pred: torch.Tensor, char_target: torch.Tensor
-    ) -> float:
-        """Calculate sentence-level CER score.
-
-        Args:
-            char_pred: Prediction character sequences. (B, ?)
-            char_target: Target character sequences. (B, ?)
-
-        Returns:
-            : Average sentence-level CER score.
-
-        """
-        import editdistance
-
-        distances, lens = [], []
-
-        for i, char_pred_i in enumerate(char_pred):
-            pred = char_pred_i.replace(" ", "")
-            target = char_target[i].replace(" ", "")
-            distances.append(editdistance.eval(pred, target))
-            lens.append(len(target))
-
-        return float(sum(distances)) / sum(lens)
-
-    def calculate_wer(
-        self, char_pred: torch.Tensor, char_target: torch.Tensor
-    ) -> float:
-        """Calculate sentence-level WER score.
-
-        Args:
-            char_pred: Prediction character sequences. (B, ?)
-            char_target: Target character sequences. (B, ?)
-
-        Returns:
-            : Average sentence-level WER score
-
-        """
-        import editdistance
-
-        distances, lens = [], []
-
-        for i, char_pred_i in enumerate(char_pred):
-            pred = char_pred_i.replace("鈻�", " ").split()
-            target = char_target[i].replace("鈻�", " ").split()
-
-            distances.append(editdistance.eval(pred, target))
-            lens.append(len(target))
-
-        return float(sum(distances)) / sum(lens)
diff --git a/funasr/models_transducer/espnet_transducer_model_uni_asr.py b/funasr/models_transducer/espnet_transducer_model_uni_asr.py
deleted file mode 100644
index 2add3fa..0000000
--- a/funasr/models_transducer/espnet_transducer_model_uni_asr.py
+++ /dev/null
@@ -1,485 +0,0 @@
-"""ESPnet2 ASR Transducer model."""
-
-import logging
-from contextlib import contextmanager
-from typing import Dict, List, Optional, Tuple, Union
-
-import torch
-from packaging.version import parse as V
-from typeguard import check_argument_types
-
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.joint_network import JointNetwork
-from funasr.models_transducer.utils import get_transducer_task_io
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
-
-if V(torch.__version__) >= V("1.6.0"):
-    from torch.cuda.amp import autocast
-else:
-
-    @contextmanager
-    def autocast(enabled=True):
-        yield
-
-
-class UniASRTransducerModel(AbsESPnetModel):
-    """ESPnet2ASRTransducerModel module definition.
-
-    Args:
-        vocab_size: Size of complete vocabulary (w/ EOS and blank included).
-        token_list: List of token
-        frontend: Frontend module.
-        specaug: SpecAugment module.
-        normalize: Normalization module.
-        encoder: Encoder module.
-        decoder: Decoder module.
-        joint_network: Joint Network module.
-        transducer_weight: Weight of the Transducer loss.
-        fastemit_lambda: FastEmit lambda value.
-        auxiliary_ctc_weight: Weight of auxiliary CTC loss.
-        auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
-        auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
-        auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
-        ignore_id: Initial padding ID.
-        sym_space: Space symbol.
-        sym_blank: Blank Symbol
-        report_cer: Whether to report Character Error Rate during validation.
-        report_wer: Whether to report Word Error Rate during validation.
-        extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
-
-    """
-
-    def __init__(
-        self,
-        vocab_size: int,
-        token_list: Union[Tuple[str, ...], List[str]],
-        frontend: Optional[AbsFrontend],
-        specaug: Optional[AbsSpecAug],
-        normalize: Optional[AbsNormalize],
-        encoder,
-        decoder: AbsDecoder,
-        att_decoder: Optional[AbsAttDecoder],
-        joint_network: JointNetwork,
-        transducer_weight: float = 1.0,
-        fastemit_lambda: float = 0.0,
-        auxiliary_ctc_weight: float = 0.0,
-        auxiliary_ctc_dropout_rate: float = 0.0,
-        auxiliary_lm_loss_weight: float = 0.0,
-        auxiliary_lm_loss_smoothing: float = 0.0,
-        ignore_id: int = -1,
-        sym_space: str = "<space>",
-        sym_blank: str = "<blank>",
-        report_cer: bool = True,
-        report_wer: bool = True,
-        extract_feats_in_collect_stats: bool = True,
-    ) -> None:
-        """Construct an ESPnetASRTransducerModel object."""
-        super().__init__()
-
-        assert check_argument_types()
-
-        # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
-        self.blank_id = 0
-        self.vocab_size = vocab_size
-        self.ignore_id = ignore_id
-        self.token_list = token_list.copy()
-
-        self.sym_space = sym_space
-        self.sym_blank = sym_blank
-
-        self.frontend = frontend
-        self.specaug = specaug
-        self.normalize = normalize
-
-        self.encoder = encoder
-        self.decoder = decoder
-        self.joint_network = joint_network
-
-        self.criterion_transducer = None
-        self.error_calculator = None
-
-        self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
-        self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
-
-        if self.use_auxiliary_ctc:
-            self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
-            self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
-
-        if self.use_auxiliary_lm_loss:
-            self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
-            self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
-
-        self.transducer_weight = transducer_weight
-        self.fastemit_lambda = fastemit_lambda
-
-        self.auxiliary_ctc_weight = auxiliary_ctc_weight
-        self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
-
-        self.report_cer = report_cer
-        self.report_wer = report_wer
-
-        self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
-
-    def forward(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
-        decoding_ind: int = None,
-        **kwargs,
-    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        """Forward architecture and compute loss(es).
-
-        Args:
-            speech: Speech sequences. (B, S)
-            speech_lengths: Speech sequences lengths. (B,)
-            text: Label ID sequences. (B, L)
-            text_lengths: Label ID sequences lengths. (B,)
-            kwargs: Contains "utts_id".
-
-        Return:
-            loss: Main loss value.
-            stats: Task statistics.
-            weight: Task weights.
-
-        """
-        assert text_lengths.dim() == 1, text_lengths.shape
-        assert (
-            speech.shape[0]
-            == speech_lengths.shape[0]
-            == text.shape[0]
-            == text_lengths.shape[0]
-        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
-
-        batch_size = speech.shape[0]
-        text = text[:, : text_lengths.max()]
-
-        # 1. Encoder
-        ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
-        # 2. Transducer-related I/O preparation
-        decoder_in, target, t_len, u_len = get_transducer_task_io(
-            text,
-            encoder_out_lens,
-            ignore_id=self.ignore_id,
-        )
-
-        # 3. Decoder
-        self.decoder.set_device(encoder_out.device)
-        decoder_out = self.decoder(decoder_in, u_len)
-
-        # 4. Joint Network
-        joint_out = self.joint_network(
-            encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
-        )
-
-        # 5. Losses
-        loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
-            encoder_out,
-            joint_out,
-            target,
-            t_len,
-            u_len,
-        )
-
-        loss_ctc, loss_lm = 0.0, 0.0
-
-        if self.use_auxiliary_ctc:
-            loss_ctc = self._calc_ctc_loss(
-                encoder_out,
-                target,
-                t_len,
-                u_len,
-            )
-
-        if self.use_auxiliary_lm_loss:
-            loss_lm = self._calc_lm_loss(decoder_out, target)
-
-        loss = (
-            self.transducer_weight * loss_trans
-            + self.auxiliary_ctc_weight * loss_ctc
-            + self.auxiliary_lm_loss_weight * loss_lm
-        )
-
-        stats = dict(
-            loss=loss.detach(),
-            loss_transducer=loss_trans.detach(),
-            aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
-            aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
-            cer_transducer=cer_trans,
-            wer_transducer=wer_trans,
-        )
-
-        # force_gatherable: to-device and to-tensor if scalar for DataParallel
-        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-
-        return loss, stats, weight
-
-    def collect_feats(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
-        **kwargs,
-    ) -> Dict[str, torch.Tensor]:
-        """Collect features sequences and features lengths sequences.
-
-        Args:
-            speech: Speech sequences. (B, S)
-            speech_lengths: Speech sequences lengths. (B,)
-            text: Label ID sequences. (B, L)
-            text_lengths: Label ID sequences lengths. (B,)
-            kwargs: Contains "utts_id".
-
-        Return:
-            {}: "feats": Features sequences. (B, T, D_feats),
-                "feats_lengths": Features sequences lengths. (B,)
-
-        """
-        if self.extract_feats_in_collect_stats:
-            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-        else:
-            # Generate dummy stats if extract_feats_in_collect_stats is False
-            logging.warning(
-                "Generating dummy stats for feats and feats_lengths, "
-                "because encoder_conf.extract_feats_in_collect_stats is "
-                f"{self.extract_feats_in_collect_stats}"
-            )
-
-            feats, feats_lengths = speech, speech_lengths
-
-        return {"feats": feats, "feats_lengths": feats_lengths}
-
-    def encode(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        ind: int,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Encoder speech sequences.
-
-        Args:
-            speech: Speech sequences. (B, S)
-            speech_lengths: Speech sequences lengths. (B,)
-
-        Return:
-            encoder_out: Encoder outputs. (B, T, D_enc)
-            encoder_out_lens: Encoder outputs lengths. (B,)
-
-        """
-        with autocast(False):
-            # 1. Extract feats
-            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
-
-            # 2. Data augmentation
-            if self.specaug is not None and self.training:
-                feats, feats_lengths = self.specaug(feats, feats_lengths)
-
-            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
-            if self.normalize is not None:
-                feats, feats_lengths = self.normalize(feats, feats_lengths)
-
-        # 4. Forward encoder
-        encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths, ind=ind)
-
-        assert encoder_out.size(0) == speech.size(0), (
-            encoder_out.size(),
-            speech.size(0),
-        )
-        assert encoder_out.size(1) <= encoder_out_lens.max(), (
-            encoder_out.size(),
-            encoder_out_lens.max(),
-        )
-
-        return encoder_out, encoder_out_lens
-
-    def _extract_feats(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Extract features sequences and features sequences lengths.
-
-        Args:
-            speech: Speech sequences. (B, S)
-            speech_lengths: Speech sequences lengths. (B,)
-
-        Return:
-            feats: Features sequences. (B, T, D_feats)
-            feats_lengths: Features sequences lengths. (B,)
-
-        """
-        assert speech_lengths.dim() == 1, speech_lengths.shape
-
-        # for data-parallel
-        speech = speech[:, : speech_lengths.max()]
-
-        if self.frontend is not None:
-            feats, feats_lengths = self.frontend(speech, speech_lengths)
-        else:
-            feats, feats_lengths = speech, speech_lengths
-
-        return feats, feats_lengths
-
-    def _calc_transducer_loss(
-        self,
-        encoder_out: torch.Tensor,
-        joint_out: torch.Tensor,
-        target: torch.Tensor,
-        t_len: torch.Tensor,
-        u_len: torch.Tensor,
-    ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
-        """Compute Transducer loss.
-
-        Args:
-            encoder_out: Encoder output sequences. (B, T, D_enc)
-            joint_out: Joint Network output sequences (B, T, U, D_joint)
-            target: Target label ID sequences. (B, L)
-            t_len: Encoder output sequences lengths. (B,)
-            u_len: Target label ID sequences lengths. (B,)
-
-        Return:
-            loss_transducer: Transducer loss value.
-            cer_transducer: Character error rate for Transducer.
-            wer_transducer: Word Error Rate for Transducer.
-
-        """
-        if self.criterion_transducer is None:
-            try:
-                # from warprnnt_pytorch import RNNTLoss
-	        # self.criterion_transducer = RNNTLoss(
-                    # reduction="mean",
-                    # fastemit_lambda=self.fastemit_lambda,
-                # )
-                from warp_rnnt import rnnt_loss as RNNTLoss
-                self.criterion_transducer = RNNTLoss
-
-            except ImportError:
-                logging.error(
-                    "warp-rnnt was not installed."
-                    "Please consult the installation documentation."
-                )
-                exit(1)
-
-        # loss_transducer = self.criterion_transducer(
-        #     joint_out,
-        #     target,
-        #     t_len,
-        #     u_len,
-        # )
-        log_probs = torch.log_softmax(joint_out, dim=-1)
-
-        loss_transducer = self.criterion_transducer(
-                log_probs,
-                target,
-                t_len,
-                u_len,
-                reduction="mean",
-                blank=self.blank_id,
-                gather=True,
-        )
-
-        if not self.training and (self.report_cer or self.report_wer):
-            if self.error_calculator is None:
-                from espnet2.asr_transducer.error_calculator import ErrorCalculator
-
-                self.error_calculator = ErrorCalculator(
-                    self.decoder,
-                    self.joint_network,
-                    self.token_list,
-                    self.sym_space,
-                    self.sym_blank,
-                    report_cer=self.report_cer,
-                    report_wer=self.report_wer,
-                )
-
-            cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
-
-            return loss_transducer, cer_transducer, wer_transducer
-
-        return loss_transducer, None, None
-
-    def _calc_ctc_loss(
-        self,
-        encoder_out: torch.Tensor,
-        target: torch.Tensor,
-        t_len: torch.Tensor,
-        u_len: torch.Tensor,
-    ) -> torch.Tensor:
-        """Compute CTC loss.
-
-        Args:
-            encoder_out: Encoder output sequences. (B, T, D_enc)
-            target: Target label ID sequences. (B, L)
-            t_len: Encoder output sequences lengths. (B,)
-            u_len: Target label ID sequences lengths. (B,)
-
-        Return:
-            loss_ctc: CTC loss value.
-
-        """
-        ctc_in = self.ctc_lin(
-            torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
-        )
-        ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
-
-        target_mask = target != 0
-        ctc_target = target[target_mask].cpu()
-
-        with torch.backends.cudnn.flags(deterministic=True):
-            loss_ctc = torch.nn.functional.ctc_loss(
-                ctc_in,
-                ctc_target,
-                t_len,
-                u_len,
-                zero_infinity=True,
-                reduction="sum",
-            )
-        loss_ctc /= target.size(0)
-
-        return loss_ctc
-
-    def _calc_lm_loss(
-        self,
-        decoder_out: torch.Tensor,
-        target: torch.Tensor,
-    ) -> torch.Tensor:
-        """Compute LM loss.
-
-        Args:
-            decoder_out: Decoder output sequences. (B, U, D_dec)
-            target: Target label ID sequences. (B, L)
-
-        Return:
-            loss_lm: LM loss value.
-
-        """
-        lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
-        lm_target = target.view(-1).type(torch.int64)
-
-        with torch.no_grad():
-            true_dist = lm_loss_in.clone()
-            true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
-
-            # Ignore blank ID (0)
-            ignore = lm_target == 0
-            lm_target = lm_target.masked_fill(ignore, 0)
-
-            true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
-
-        loss_lm = torch.nn.functional.kl_div(
-            torch.log_softmax(lm_loss_in, dim=1),
-            true_dist,
-            reduction="none",
-        )
-        loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
-            0
-        )
-
-        return loss_lm
diff --git a/funasr/models_transducer/utils.py b/funasr/models_transducer/utils.py
deleted file mode 100644
index fd3c531..0000000
--- a/funasr/models_transducer/utils.py
+++ /dev/null
@@ -1,200 +0,0 @@
-"""Utility functions for Transducer models."""
-
-from typing import List, Tuple
-
-import torch
-
-
-class TooShortUttError(Exception):
-    """Raised when the utt is too short for subsampling.
-
-    Args:
-        message: Error message to display.
-        actual_size: The size that cannot pass the subsampling.
-        limit: The size limit for subsampling.
-
-    """
-
-    def __init__(self, message: str, actual_size: int, limit: int) -> None:
-        """Construct a TooShortUttError module."""
-        super().__init__(message)
-
-        self.actual_size = actual_size
-        self.limit = limit
-
-
-def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
-    """Check if the input is too short for subsampling.
-
-    Args:
-        sub_factor: Subsampling factor for Conv2DSubsampling.
-        size: Input size.
-
-    Returns:
-        : Whether an error should be sent.
-        : Size limit for specified subsampling factor.
-
-    """
-    if sub_factor == 2 and size < 3:
-        return True, 7
-    elif sub_factor == 4 and size < 7:
-        return True, 7
-    elif sub_factor == 6 and size < 11:
-        return True, 11
-
-    return False, -1
-
-
-def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
-    """Get conv2D second layer parameters for given subsampling factor.
-
-    Args:
-        sub_factor: Subsampling factor (1/X).
-        input_size: Input size.
-
-    Returns:
-        : Kernel size for second convolution.
-        : Stride for second convolution.
-        : Conv2DSubsampling output size.
-
-    """
-    if sub_factor == 2:
-        return 3, 1, (((input_size - 1) // 2 - 2))
-    elif sub_factor == 4:
-        return 3, 2, (((input_size - 1) // 2 - 1) // 2)
-    elif sub_factor == 6:
-        return 5, 3, (((input_size - 1) // 2 - 2) // 3)
-    else:
-        raise ValueError(
-            "subsampling_factor parameter should be set to either 2, 4 or 6."
-        )
-
-
-def make_chunk_mask(
-    size: int,
-    chunk_size: int,
-    left_chunk_size: int = 0,
-    device: torch.device = None,
-) -> torch.Tensor:
-    """Create chunk mask for the subsequent steps (size, size).
-
-    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
-
-    Args:
-        size: Size of the source mask.
-        chunk_size: Number of frames in chunk.
-        left_chunk_size: Size of the left context in chunks (0 means full context).
-        device: Device for the mask tensor.
-
-    Returns:
-        mask: Chunk mask. (size, size)
-
-    """
-    mask = torch.zeros(size, size, device=device, dtype=torch.bool)
-
-    for i in range(size):
-        if left_chunk_size <= 0:
-            start = 0
-        else:
-            start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
-
-        end = min((i // chunk_size + 1) * chunk_size, size)
-        mask[i, start:end] = True
-
-    return ~mask
-
-
-def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
-    """Create source mask for given lengths.
-
-    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
-
-    Args:
-        lengths: Sequence lengths. (B,)
-
-    Returns:
-        : Mask for the sequence lengths. (B, max_len)
-
-    """
-    max_len = lengths.max()
-    batch_size = lengths.size(0)
-
-    expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
-
-    return expanded_lengths >= lengths.unsqueeze(1)
-
-
-def get_transducer_task_io(
-    labels: torch.Tensor,
-    encoder_out_lens: torch.Tensor,
-    ignore_id: int = -1,
-    blank_id: int = 0,
-) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
-    """Get Transducer loss I/O.
-
-    Args:
-        labels: Label ID sequences. (B, L)
-        encoder_out_lens: Encoder output lengths. (B,)
-        ignore_id: Padding symbol ID.
-        blank_id: Blank symbol ID.
-
-    Returns:
-        decoder_in: Decoder inputs. (B, U)
-        target: Target label ID sequences. (B, U)
-        t_len: Time lengths. (B,)
-        u_len: Label lengths. (B,)
-
-    """
-
-    def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
-        """Create padded batch of labels from a list of labels sequences.
-
-        Args:
-            labels: Labels sequences. [B x (?)]
-            padding_value: Padding value.
-
-        Returns:
-            labels: Batch of padded labels sequences. (B,)
-
-        """
-        batch_size = len(labels)
-
-        padded = (
-            labels[0]
-            .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
-            .fill_(padding_value)
-        )
-
-        for i in range(batch_size):
-            padded[i, : labels[i].size(0)] = labels[i]
-
-        return padded
-
-    device = labels.device
-
-    labels_unpad = [y[y != ignore_id] for y in labels]
-    blank = labels[0].new([blank_id])
-
-    decoder_in = pad_list(
-        [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
-    ).to(device)
-
-    target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
-
-    encoder_out_lens = list(map(int, encoder_out_lens))
-    t_len = torch.IntTensor(encoder_out_lens).to(device)
-
-    u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
-
-    return decoder_in, target, t_len, u_len
-
-def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
-    """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
-    if t.size(dim) == pad_len:
-        return t
-    else:
-        pad_size = list(t.shape)
-        pad_size[dim] = pad_len - t.size(dim)
-        return torch.cat(
-            [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
-        )
diff --git a/funasr/models_transducer/activation.py b/funasr/modules/activation.py
similarity index 100%
rename from funasr/models_transducer/activation.py
rename to funasr/modules/activation.py
diff --git a/funasr/models_transducer/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py
similarity index 99%
rename from funasr/models_transducer/beam_search_transducer.py
rename to funasr/modules/beam_search/beam_search_transducer.py
index 8e234e4..eaf5627 100644
--- a/funasr/models_transducer/beam_search_transducer.py
+++ b/funasr/modules/beam_search/beam_search_transducer.py
@@ -6,8 +6,8 @@
 import numpy as np
 import torch
 
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.joint_network import JointNetwork
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.joint_network import JointNetwork
 
 
 @dataclass
diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py
index 92f9079..9b5039c 100644
--- a/funasr/modules/e2e_asr_common.py
+++ b/funasr/modules/e2e_asr_common.py
@@ -6,6 +6,8 @@
 
 """Common functions for ASR."""
 
+from typing import List, Optional, Tuple
+
 import json
 import logging
 import sys
@@ -13,7 +15,11 @@
 from itertools import groupby
 import numpy as np
 import six
+import torch
 
+from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.joint_network import JointNetwork
 
 def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
     """End detection.
@@ -247,3 +253,148 @@
             word_eds.append(editdistance.eval(hyp_words, ref_words))
             word_ref_lens.append(len(ref_words))
         return float(sum(word_eds)) / sum(word_ref_lens)
+
+class ErrorCalculatorTransducer:
+    """Calculate CER and WER for transducer models.
+    Args:
+        decoder: Decoder module.
+        joint_network: Joint Network module.
+        token_list: List of token units.
+        sym_space: Space symbol.
+        sym_blank: Blank symbol.
+        report_cer: Whether to compute CER.
+        report_wer: Whether to compute WER.
+    """
+
+    def __init__(
+        self,
+        decoder: AbsDecoder,
+        joint_network: JointNetwork,
+        token_list: List[int],
+        sym_space: str,
+        sym_blank: str,
+        report_cer: bool = False,
+        report_wer: bool = False,
+    ) -> None:
+        """Construct an ErrorCalculatorTransducer object."""
+        super().__init__()
+
+        self.beam_search = BeamSearchTransducer(
+            decoder=decoder,
+            joint_network=joint_network,
+            beam_size=1,
+            search_type="default",
+            score_norm=False,
+        )
+
+        self.decoder = decoder
+
+        self.token_list = token_list
+        self.space = sym_space
+        self.blank = sym_blank
+
+        self.report_cer = report_cer
+        self.report_wer = report_wer
+
+    def __call__(
+        self, encoder_out: torch.Tensor, target: torch.Tensor
+    ) -> Tuple[Optional[float], Optional[float]]:
+        """Calculate sentence-level WER or/and CER score for Transducer model.
+        Args:
+            encoder_out: Encoder output sequences. (B, T, D_enc)
+            target: Target label ID sequences. (B, L)
+        Returns:
+            : Sentence-level CER score.
+            : Sentence-level WER score.
+        """
+        cer, wer = None, None
+
+        batchsize = int(encoder_out.size(0))
+
+        encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
+
+        batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
+        pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
+
+        char_pred, char_target = self.convert_to_char(pred, target)
+
+        if self.report_cer:
+            cer = self.calculate_cer(char_pred, char_target)
+
+        if self.report_wer:
+            wer = self.calculate_wer(char_pred, char_target)
+
+        return cer, wer
+
+    def convert_to_char(
+        self, pred: torch.Tensor, target: torch.Tensor
+    ) -> Tuple[List, List]:
+        """Convert label ID sequences to character sequences.
+        Args:
+            pred: Prediction label ID sequences. (B, U)
+            target: Target label ID sequences. (B, L)
+        Returns:
+            char_pred: Prediction character sequences. (B, ?)
+            char_target: Target character sequences. (B, ?)
+        """
+        char_pred, char_target = [], []
+
+        for i, pred_i in enumerate(pred):
+            char_pred_i = [self.token_list[int(h)] for h in pred_i]
+            char_target_i = [self.token_list[int(r)] for r in target[i]]
+
+            char_pred_i = "".join(char_pred_i).replace(self.space, " ")
+            char_pred_i = char_pred_i.replace(self.blank, "")
+
+            char_target_i = "".join(char_target_i).replace(self.space, " ")
+            char_target_i = char_target_i.replace(self.blank, "")
+
+            char_pred.append(char_pred_i)
+            char_target.append(char_target_i)
+
+        return char_pred, char_target
+
+    def calculate_cer(
+        self, char_pred: torch.Tensor, char_target: torch.Tensor
+    ) -> float:
+        """Calculate sentence-level CER score.
+        Args:
+            char_pred: Prediction character sequences. (B, ?)
+            char_target: Target character sequences. (B, ?)
+        Returns:
+            : Average sentence-level CER score.
+        """
+        import editdistance
+
+        distances, lens = [], []
+
+        for i, char_pred_i in enumerate(char_pred):
+            pred = char_pred_i.replace(" ", "")
+            target = char_target[i].replace(" ", "")
+            distances.append(editdistance.eval(pred, target))
+            lens.append(len(target))
+
+        return float(sum(distances)) / sum(lens)
+
+    def calculate_wer(
+        self, char_pred: torch.Tensor, char_target: torch.Tensor
+    ) -> float:
+        """Calculate sentence-level WER score.
+        Args:
+            char_pred: Prediction character sequences. (B, ?)
+            char_target: Target character sequences. (B, ?)
+        Returns:
+            : Average sentence-level WER score
+        """
+        import editdistance
+
+        distances, lens = [], []
+
+        for i, char_pred_i in enumerate(char_pred):
+            pred = char_pred_i.replace("鈻�", " ").split()
+            target = char_target[i].replace("鈻�", " ").split()
+
+            distances.append(editdistance.eval(pred, target))
+            lens.append(len(target))
+
+        return float(sum(distances)) / sum(lens)
diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py
index 6d77d69..5d4fe1c 100644
--- a/funasr/modules/nets_utils.py
+++ b/funasr/modules/nets_utils.py
@@ -3,7 +3,7 @@
 """Network related utility tools."""
 
 import logging
-from typing import Dict
+from typing import Dict, List, Tuple
 
 import numpy as np
 import torch
@@ -506,3 +506,196 @@
     }
 
     return activation_funcs[act]()
+
+class TooShortUttError(Exception):
+    """Raised when the utt is too short for subsampling.
+
+    Args:
+        message: Error message to display.
+        actual_size: The size that cannot pass the subsampling.
+        limit: The size limit for subsampling.
+
+    """
+
+    def __init__(self, message: str, actual_size: int, limit: int) -> None:
+        """Construct a TooShortUttError module."""
+        super().__init__(message)
+
+        self.actual_size = actual_size
+        self.limit = limit
+
+
+def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
+    """Check if the input is too short for subsampling.
+
+    Args:
+        sub_factor: Subsampling factor for Conv2DSubsampling.
+        size: Input size.
+
+    Returns:
+        : Whether an error should be sent.
+        : Size limit for specified subsampling factor.
+
+    """
+    if sub_factor == 2 and size < 3:
+        return True, 7
+    elif sub_factor == 4 and size < 7:
+        return True, 7
+    elif sub_factor == 6 and size < 11:
+        return True, 11
+
+    return False, -1
+
+
+def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
+    """Get conv2D second layer parameters for given subsampling factor.
+
+    Args:
+        sub_factor: Subsampling factor (1/X).
+        input_size: Input size.
+
+    Returns:
+        : Kernel size for second convolution.
+        : Stride for second convolution.
+        : Conv2DSubsampling output size.
+
+    """
+    if sub_factor == 2:
+        return 3, 1, (((input_size - 1) // 2 - 2))
+    elif sub_factor == 4:
+        return 3, 2, (((input_size - 1) // 2 - 1) // 2)
+    elif sub_factor == 6:
+        return 5, 3, (((input_size - 1) // 2 - 2) // 3)
+    else:
+        raise ValueError(
+            "subsampling_factor parameter should be set to either 2, 4 or 6."
+        )
+
+
+def make_chunk_mask(
+    size: int,
+    chunk_size: int,
+    left_chunk_size: int = 0,
+    device: torch.device = None,
+) -> torch.Tensor:
+    """Create chunk mask for the subsequent steps (size, size).
+
+    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+    Args:
+        size: Size of the source mask.
+        chunk_size: Number of frames in chunk.
+        left_chunk_size: Size of the left context in chunks (0 means full context).
+        device: Device for the mask tensor.
+
+    Returns:
+        mask: Chunk mask. (size, size)
+
+    """
+    mask = torch.zeros(size, size, device=device, dtype=torch.bool)
+
+    for i in range(size):
+        if left_chunk_size <= 0:
+            start = 0
+        else:
+            start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
+
+        end = min((i // chunk_size + 1) * chunk_size, size)
+        mask[i, start:end] = True
+
+    return ~mask
+
+def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
+    """Create source mask for given lengths.
+
+    Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+    Args:
+        lengths: Sequence lengths. (B,)
+
+    Returns:
+        : Mask for the sequence lengths. (B, max_len)
+
+    """
+    max_len = lengths.max()
+    batch_size = lengths.size(0)
+
+    expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
+
+    return expanded_lengths >= lengths.unsqueeze(1)
+
+
+def get_transducer_task_io(
+    labels: torch.Tensor,
+    encoder_out_lens: torch.Tensor,
+    ignore_id: int = -1,
+    blank_id: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+    """Get Transducer loss I/O.
+
+    Args:
+        labels: Label ID sequences. (B, L)
+        encoder_out_lens: Encoder output lengths. (B,)
+        ignore_id: Padding symbol ID.
+        blank_id: Blank symbol ID.
+
+    Returns:
+        decoder_in: Decoder inputs. (B, U)
+        target: Target label ID sequences. (B, U)
+        t_len: Time lengths. (B,)
+        u_len: Label lengths. (B,)
+
+    """
+
+    def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
+        """Create padded batch of labels from a list of labels sequences.
+
+        Args:
+            labels: Labels sequences. [B x (?)]
+            padding_value: Padding value.
+
+        Returns:
+            labels: Batch of padded labels sequences. (B,)
+
+        """
+        batch_size = len(labels)
+
+        padded = (
+            labels[0]
+            .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
+            .fill_(padding_value)
+        )
+
+        for i in range(batch_size):
+            padded[i, : labels[i].size(0)] = labels[i]
+
+        return padded
+
+    device = labels.device
+
+    labels_unpad = [y[y != ignore_id] for y in labels]
+    blank = labels[0].new([blank_id])
+
+    decoder_in = pad_list(
+        [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
+    ).to(device)
+
+    target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
+
+    encoder_out_lens = list(map(int, encoder_out_lens))
+    t_len = torch.IntTensor(encoder_out_lens).to(device)
+
+    u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
+
+    return decoder_in, target, t_len, u_len
+
+def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
+    """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
+    if t.size(dim) == pad_len:
+        return t
+    else:
+        pad_size = list(t.shape)
+        pad_size[dim] = pad_len - t.size(dim)
+        return torch.cat(
+            [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
+        )
diff --git a/funasr/tasks/asr_transducer.py b/funasr/tasks/asr_transducer.py
index be14455..cae18c1 100644
--- a/funasr/tasks/asr_transducer.py
+++ b/funasr/tasks/asr_transducer.py
@@ -21,15 +21,13 @@
     LightweightConvolutionTransformerDecoder,
     TransformerDecoder,
 )
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
-from funasr.models_transducer.decoder.rnn_decoder import RNNDecoder
-from funasr.models_transducer.decoder.stateless_decoder import StatelessDecoder
-from funasr.models_transducer.encoder.encoder import Encoder
-from funasr.models_transducer.encoder.sanm_encoder import SANMEncoderChunkOpt
-from funasr.models_transducer.espnet_transducer_model import ESPnetASRTransducerModel
-from funasr.models_transducer.espnet_transducer_model_unified import ESPnetASRUnifiedTransducerModel
-from funasr.models_transducer.espnet_transducer_model_uni_asr import UniASRTransducerModel
-from funasr.models_transducer.joint_network import JointNetwork
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
+from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder
+from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder
+from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder
+from funasr.models.e2e_transducer import TransducerModel
+from funasr.models.e2e_transducer_unified import UnifiedTransducerModel
+from funasr.models.joint_network import JointNetwork
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.layers.global_mvn import GlobalMVN
 from funasr.layers.utterance_mvn import UtteranceMVN
@@ -75,7 +73,6 @@
         "encoder",
         classes=dict(
                 encoder=Encoder,
-                sanm_chunk_opt=SANMEncoderChunkOpt,
         ),
         default="encoder",
 )
@@ -158,7 +155,7 @@
         group.add_argument(
             "--model_conf",
             action=NestedDictAction,
-            default=get_default_kwargs(ESPnetASRTransducerModel),
+            default=get_default_kwargs(TransducerModel),
             help="The keyword arguments for the model class.",
         )
         # group.add_argument(
@@ -354,7 +351,7 @@
         return retval
 
     @classmethod
-    def build_model(cls, args: argparse.Namespace) -> ESPnetASRTransducerModel:
+    def build_model(cls, args: argparse.Namespace) -> TransducerModel:
         """Required data depending on task mode.
         Args:
             cls: ASRTransducerTask object.
@@ -440,22 +437,8 @@
 
         # 7. Build model
 
-        if getattr(args, "encoder", None) is not None and args.encoder == 'sanm_chunk_opt':
-            model = UniASRTransducerModel(
-                vocab_size=vocab_size,
-                token_list=token_list,
-                frontend=frontend,
-                specaug=specaug,
-                normalize=normalize,
-                encoder=encoder,
-                decoder=decoder,
-                att_decoder=att_decoder,
-                joint_network=joint_network,
-                **args.model_conf,
-            )
-
-        elif encoder.unified_model_training:
-            model = ESPnetASRUnifiedTransducerModel(
+        if encoder.unified_model_training:
+            model = UnifiedTransducerModel(
                 vocab_size=vocab_size,
                 token_list=token_list,
                 frontend=frontend,
@@ -469,7 +452,7 @@
             )
 
         else:
-            model = ESPnetASRTransducerModel(
+            model = TransducerModel(
                 vocab_size=vocab_size,
                 token_list=token_list,
                 frontend=frontend,

--
Gitblit v1.9.1