From 0e622e694e6cb4459955f1e5942a7c53349ce640 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 19 十二月 2023 21:58:14 +0800
Subject: [PATCH] funasr2

---
 funasr/frontends/utils/log_mel.py                                 |    0 
 funasr/frontends/wav_frontend_kaldifeat.py                        |    0 
 funasr/models/paraformer/search.py                                |    2 
 funasr/models/ct_transformer/vad_realtime_transformer.py          |    2 
 funasr/datasets/audio_datasets/__init__.py                        |    0 
 funasr/models/neat_contextual_paraformer/model.py                 |  533 ++
 funasr/models/transformer/model.py                                |   71 
 examples/industrial_data_pretraining/paraformer-large/finetune.sh |    0 
 funasr/models/uniasr/e2e_uni_asr.py                               |    8 
 funasr/bin/inference.py                                           |  218 
 funasr/models/sond/attention.py                                   |  328 +
 funasr/models/transformer/encoder.py                              |  332 +
 funasr/frontends/__init__.py                                      |    0 
 funasr/models/scama/sanm_decoder.py                               |  552 +
 funasr/models/e_branchformer/encoder.py                           |   15 
 funasr/models/bici_paraformer/cif_predictor.py                    |  340 +
 funasr/models/branchformer/model.py                               |   52 
 funasr/tokenizer/char_tokenizer.py                                |    4 
 funasr/models/sanm/model.py                                       |   18 
 funasr/models/mfcca/mfcca_encoder.py                              |   14 
 funasr/models/paraformer/cif_predictor.py                         |  376 -
 funasr/models/conformer/model.py                                  |   52 
 funasr/models/bat/model.py                                        |   39 
 funasr/models/eend/encoder.py                                     |    2 
 examples/industrial_data_pretraining/paraformer-large/infer.sh    |   10 
 funasr/frontends/fused.py                                         |    8 
 funasr/datasets/audio_datasets/load_audio_extract_fbank.py        |   11 
 funasr/utils/timestamp_tools.py                                   |    2 
 funasr/tokenizer/phoneme_tokenizer.py                             |    2 
 funasr/models/ct_transformer/attention.py                         | 1091 ++++
 funasr/models/paraformer_online/sanm_decoder.py                   |  507 +
 funasr/models/paraformer/model.py                                 | 1322 ----
 funasr/train_utils/trainer.py                                     |   13 
 funasr/utils/register.py                                          |   72 
 funasr/models/mfcca/e2e_asr_mfcca.py                              |    6 
 funasr/frontends/utils/stft.py                                    |    2 
 funasr/datasets/audio_datasets/index_ds.py                        |   64 
 funasr/frontends/utils/__init__.py                                |    0 
 funasr/models/language_model/rnn/encoders.py                      |    2 
 funasr/models/sa_asr/e2e_sa_asr.py                                |    6 
 funasr/models/paraformer/template.yaml                            |  126 
 funasr/models/transformer/utils/nets_utils.py                     |   21 
 funasr/bin/train.py                                               |   89 
 funasr/models/language_model/rnn/decoders.py                      |    2 
 funasr/models/transducer/rnn_decoder.py                           |    7 
 funasr/datasets/audio_datasets/datasets.py                        |   83 
 funasr/datasets/__init__.py                                       |    0 
 funasr/models/tp_aligner/e2e_tp.py                                |    8 
 funasr/models/sond/encoder/conv_encoder.py                        |    2 
 funasr/frontends/default.py                                       |   17 
 funasr/schedulers/__init__.py                                     |    2 
 funasr/frontends/utils/frontend.py                                |    4 
 funasr/datasets/audio_datasets/samplers.py                        |   39 
 funasr/frontends/utils/dnn_beamformer.py                          |    8 
 funasr/models/paraformer_online/model.py                          | 1284 ++++
 funasr/models/bat/conformer_chunk_encoder.py                      |  701 ++
 funasr/models/language_model/transformer_encoder.py               |  231 
 funasr/models/cnn/ResNet_aug.py                                   |    2 
 funasr/models/data2vec/data2vec.py                                |   25 
 funasr/download/download_from_hub.py                              |   10 
 funasr/models/bat/cif_predictor.py                                |  220 
 funasr/models/sond/e2e_diar_sond.py                               |    2 
 funasr/models/transformer/decoder.py                              |  647 ++
 funasr/train_utils/model_summary.py                               |    8 
 funasr/optimizers/__init__.py                                     |    2 
 funasr/frontends/s3prl.py                                         |    9 
 funasr/__init__.py                                                |   23 
 funasr/download/runtime_sdk_download_tool.py                      |   73 
 funasr/frontends/windowing.py                                     |    4 
 funasr/models/transformer/positionwise_feed_forward.py            |   22 
 funasr/models/e_branchformer/model.py                             |   52 
 funasr/tokenizer/abs_tokenizer.py                                 |   12 
 funasr/frontends/utils/mask_estimator.py                          |    0 
 funasr/metrics/compute_acc.py                                     |   23 
 funasr/models/normalize/global_mvn.py                             |    3 
 funasr/models/sa_asr/transformer_decoder.py                       |  442 -
 funasr/models/specaug/specaug.py                                  |    5 
 funasr/frontends/utils/feature_transform.py                       |    0 
 funasr/models/ct_transformer/target_delay_transformer.py          |    2 
 funasr/models/sond/encoder/self_attention_encoder.py              |   16 
 funasr/frontends/eend_ola_feature.py                              |    0 
 funasr/models/neat_contextual_paraformer/decoder.py               |   12 
 funasr/models/sanm/positionwise_feed_forward.py                   |   34 
 funasr/models/sond/encoder/fsmn_encoder.py                        |    2 
 funasr/models/bat/attention.py                                    |  238 
 funasr/models/ct_transformer/sanm_encoder.py                      |  383 +
 funasr/models/transformer/attention.py                            |  774 --
 funasr/models/data2vec/data2vec_encoder.py                        |    3 
 funasr/models/fsmn_vad/model.py                                   |    2 
 funasr/models/sanm/decoder.py                                     |  474 +
 funasr/models/sanm/attention.py                                   |  641 ++
 funasr/frontends/wav_frontend.py                                  |   18 
 funasr/models/cnn/DTDNN.py                                        |    2 
 funasr/frontends/utils/complex_utils.py                           |    0 
 funasr/models/bici_paraformer/model.py                            |  328 +
 funasr/models/normalize/utterance_mvn.py                          |    3 
 funasr/models/eend/e2e_diar_eend_ola.py                           |    2 
 funasr/models/branchformer/encoder.py                             |   11 
 funasr/models/scama/sanm_encoder.py                               |  613 ++
 funasr/models/transducer/model.py                                 |   27 
 funasr/models/paraformer/decoder.py                               |  625 ++
 /dev/null                                                         |   14 
 funasr/frontends/utils/beamformer.py                              |    0 
 funasr/models/sa_asr/attention.py                                 |   51 
 funasr/bin/tokenize_text.py                                       |    4 
 funasr/models/conformer/encoder.py                                |  613 ++
 funasr/frontends/utils/dnn_wpe.py                                 |    2 
 funasr/models/cnn/ResNet.py                                       |    2 
 funasr/models/xvector/e2e_sv.py                                   |    6 
 funasr/models/sanm/encoder.py                                     |  454 +
 110 files changed, 11,688 insertions(+), 3,952 deletions(-)

diff --git a/examples/industrial_data_pretraining/paraformer-large/run.sh b/examples/industrial_data_pretraining/paraformer-large/finetune.sh
similarity index 100%
rename from examples/industrial_data_pretraining/paraformer-large/run.sh
rename to examples/industrial_data_pretraining/paraformer-large/finetune.sh
diff --git a/examples/industrial_data_pretraining/paraformer-large/infer.sh b/examples/industrial_data_pretraining/paraformer-large/infer.sh
index b7fbe75..48ad3bf 100644
--- a/examples/industrial_data_pretraining/paraformer-large/infer.sh
+++ b/examples/industrial_data_pretraining/paraformer-large/infer.sh
@@ -2,6 +2,12 @@
 cmd="funasr/bin/inference.py"
 
 python $cmd \
++model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
++input="/Users/zhifu/Downloads/asr_example.wav" \
++output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
++device="cpu" \
+
+python $cmd \
 +model="/Users/zhifu/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
 +input="/Users/zhifu/Downloads/asr_example.wav" \
 +output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
@@ -12,4 +18,6 @@
 #+input="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
 #+model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
 
-#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
\ No newline at end of file
+#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
+#+"hotword='杈鹃瓟闄� 榄旀惌'"
\ No newline at end of file
diff --git a/funasr/__init__.py b/funasr/__init__.py
index ac1591d..c7cc3b6 100644
--- a/funasr/__init__.py
+++ b/funasr/__init__.py
@@ -1,10 +1,31 @@
 """Initialize funasr package."""
 
 import os
+import pkgutil
+import importlib
 
 dirname = os.path.dirname(__file__)
 version_file = os.path.join(dirname, "version.txt")
 with open(version_file, "r") as f:
     __version__ = f.read().strip()
 
-from funasr.bin.inference import infer
\ No newline at end of file
+
+import importlib
+import pkgutil
+
+def import_submodules(package, recursive=True):
+    if isinstance(package, str):
+        package = importlib.import_module(package)
+    results = {}
+    for loader, name, is_pkg in pkgutil.walk_packages(package.__path__, package.__name__ + '.'):
+        try:
+            results[name] = importlib.import_module(name)
+        except Exception as e:
+            # 濡傛灉鎯宠鐪嬪埌瀵煎叆閿欒鐨勫叿浣撲俊鎭紝鍙互鍙栨秷娉ㄩ噴涓嬮潰鐨勮
+            # print(f"Failed to import {name}: {e}")
+            pass
+        if recursive and is_pkg:
+            results.update(import_submodules(name))
+    return results
+
+import_submodules(__name__)
diff --git a/funasr/bin/aggregate_stats_dirs.py b/funasr/bin/aggregate_stats_dirs.py
deleted file mode 100755
index 94cbdf8..0000000
--- a/funasr/bin/aggregate_stats_dirs.py
+++ /dev/null
@@ -1,108 +0,0 @@
-#!/usr/bin/env python3
-import argparse
-import logging
-import sys
-from pathlib import Path
-from typing import Iterable
-from typing import Union
-
-import numpy as np
-
-from funasr.utils.cli_utils import get_commandline_args
-
-
-def aggregate_stats_dirs(
-        input_dir: Iterable[Union[str, Path]],
-        output_dir: Union[str, Path],
-        log_level: str,
-        skip_sum_stats: bool,
-):
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) (levelname)s: %(message)s",
-    )
-
-    input_dirs = [Path(p) for p in input_dir]
-    output_dir = Path(output_dir)
-
-    for mode in ["train", "valid"]:
-        with (input_dirs[0] / mode / "batch_keys").open("r", encoding="utf-8") as f:
-            batch_keys = [line.strip() for line in f if line.strip() != ""]
-        with (input_dirs[0] / mode / "stats_keys").open("r", encoding="utf-8") as f:
-            stats_keys = [line.strip() for line in f if line.strip() != ""]
-        (output_dir / mode).mkdir(parents=True, exist_ok=True)
-
-        for key in batch_keys:
-            with (output_dir / mode / f"{key}_shape").open(
-                    "w", encoding="utf-8"
-            ) as fout:
-                for idir in input_dirs:
-                    with (idir / mode / f"{key}_shape").open(
-                            "r", encoding="utf-8"
-                    ) as fin:
-                        # Read to the last in order to sort keys
-                        # because the order can be changed if num_workers>=1
-                        lines = fin.readlines()
-                        lines = sorted(lines, key=lambda x: x.split()[0])
-                        for line in lines:
-                            fout.write(line)
-
-        for key in stats_keys:
-            if not skip_sum_stats:
-                sum_stats = None
-                for idir in input_dirs:
-                    stats = np.load(idir / mode / f"{key}_stats.npz")
-                    if sum_stats is None:
-                        sum_stats = dict(**stats)
-                    else:
-                        for k in stats:
-                            sum_stats[k] += stats[k]
-
-                np.savez(output_dir / mode / f"{key}_stats.npz", **sum_stats)
-
-            # if --write_collected_feats=true
-            p = Path(mode) / "collect_feats" / f"{key}.scp"
-            scp = input_dirs[0] / p
-            if scp.exists():
-                (output_dir / p).parent.mkdir(parents=True, exist_ok=True)
-                with (output_dir / p).open("w", encoding="utf-8") as fout:
-                    for idir in input_dirs:
-                        with (idir / p).open("r", encoding="utf-8") as fin:
-                            for line in fin:
-                                fout.write(line)
-
-
-def get_parser() -> argparse.ArgumentParser:
-    parser = argparse.ArgumentParser(
-        description="Aggregate statistics directories into one directory",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-    parser.add_argument(
-        "--skip_sum_stats",
-        default=False,
-        action="store_true",
-        help="Skip computing the sum of statistics.",
-    )
-
-    parser.add_argument("--input_dir", action="append", help="Input directories")
-    parser.add_argument("--output_dir", required=True, help="Output directory")
-    return parser
-
-
-def main(cmd=None):
-    print(get_commandline_args(), file=sys.stderr)
-    parser = get_parser()
-    args = parser.parse_args(cmd)
-    kwargs = vars(args)
-    aggregate_stats_dirs(**kwargs)
-
-
-if __name__ == "__main__":
-    main()
diff --git a/funasr/bin/export_model.py b/funasr/bin/export_model.py
deleted file mode 100644
index 6ab9408..0000000
--- a/funasr/bin/export_model.py
+++ /dev/null
@@ -1,296 +0,0 @@
-import os
-import torch
-import random
-import logging
-import numpy as np
-from pathlib import Path
-from typing import Union, Dict, List
-from funasr.export.models import get_model
-from funasr.utils.types import str2bool, str2triple_str
-# torch_version = float(".".join(torch.__version__.split(".")[:2]))
-# assert torch_version > 1.9
-
-class ModelExport:
-    def __init__(
-        self,
-        cache_dir: Union[Path, str] = None,
-        onnx: bool = True,
-        device: str = "cpu",
-        quant: bool = True,
-        fallback_num: int = 0,
-        audio_in: str = None,
-        calib_num: int = 200,
-        model_revision: str = None,
-    ):
-        self.set_all_random_seed(0)
-
-        self.cache_dir = cache_dir
-        self.export_config = dict(
-            feats_dim=560,
-            onnx=False,
-        )
-        
-        self.onnx = onnx
-        self.device = device
-        self.quant = quant
-        self.fallback_num = fallback_num
-        self.frontend = None
-        self.audio_in = audio_in
-        self.calib_num = calib_num
-        self.model_revision = model_revision
-        
-
-    def _export(
-        self,
-        model,
-        tag_name: str = None,
-        verbose: bool = False,
-    ):
-
-        export_dir = self.cache_dir
-        os.makedirs(export_dir, exist_ok=True)
-
-        # export encoder1
-        self.export_config["model_name"] = "model"
-        model = get_model(
-            model,
-            self.export_config,
-        )
-        if isinstance(model, List):
-            for m in model:
-                m.eval()
-                if self.onnx:
-                    self._export_onnx(m, verbose, export_dir)
-                else:
-                    self._export_torchscripts(m, verbose, export_dir)
-                print("output dir: {}".format(export_dir))
-        else:
-            model.eval()
-            # self._export_onnx(model, verbose, export_dir)
-            if self.onnx:
-                self._export_onnx(model, verbose, export_dir)
-            else:
-                self._export_torchscripts(model, verbose, export_dir)
-            print("output dir: {}".format(export_dir))
-
-
-    def _torch_quantize(self, model):
-        def _run_calibration_data(m):
-            # using dummy inputs for a example
-            if self.audio_in is not None:
-                feats, feats_len = self.load_feats(self.audio_in)
-                for i, (feat, len) in enumerate(zip(feats, feats_len)):
-                    with torch.no_grad():
-                        m(feat, len)
-            else:
-                dummy_input = model.get_dummy_inputs()
-                m(*dummy_input)
-            
-
-        from torch_quant.module import ModuleFilter
-        from torch_quant.quantizer import Backend, Quantizer
-        from funasr.export.models.modules.decoder_layer import DecoderLayerSANM
-        from funasr.export.models.modules.encoder_layer import EncoderLayerSANM
-        module_filter = ModuleFilter(include_classes=[EncoderLayerSANM, DecoderLayerSANM])
-        module_filter.exclude_op_types = [torch.nn.Conv1d]
-        quantizer = Quantizer(
-            module_filter=module_filter,
-            backend=Backend.FBGEMM,
-        )
-        model.eval()
-        calib_model = quantizer.calib(model)
-        _run_calibration_data(calib_model)
-        if self.fallback_num > 0:
-            # perform automatic mixed precision quantization
-            amp_model = quantizer.amp(model)
-            _run_calibration_data(amp_model)
-            quantizer.fallback(amp_model, num=self.fallback_num)
-            print('Fallback layers:')
-            print('\n'.join(quantizer.module_filter.exclude_names))
-        quant_model = quantizer.quantize(model)
-        return quant_model
-
-
-    def _export_torchscripts(self, model, verbose, path, enc_size=None):
-        if enc_size:
-            dummy_input = model.get_dummy_inputs(enc_size)
-        else:
-            dummy_input = model.get_dummy_inputs()
-
-        if self.device == 'cuda':
-            model = model.cuda()
-            dummy_input = tuple([i.cuda() for i in dummy_input])
-
-        # model_script = torch.jit.script(model)
-        model_script = torch.jit.trace(model, dummy_input)
-        model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
-
-        if self.quant:
-            quant_model = self._torch_quantize(model)
-            model_script = torch.jit.trace(quant_model, dummy_input)
-            model_script.save(os.path.join(path, f'{model.model_name}_quant.torchscripts'))
-
-
-    def set_all_random_seed(self, seed: int):
-        random.seed(seed)
-        np.random.seed(seed)
-        torch.random.manual_seed(seed)
-
-    def parse_audio_in(self, audio_in):
-        
-        wav_list, name_list = [], []
-        if audio_in.endswith(".scp"):
-            f = open(audio_in, 'r')
-            lines = f.readlines()[:self.calib_num]
-            for line in lines:
-                name, path = line.strip().split()
-                name_list.append(name)
-                wav_list.append(path)
-        else:
-            wav_list = [audio_in,]
-            name_list = ["test",]
-        return wav_list, name_list
-    
-    def load_feats(self, audio_in: str = None):
-        import torchaudio
-
-        wav_list, name_list = self.parse_audio_in(audio_in)
-        feats = []
-        feats_len = []
-        for line in wav_list:
-            path = line.strip()
-            waveform, sampling_rate = torchaudio.load(path)
-            if sampling_rate != self.frontend.fs:
-                waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
-                                                          new_freq=self.frontend.fs)(waveform)
-            fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
-            feats.append(fbank)
-            feats_len.append(fbank_len)
-        return feats, feats_len
-    
-    def export(self,
-               tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
-               mode: str = None,
-               ):
-        
-        model_dir = tag_name
-        if model_dir.startswith('damo'):
-            from modelscope.hub.snapshot_download import snapshot_download
-            model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir, revision=self.model_revision)
-        self.cache_dir = model_dir
-
-        if mode is None:
-            import json
-            json_file = os.path.join(model_dir, 'configuration.json')
-            with open(json_file, 'r') as f:
-                config_data = json.load(f)
-                if config_data['task'] == "punctuation":
-                    mode = config_data['model']['punc_model_config']['mode']
-                else:
-                    mode = config_data['model']['model_config']['mode']
-        if mode.startswith('paraformer'):
-            from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-            config = os.path.join(model_dir, 'config.yaml')
-            model_file = os.path.join(model_dir, 'model.pb')
-            cmvn_file = os.path.join(model_dir, 'am.mvn')
-            model, asr_train_args = ASRTask.build_model_from_file(
-                config, model_file, cmvn_file, 'cpu'
-            )
-            self.frontend = model.frontend
-            self.export_config["feats_dim"] = 560
-        elif mode.startswith('offline'):
-            from funasr.tasks.vad import VADTask
-            config = os.path.join(model_dir, 'vad.yaml')
-            model_file = os.path.join(model_dir, 'vad.pb')
-            cmvn_file = os.path.join(model_dir, 'vad.mvn')
-            
-            model, vad_infer_args = VADTask.build_model_from_file(
-                config, model_file, cmvn_file=cmvn_file, device='cpu'
-            )
-            self.export_config["feats_dim"] = 400
-            self.frontend = model.frontend
-        elif mode.startswith('punc'):
-            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
-            punc_train_config = os.path.join(model_dir, 'config.yaml')
-            punc_model_file = os.path.join(model_dir, 'punc.pb')
-            model, punc_train_args = PUNCTask.build_model_from_file(
-                punc_train_config, punc_model_file, 'cpu'
-            )
-        elif mode.startswith('punc_VadRealtime'):
-            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
-            punc_train_config = os.path.join(model_dir, 'config.yaml')
-            punc_model_file = os.path.join(model_dir, 'punc.pb')
-            model, punc_train_args = PUNCTask.build_model_from_file(
-                punc_train_config, punc_model_file, 'cpu'
-            )
-        self._export(model, tag_name)
-            
-
-    def _export_onnx(self, model, verbose, path, enc_size=None):
-        if enc_size:
-            dummy_input = model.get_dummy_inputs(enc_size)
-        else:
-            dummy_input = model.get_dummy_inputs()
-
-        # model_script = torch.jit.script(model)
-        model_script = model #torch.jit.trace(model)
-        model_path = os.path.join(path, f'{model.model_name}.onnx')
-        # if not os.path.exists(model_path):
-        torch.onnx.export(
-            model_script,
-            dummy_input,
-            model_path,
-            verbose=verbose,
-            opset_version=14,
-            input_names=model.get_input_names(),
-            output_names=model.get_output_names(),
-            dynamic_axes=model.get_dynamic_axes()
-        )
-
-        if self.quant:
-            from onnxruntime.quantization import QuantType, quantize_dynamic
-            import onnx
-            quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx')
-            if not os.path.exists(quant_model_path):
-                onnx_model = onnx.load(model_path)
-                nodes = [n.name for n in onnx_model.graph.node]
-                nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m  or 'bias_decoder' in m]
-                quantize_dynamic(
-                    model_input=model_path,
-                    model_output=quant_model_path,
-                    op_types_to_quantize=['MatMul'],
-                    per_channel=True,
-                    reduce_range=False,
-                    weight_type=QuantType.QUInt8,
-                    nodes_to_exclude=nodes_to_exclude,
-                )
-
-
-if __name__ == '__main__':
-    import argparse
-    parser = argparse.ArgumentParser()
-    # parser.add_argument('--model-name', type=str, required=True)
-    parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
-    parser.add_argument('--export-dir', type=str, required=True)
-    parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
-    parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
-    parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
-    parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
-    parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
-    parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
-    parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
-    args = parser.parse_args()
-
-    export_model = ModelExport(
-        cache_dir=args.export_dir,
-        onnx=args.type == 'onnx',
-        device=args.device,
-        quant=args.quantize,
-        fallback_num=args.fallback_num,
-        audio_in=args.audio_in,
-        calib_num=args.calib_num,
-        model_revision=args.model_revision,
-    )
-    for model_name in args.model_name:
-        print("export model: {}".format(model_name))
-        export_model.export(model_name)
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index d63ebc9..09e28f3 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -5,123 +5,25 @@
 import hydra
 import json
 from omegaconf import DictConfig, OmegaConf
-from funasr.utils.dynamic_import import dynamic_import
 import logging
 from funasr.download.download_from_hub import download_model
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
-from funasr.tokenizer.funtoken import build_tokenizer
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_bytes
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_bytes
 from funasr.train_utils.device_funcs import to_device
 from tqdm import tqdm
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
 import time
 import random
 import string
+from funasr.utils.register import registry_tables
 
-@hydra.main(config_name=None, version_base=None)
-def main_hydra(kwargs: DictConfig):
-	assert "model" in kwargs
 
-	pipeline = infer(**kwargs)
-	res = pipeline(input=kwargs["input"])
-	print(res)
-	
-def infer(**kwargs):
-	
-	if ":" not in kwargs["model"]:
-		logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
-		kwargs = download_model(**kwargs)
-	
-	set_all_random_seed(kwargs.get("seed", 0))
-
-	
-	device = kwargs.get("device", "cuda")
-	if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
-		device = "cpu"
-		batch_size = 1
-	kwargs["device"] = device
-	
-	# build_tokenizer
-	tokenizer = build_tokenizer(
-		token_type=kwargs.get("token_type", "char"),
-		bpemodel=kwargs.get("bpemodel", None),
-		delimiter=kwargs.get("delimiter", None),
-		space_symbol=kwargs.get("space_symbol", "<space>"),
-		non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
-		g2p_type=kwargs.get("g2p_type", None),
-		token_list=kwargs.get("token_list", None),
-		unk_symbol=kwargs.get("unk_symbol", "<unk>"),
-	)
-
-	import pdb;
-	pdb.set_trace()
-	# build model
-	model_class = dynamic_import(kwargs.get("model"))
-	model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
-	model.eval()
-	model.to(device)
-	frontend = model.frontend
-	kwargs["token_list"] = tokenizer.token_list
-	
-	
-	# init_param
-	init_param = kwargs.get("init_param", None)
-	if init_param is not None:
-		logging.info(f"Loading pretrained params from {init_param}")
-		load_pretrained_model(
-			model=model,
-			init_param=init_param,
-			ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
-			oss_bucket=kwargs.get("oss_bucket", None),
-		)
-	
-	def _forward(input, input_len=None, **cfg):
-		cfg = OmegaConf.merge(kwargs, cfg)
-		date_type = cfg.get("date_type", "sound")
-		
-		key_list, data_list = build_iter_for_infer(input, input_len=input_len, date_type=date_type, frontend=frontend)
-		
-		speed_stats = {}
-		asr_result_list = []
-		num_samples = len(data_list)
-		pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
-		for beg_idx in range(0, num_samples, batch_size):
-
-			end_idx = min(num_samples, beg_idx + batch_size)
-			data_batch = data_list[beg_idx:end_idx]
-			key_batch = key_list[beg_idx:end_idx]
-			batch = {"data_in": data_batch, "key": key_batch}
-			
-			time1 = time.perf_counter()
-			results, meta_data = model.generate(**batch, tokenizer=tokenizer, **cfg)
-			time2 = time.perf_counter()
-			
-			asr_result_list.append(results)
-			pbar.update(1)
-			
-			# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
-			batch_data_time = meta_data.get("batch_data_time", -1)
-			speed_stats["load_data"] = meta_data["load_data"]
-			speed_stats["extract_feat"] = meta_data["extract_feat"]
-			speed_stats["forward"] = f"{time2 - time1:0.3f}"
-			speed_stats["rtf"] = f"{(time2 - time1)/batch_data_time:0.3f}"
-			description = (
-				f"{speed_stats}, "
-			)
-			pbar.set_description(description)
-		
-		torch.cuda.empty_cache()
-		return asr_result_list
-	
-	return _forward
-	
-
-def build_iter_for_infer(data_in, input_len=None, date_type="sound", frontend=None):
+def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
 	"""
 	
 	:param input:
 	:param input_len:
-	:param date_type:
+	:param data_type:
 	:param frontend:
 	:return:
 	"""
@@ -131,7 +33,7 @@
 	
 	chars = string.ascii_letters + string.digits
 	
-	if isinstance(data_in, str) and os.path.exists(data_in): # wav_pat; filelist: wav.scp, file.jsonl;text.txt;
+	if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
 		_, file_extension = os.path.splitext(data_in)
 		file_extension = file_extension.lower()
 		if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
@@ -153,10 +55,10 @@
 			key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
 			data_list = [data_in]
 			key_list = [key]
-	elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank, wav_path]
+	elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank]
 		data_list = data_in
 		key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
-	else: # raw text; audio sample point, fbank
+	else: # raw text; audio sample point, fbank; bytes
 		if isinstance(data_in, bytes): # audio bytes
 			data_in = load_bytes(data_in)
 		key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
@@ -165,6 +67,112 @@
 	
 	return key_list, data_list
 
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(kwargs: DictConfig):
+	log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
+
+	logging.basicConfig(level=log_level)
+
+	import pdb;
+	pdb.set_trace()
+	model = AutoModel(**kwargs)
+	res = model.generate(input=kwargs["input"])
+	print(res)
+
+class AutoModel:
+	def __init__(self, **kwargs):
+		registry_tables.print_register_tables()
+		assert "model" in kwargs
+		if "model_conf" not in kwargs:
+			logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+			kwargs = download_model(**kwargs)
+		
+		set_all_random_seed(kwargs.get("seed", 0))
+		
+		device = kwargs.get("device", "cuda")
+		if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
+			device = "cpu"
+			kwargs["batch_size"] = 1
+		kwargs["device"] = device
+
+		# build tokenizer
+		tokenizer = kwargs.get("tokenizer", None)
+		if tokenizer is not None:
+			tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
+			tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
+			kwargs["tokenizer"] = tokenizer
+		
+		# build frontend
+		frontend = kwargs.get("frontend", None)
+		if frontend is not None:
+			frontend_class = registry_tables.frontend_classes.get(frontend.lower())
+			frontend = frontend_class(**kwargs["frontend_conf"])
+			kwargs["frontend"] = frontend
+		
+		# build model
+		model_class = registry_tables.model_classes.get(kwargs["model"].lower())
+		model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+		model.eval()
+		model.to(device)
+		
+		kwargs["token_list"] = tokenizer.token_list
+		
+		# init_param
+		init_param = kwargs.get("init_param", None)
+		if init_param is not None:
+			logging.info(f"Loading pretrained params from {init_param}")
+			load_pretrained_model(
+				model=model,
+				init_param=init_param,
+				ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
+				oss_bucket=kwargs.get("oss_bucket", None),
+			)
+		self.kwargs = kwargs
+		self.model = model
+		self.tokenizer = tokenizer
+	
+	def generate(self, input, input_len=None, **cfg):
+		self.kwargs.update(cfg)
+		data_type = self.kwargs.get("data_type", "sound")
+		batch_size = self.kwargs.get("batch_size", 1)
+		if self.kwargs.get("device", "cpu") == "cpu":
+			batch_size = 1
+		
+		key_list, data_list = build_iter_for_infer(input, input_len=input_len, data_type=data_type)
+		
+		speed_stats = {}
+		asr_result_list = []
+		num_samples = len(data_list)
+		pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+		for beg_idx in range(0, num_samples, batch_size):
+			end_idx = min(num_samples, beg_idx + batch_size)
+			data_batch = data_list[beg_idx:end_idx]
+			key_batch = key_list[beg_idx:end_idx]
+			batch = {"data_in": data_batch, "key": key_batch}
+			if (end_idx - beg_idx) == 1 and isinstance(data_batch[0], torch.Tensor): # fbank
+				batch["data_batch"] = data_batch[0]
+				batch["data_lengths"] = input_len
+		
+			time1 = time.perf_counter()
+			results, meta_data = self.model.generate(**batch, **self.kwargs)
+			time2 = time.perf_counter()
+			
+			asr_result_list.append(results)
+			pbar.update(1)
+			
+			# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
+			batch_data_time = meta_data.get("batch_data_time", -1)
+			speed_stats["load_data"] = meta_data.get("load_data", 0.0)
+			speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
+			speed_stats["forward"] = f"{time2 - time1:0.3f}"
+			speed_stats["rtf"] = f"{(time2 - time1) / batch_data_time:0.3f}"
+			description = (
+				f"{speed_stats}, "
+			)
+			pbar.set_description(description)
+		
+		torch.cuda.empty_cache()
+		return asr_result_list
 
 if __name__ == '__main__':
 	main_hydra()
\ No newline at end of file
diff --git a/funasr/bin/tokenize_text.py b/funasr/bin/tokenize_text.py
index 674c1b9..0ecf2f6 100755
--- a/funasr/bin/tokenize_text.py
+++ b/funasr/bin/tokenize_text.py
@@ -11,7 +11,7 @@
 from funasr.utils.cli_utils import get_commandline_args
 from funasr.tokenizer.build_tokenizer import build_tokenizer
 from funasr.tokenizer.cleaner import TextCleaner
-from funasr.tokenizer.phoneme_tokenizer import g2p_choices
+from funasr.tokenizer.phoneme_tokenizer import g2p_classes
 from funasr.utils.types import str2bool
 from funasr.utils.types import str_or_none
 
@@ -239,7 +239,7 @@
     parser.add_argument(
         "--g2p",
         type=str_or_none,
-        choices=g2p_choices,
+        choices=g2p_classes,
         default=None,
         help="Specify g2p method if --token_type=phn",
     )
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 6a88233..72fa9fa 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -8,34 +8,30 @@
 import hydra
 from omegaconf import DictConfig, OmegaConf
 from funasr.train_utils.set_all_random_seed import set_all_random_seed
-# from funasr.model_class_factory1 import model_choices
 from funasr.models.lora.utils import mark_only_lora_as_trainable
-from funasr.optimizers import optim_choices
-from funasr.schedulers import scheduler_choices
+from funasr.optimizers import optim_classes
+from funasr.schedulers import scheduler_classes
 from funasr.train_utils.load_pretrained_model import load_pretrained_model
 from funasr.train_utils.initialize import initialize
-from funasr.datasets.fun_datasets.data_sampler import BatchSampler
 # from funasr.tokenizer.build_tokenizer import build_tokenizer
 # from funasr.tokenizer.token_id_converter import TokenIDConverter
-from funasr.tokenizer.funtoken import build_tokenizer
-from funasr.datasets.fun_datasets.dataset_jsonl import AudioDataset
+# from funasr.tokenizer.funtoken import build_tokenizer
 from funasr.train_utils.trainer import Trainer
-# from funasr.utils.load_fr_py import load_class_from_path
-from funasr.utils.dynamic_import import dynamic_import
 import torch.distributed as dist
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 from funasr.download.download_from_hub import download_model
+from funasr.utils.register import registry_tables
 
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
 	import pdb; pdb.set_trace()
-	if ":" in kwargs["model"]:
+	assert "model" in kwargs
+	if "model_conf" not in kwargs:
 		logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
 		kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
 	
-	import pdb;
-	pdb.set_trace()
+
 	main(**kwargs)
 
 
@@ -43,6 +39,7 @@
 	# preprocess_config(kwargs)
 	# import pdb; pdb.set_trace()
 	# set random seed
+	registry_tables.print_register_tables()
 	set_all_random_seed(kwargs.get("seed", 0))
 	torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled)
 	torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark)
@@ -56,31 +53,38 @@
 		dist.init_process_group(backend=kwargs.get("backend", "nccl"), init_method='env://')
 		torch.cuda.set_device(local_rank)
 	
-	
-	# build_tokenizer
-	tokenizer = build_tokenizer(
-		token_type=kwargs.get("token_type", "char"),
-		bpemodel=kwargs.get("bpemodel", None),
-		delimiter=kwargs.get("delimiter", None),
-		space_symbol=kwargs.get("space_symbol", "<space>"),
-		non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
-		g2p_type=kwargs.get("g2p_type", None),
-		token_list=kwargs.get("token_list", None),
-		unk_symbol=kwargs.get("unk_symbol", "<unk>"),
-	)
+	# save config.yaml
+	if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
+		os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
+		yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
+		OmegaConf.save(config=kwargs, f=yaml_file)
+		logging.info("config.yaml is saved to: %s", yaml_file)
 
+	tokenizer = kwargs.get("tokenizer", None)
+	if tokenizer is not None:
+		tokenizer_class = registry_tables.tokenizer_classes.get(tokenizer.lower())
+		tokenizer = tokenizer_class(**kwargs["tokenizer_conf"])
+		kwargs["tokenizer"] = tokenizer
+	
+	# build frontend if frontend is none None
+	frontend = kwargs.get("frontend", None)
+	if frontend is not None:
+		frontend_class = registry_tables.frontend_classes.get(frontend.lower())
+		frontend = frontend_class(**kwargs["frontend_conf"])
+		kwargs["frontend"] = frontend
+	
 	# import pdb;
 	# pdb.set_trace()
 	# build model
-	# model_class = model_choices.get_class(kwargs.get("model", "asr"))
-	# model_class = load_class_from_path(kwargs.get("model").split(":"))
-	model_class = dynamic_import(kwargs.get("model"))
+	model_class = registry_tables.model_classes.get(kwargs["model"].lower())
 	model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
-	frontend = model.frontend
+
+
+
 	# init_param
 	init_param = kwargs.get("init_param", None)
 	if init_param is not None:
-		if not isinstance(init_param, Sequence):
+		if not isinstance(init_param, (list, tuple)):
 			init_param = (init_param,)
 		logging.info("init_param is not None: %s", init_param)
 		for p in init_param:
@@ -93,9 +97,8 @@
 			)
 	else:
 		initialize(model, kwargs.get("init", "kaiming_normal"))
-	
-	# import pdb;
-	# pdb.set_trace()
+
+
 	# freeze_param
 	freeze_param = kwargs.get("freeze_param", None)
 	if freeze_param is not None:
@@ -122,33 +125,33 @@
 		
 	# optim
 	optim = kwargs.get("optim", "adam")
-	assert optim in optim_choices
-	optim_class = optim_choices.get(optim)
+	assert optim in optim_classes
+	optim_class = optim_classes.get(optim)
 	optim = optim_class(model.parameters(), **kwargs.get("optim_conf"))
 	
 	# scheduler
 	scheduler = kwargs.get("scheduler", "warmuplr")
-	assert scheduler in scheduler_choices
-	scheduler_class = scheduler_choices.get(scheduler)
+	assert scheduler in scheduler_classes
+	scheduler_class = scheduler_classes.get(scheduler)
 	scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
 
-
+	# import pdb;
+	# pdb.set_trace()
 	# dataset
-	dataset_tr = AudioDataset(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
+	dataset_class = registry_tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset").lower())
+	dataset_tr = dataset_class(kwargs.get("train_data_set_list"), frontend=frontend, tokenizer=tokenizer, **kwargs.get("dataset_conf"))
 
 	# dataloader
-	batch_sampler = BatchSampler(dataset_tr, **kwargs.get("dataset_conf"), **kwargs.get("dataset_conf").get("batch_conf"))
+	batch_sampler = kwargs["dataset_conf"].get("batch_sampler", "DynamicBatchLocalShuffleSampler")
+	batch_sampler_class = registry_tables.batch_sampler_classes.get(batch_sampler.lower())
+	batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf"))
 	dataloader_tr = torch.utils.data.DataLoader(dataset_tr,
 	                                            collate_fn=dataset_tr.collator,
 	                                            batch_sampler=batch_sampler,
 	                                            num_workers=kwargs.get("dataset_conf").get("num_workers", 4),
 	                                            pin_memory=True)
 	
-	if (use_ddp or use_fsdp) and dist.get_rank() == 0 or not (use_ddp or use_fsdp) and local_rank == 0:
-		os.makedirs(kwargs.get("output_dir", "./"), exist_ok=True)
-		yaml_file = os.path.join(kwargs.get("output_dir", "./"), "config.yaml")
-		OmegaConf.save(config=kwargs, f=yaml_file)
-		logging.info("config.yaml is saved to: %s", yaml_file)
+
 	
 	trainer = Trainer(
 	    model=model,
diff --git a/funasr/datasets/fun_datasets/__init__.py b/funasr/datasets/__init__.py
similarity index 100%
rename from funasr/datasets/fun_datasets/__init__.py
rename to funasr/datasets/__init__.py
diff --git a/funasr/datasets/fun_datasets/__init__.py b/funasr/datasets/audio_datasets/__init__.py
similarity index 100%
copy from funasr/datasets/fun_datasets/__init__.py
copy to funasr/datasets/audio_datasets/__init__.py
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
new file mode 100644
index 0000000..353a3a0
--- /dev/null
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -0,0 +1,83 @@
+import torch
+import json
+import torch.distributed as dist
+import numpy as np
+import kaldiio
+import librosa
+import torchaudio
+import time
+import logging
+
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.utils.register import register_class, registry_tables
+
+@register_class("dataset_classes", "AudioDataset")
+class AudioDataset(torch.utils.data.Dataset):
+	def __init__(self,
+	             path,
+	             index_ds: str = None,
+	             frontend=None,
+	             tokenizer=None,
+	             int_pad_value: int = -1,
+	             float_pad_value: float = 0.0,
+	              **kwargs):
+		super().__init__()
+		index_ds_class = registry_tables.index_ds_classes.get(index_ds.lower())
+		self.index_ds = index_ds_class(path)
+		self.frontend = frontend
+		self.fs = 16000 if frontend is None else frontend.fs
+		self.data_type = "sound"
+		self.tokenizer = tokenizer
+
+		self.int_pad_value = int_pad_value
+		self.float_pad_value = float_pad_value
+	
+	def get_source_len(self, index):
+		item = self.index_ds[index]
+		return self.index_ds.get_source_len(item)
+	
+	def get_target_len(self, index):
+		item = self.index_ds[index]
+		return self.index_ds.get_target_len(item)
+	
+	def __len__(self):
+		return len(self.index_ds)
+	
+	def __getitem__(self, index):
+		item = self.index_ds[index]
+		# import pdb;
+		# pdb.set_trace()
+		source = item["source"]
+		data_src = load_audio(source, fs=self.fs)
+		speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend) # speech: [b, T, d]
+		target = item["target"]
+		ids = self.tokenizer.encode(target)
+		ids_lengths = len(ids)
+		text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
+
+		return {"speech": speech[0, :, :],
+		        "speech_lengths": speech_lengths,
+		        "text": text,
+		        "text_lengths": text_lengths,
+		        }
+	
+	
+	def collator(self, samples: list=None):
+
+
+		outputs = {}
+		for sample in samples:
+			for key in sample.keys():
+				if key not in outputs:
+					outputs[key] = []
+				outputs[key].append(sample[key])
+
+		for key, data_list in outputs.items():
+			if data_list[0].dtype == torch.int64:
+
+				pad_value = self.int_pad_value
+			else:
+				pad_value = self.float_pad_value
+			outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+		return outputs
+
diff --git a/funasr/datasets/audio_datasets/index_ds.py b/funasr/datasets/audio_datasets/index_ds.py
new file mode 100644
index 0000000..33b309a
--- /dev/null
+++ b/funasr/datasets/audio_datasets/index_ds.py
@@ -0,0 +1,64 @@
+import torch
+import json
+import torch.distributed as dist
+import time
+import logging
+
+from funasr.utils.register import register_class
+
+@register_class("index_ds_classes", "IndexDSJsonl")
+class IndexDSJsonl(torch.utils.data.Dataset):
+	
+	def __init__(self, path):
+		super().__init__()
+		
+		contents = []
+		with open(path, encoding='utf-8') as fin:
+			for line in fin:
+				data = json.loads(line.strip())
+				if "text" in data:  # for sft
+					self.contents.append(data['text'])
+				if "source" in data:  # for speech lab pretrain
+					prompt = data["prompt"]
+					source = data["source"]
+					target = data["target"]
+					source_len = data["source_len"]
+					target_len = data["target_len"]
+
+					contents.append({"source": source,
+					                 "prompt": prompt,
+					                 "target": target,
+					                 "source_len": source_len,
+					                 "target_len": target_len,
+					                 }
+					                )
+		
+		self.contents = []
+		total_num = len(contents)
+		try:
+			rank = dist.get_rank()
+			world_size = dist.get_world_size()
+		except:
+			rank = 0
+			world_size = 1
+			logging.warning("distributed is not initialized, only single shard")
+		num_per_rank = total_num // world_size
+		
+		# rank = 0
+		# import ipdb; ipdb.set_trace()
+		self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
+	
+		logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
+
+	def __len__(self):
+		return len(self.contents)
+	
+	def __getitem__(self, index):
+		return self.contents[index]
+	
+	def get_source_len(self, data_dict):
+		return data_dict["source_len"]
+
+	def get_target_len(self, data_dict):
+		
+		return data_dict["target_len"] if "target_len" in data_dict else 0
diff --git a/funasr/datasets/fun_datasets/load_audio_extract_fbank.py b/funasr/datasets/audio_datasets/load_audio_extract_fbank.py
similarity index 93%
rename from funasr/datasets/fun_datasets/load_audio_extract_fbank.py
rename to funasr/datasets/audio_datasets/load_audio_extract_fbank.py
index c76f346..c8883ee 100644
--- a/funasr/datasets/fun_datasets/load_audio_extract_fbank.py
+++ b/funasr/datasets/audio_datasets/load_audio_extract_fbank.py
@@ -46,15 +46,16 @@
 	array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
 	return array
 
-def extract_fbank(data, data_len = None, date_type: str="sound", frontend=None):
-	
+def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None):
+	# import pdb;
+	# pdb.set_trace()
 	if isinstance(data, np.ndarray):
 		data = torch.from_numpy(data)
-		if len(data) < 2:
+		if len(data.shape) < 2:
 			data = data[None, :] # data: [batch, N]
 		data_len = [data.shape[1]] if data_len is None else data_len
 	elif isinstance(data, torch.Tensor):
-		if len(data) < 2:
+		if len(data.shape) < 2:
 			data = data[None, :] # data: [batch, N]
 		data_len = [data.shape[1]] if data_len is None else data_len
 	elif isinstance(data, (list, tuple)):
@@ -67,7 +68,7 @@
 		data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
 	# import pdb;
 	# pdb.set_trace()
-	if date_type == "sound":
+	if data_type == "sound":
 		data, data_len = frontend(data, data_len)
 	
 	if isinstance(data_len, (list, tuple)):
diff --git a/funasr/datasets/fun_datasets/data_sampler.py b/funasr/datasets/audio_datasets/samplers.py
similarity index 61%
rename from funasr/datasets/fun_datasets/data_sampler.py
rename to funasr/datasets/audio_datasets/samplers.py
index 3a19a17..7d3a941 100644
--- a/funasr/datasets/fun_datasets/data_sampler.py
+++ b/funasr/datasets/audio_datasets/samplers.py
@@ -2,31 +2,38 @@
 
 import numpy as np
 
+from funasr.utils.register import register_class
+
+@register_class("batch_sampler_classes", "DynamicBatchLocalShuffleSampler")
 class BatchSampler(torch.utils.data.BatchSampler):
 	
-	def __init__(self, dataset, batch_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs):
+	def __init__(self, dataset,
+	             batch_type: str="example",
+	             batch_size: int=100,
+	             buffer_size: int=30,
+	             drop_last: bool=False,
+	             shuffle: bool=True,
+	             **kwargs):
 		
 		self.drop_last = drop_last
 		self.pre_idx = -1
 		self.dataset = dataset
 		self.total_samples = len(dataset)
-		# self.batch_type = args.batch_type
-		# self.batch_size = args.batch_size
-		# self.sort_size = args.sort_size
-		# self.max_length_token = args.max_length_token
 		self.batch_type = batch_type
 		self.batch_size = batch_size
-		self.sort_size = sort_size
-		self.max_length_token = kwargs.get("max_length_token", 5000)
+		self.buffer_size = buffer_size
+		self.max_token_length = kwargs.get("max_token_length", 5000)
 		self.shuffle_idx = np.arange(self.total_samples)
 		self.shuffle = shuffle
 
 	
 	def __len__(self):
 		return self.total_samples
-
+	
+	def set_epoch(self, epoch):
+		np.random.seed(epoch)
+		
 	def __iter__(self):
-		# print("in sampler")
 		
 		if self.shuffle:
 			np.random.shuffle(self.shuffle_idx)
@@ -35,31 +42,31 @@
 		max_token = 0
 		num_sample = 0
 
-		iter_num = (self.total_samples-1) // self.sort_size + 1
+		iter_num = (self.total_samples-1) // self.buffer_size + 1
 		# print("iter_num: ", iter_num)
 		for iter in range(self.pre_idx + 1, iter_num):
 			datalen_with_index = []
-			for i in range(self.sort_size):
-				idx = iter * self.sort_size + i
+			for i in range(self.buffer_size):
+				idx = iter * self.buffer_size + i
 				if idx >= self.total_samples:
 					continue
 
 				idx_map = self.shuffle_idx[idx]
 				# prompt = self.dataset.indexed_dataset[idx_map]["prompt"]
-				sample_len_cur = self.dataset.indexed_dataset.get_source_len(self.dataset.indexed_dataset[idx_map]) + \
-				                 self.dataset.indexed_dataset.get_target_len(self.dataset.indexed_dataset[idx_map])
+				sample_len_cur = self.dataset.get_source_len(idx_map) + \
+				                 self.dataset.get_target_len(idx_map)
 
 				datalen_with_index.append([idx, sample_len_cur])
 			
 			datalen_with_index_sort = sorted(datalen_with_index, key=lambda x: x[1])
 			for item in datalen_with_index_sort:
 				idx, sample_len_cur_raw = item
-				if sample_len_cur_raw > self.max_length_token:
+				if sample_len_cur_raw > self.max_token_length:
 					continue
 
 				max_token_cur = max(max_token, sample_len_cur_raw)
 				max_token_padding = 1 + num_sample
-				if self.batch_type == 'token':
+				if self.batch_type == 'length':
 					max_token_padding *= max_token_cur
 				if max_token_padding <= self.batch_size:
 					batch.append(idx)
diff --git a/funasr/datasets/fun_datasets/dataloader_fn.py b/funasr/datasets/fun_datasets/dataloader_fn.py
deleted file mode 100644
index 601cbeb..0000000
--- a/funasr/datasets/fun_datasets/dataloader_fn.py
+++ /dev/null
@@ -1,63 +0,0 @@
-import time
-import torch
-from funasr.datasets.fun_datasets.dataset_jsonl import AudioDataset
-from funasr.datasets.fun_datasets.data_sampler import BatchSampler
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.tokenizer.build_tokenizer import build_tokenizer
-from funasr.tokenizer.token_id_converter import TokenIDConverter
-collate_fn = None
-# collate_fn = collate_fn,
-
-jsonl = "/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl"
-
-frontend = WavFrontend()
-token_type = 'char'
-bpemodel = None
-delimiter = None
-space_symbol = "<space>"
-non_linguistic_symbols = None
-g2p_type = None
-
-tokenizer = build_tokenizer(
-    token_type=token_type,
-    bpemodel=bpemodel,
-    delimiter=delimiter,
-    space_symbol=space_symbol,
-    non_linguistic_symbols=non_linguistic_symbols,
-    g2p_type=g2p_type,
-)
-token_list = "/Users/zhifu/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt"
-unk_symbol = "<unk>"
-
-token_id_converter = TokenIDConverter(
-    token_list=token_list,
-    unk_symbol=unk_symbol,
-)
-
-dataset = AudioDataset(jsonl, frontend=frontend, tokenizer=tokenizer, token_id_converter=token_id_converter)
-batch_sampler = BatchSampler(dataset)
-
-
-if __name__ == "__main__":
-    
-    dataloader_tr = torch.utils.data.DataLoader(dataset,
-                                                collate_fn=dataset.collator,
-                                                batch_sampler=batch_sampler,
-                                                shuffle=False,
-                                                num_workers=0,
-                                                pin_memory=True)
-    
-    print(len(dataset))
-    for i in range(3):
-        print(i)
-        beg = time.time()
-        for j, data in enumerate(dataloader_tr):
-            end = time.time()
-            time_cost = end - beg
-            beg = end
-            print(j, time_cost)
-    # data_iter = iter(dataloader_tr)
-    # data = next(data_iter)
-    pass
-
-    
diff --git a/funasr/datasets/fun_datasets/dataset_jsonl.py b/funasr/datasets/fun_datasets/dataset_jsonl.py
deleted file mode 100644
index 21df89e..0000000
--- a/funasr/datasets/fun_datasets/dataset_jsonl.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import torch
-import json
-import torch.distributed as dist
-import numpy as np
-import kaldiio
-import librosa
-import torchaudio
-import time
-import logging
-
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
-	
-	
-
-class IndexedDatasetJsonl(torch.utils.data.Dataset):
-	
-	def __init__(self, path):
-		super().__init__()
-		
-		contents = []
-		with open(path, encoding='utf-8') as fin:
-			for line in fin:
-				data = json.loads(line.strip())
-				if "text" in data:  # for sft
-					self.contents.append(data['text'])
-				if "source" in data:  # for speech lab pretrain
-					prompt = data["prompt"]
-					source = data["source"]
-					target = data["target"]
-					source_len = data["source_len"]
-					target_len = data["target_len"]
-
-					contents.append({"source": source,
-					                 "prompt": prompt,
-					                 "target": target,
-					                 "source_len": source_len,
-					                 "target_len": target_len,
-					                 }
-					                )
-		
-		self.contents = []
-		total_num = len(contents)
-		try:
-			rank = dist.get_rank()
-			world_size = dist.get_world_size()
-		except:
-			rank = 0
-			world_size = 1
-			logging.warning("distributed is not initialized, only single shard")
-		num_per_rank = total_num // world_size
-		
-		# rank = 0
-		# import ipdb; ipdb.set_trace()
-		self.contents = contents[rank * num_per_rank:(rank + 1) * num_per_rank]
-	
-		logging.info("in rank: {}, num of samplers: {}, total_num of samplers across ranks: {}".format(rank, len(self.contents), len(contents)))
-
-	def __len__(self):
-		return len(self.contents)
-	
-	def __getitem__(self, index):
-		return self.contents[index]
-	
-	def get_source_len(self, data_dict):
-		return data_dict["source_len"]
-
-	def get_target_len(self, data_dict):
-		
-		return data_dict["target_len"] if "target_len" in data_dict else 0
-
-
-class AudioDataset(torch.utils.data.Dataset):
-	def __init__(self, path, frontend=None, tokenizer=None, int_pad_value: int = -1, float_pad_value: float = 0.0, **kwargs):
-		super().__init__()
-		self.indexed_dataset = IndexedDatasetJsonl(path)
-		self.frontend = frontend.forward
-		self.fs = 16000 if frontend is None else frontend.fs
-		self.data_type = "sound"
-		self.tokenizer = tokenizer
-
-		self.int_pad_value = int_pad_value
-		self.float_pad_value = float_pad_value
-
-	
-
-	
-	def __len__(self):
-		return len(self.indexed_dataset)
-	
-	def __getitem__(self, index):
-		item = self.indexed_dataset[index]
-
-		source = item["source"]
-		data_src = load_audio(source, fs=self.fs)
-		speech, speech_lengths = extract_fbank(data_src, self.data_type, self.frontend) # speech: [b, T, d]
-		target = item["target"]
-		ids = self.tokenizer.encode(target)
-		ids_lengths = len(ids)
-		text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
-
-		return {"speech": speech[0, :, :],
-		        "speech_lengths": speech_lengths,
-		        "text": text,
-		        "text_lengths": text_lengths,
-		        }
-	
-	
-	def collator(self, samples: list=None):
-		
-		# return samples
-		
-		outputs = {}
-		for sample in samples:
-			for key in sample.keys():
-				if key not in outputs:
-					outputs[key] = []
-				outputs[key].append(sample[key])
-
-		for key, data_list in outputs.items():
-			if data_list[0].dtype == torch.int64:
-
-				pad_value = self.int_pad_value
-			else:
-				pad_value = self.float_pad_value
-			outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
-		return outputs
-
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 24ab797..eeb5d0c 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -22,11 +22,17 @@
 	kwargs = OmegaConf.merge(cfg, kwargs)
 	init_param = os.path.join(model_or_path, "model.pb")
 	kwargs["init_param"] = init_param
-	kwargs["token_list"] = os.path.join(model_or_path, "tokens.txt")
+	if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+		kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+	if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+		kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+	if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+		kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+	
 	kwargs["model"] = cfg["model"]
 	kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
 	
-	return kwargs
+	return OmegaConf.to_container(kwargs, resolve=True)
 
 def get_or_download_model_dir(
                               model,
diff --git a/funasr/download/runtime_sdk_download_tool.py b/funasr/download/runtime_sdk_download_tool.py
index 91c5844..92416f4 100644
--- a/funasr/download/runtime_sdk_download_tool.py
+++ b/funasr/download/runtime_sdk_download_tool.py
@@ -3,38 +3,43 @@
 import argparse
 from funasr.utils.types import str2bool
 
-parser = argparse.ArgumentParser()
-parser.add_argument('--model-name', type=str, required=True)
-parser.add_argument('--export-dir', type=str, required=True)
-parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
-parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
-parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
-parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
-parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
-parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
-parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
-parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
-args = parser.parse_args()
+def main():
+	parser = argparse.ArgumentParser()
+	parser.add_argument('--model-name', type=str, required=True)
+	parser.add_argument('--export-dir', type=str, required=True)
+	parser.add_argument('--export', type=str2bool, default=True, help='whether to export model')
+	parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
+	parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
+	parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
+	parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
+	parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
+	parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
+	parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
+	args = parser.parse_args()
+	
+	model_dir = args.model_name
+	if not Path(args.model_name).exists():
+		from modelscope.hub.snapshot_download import snapshot_download
+		try:
+			model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
+		except:
+			raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
+				(model_dir)
+	if args.export:
+		model_file = os.path.join(model_dir, 'model.onnx')
+		if args.quantize:
+			model_file = os.path.join(model_dir, 'model_quant.onnx')
+		if not os.path.exists(model_file):
+			print(".onnx is not exist, begin to export onnx")
+			from funasr.bin.export_model import ModelExport
+			export_model = ModelExport(
+				cache_dir=args.export_dir,
+				onnx=True,
+				device="cpu",
+				quant=args.quantize,
+			)
+			export_model.export(model_dir)
 
-model_dir = args.model_name
-if not Path(args.model_name).exists():
-	from modelscope.hub.snapshot_download import snapshot_download
-	try:
-		model_dir = snapshot_download(args.model_name, cache_dir=args.export_dir, revision=args.model_revision)
-	except:
-		raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format \
-			(model_dir)
-if args.export:
-	model_file = os.path.join(model_dir, 'model.onnx')
-	if args.quantize:
-		model_file = os.path.join(model_dir, 'model_quant.onnx')
-	if not os.path.exists(model_file):
-		print(".onnx is not exist, begin to export onnx")
-		from funasr.bin.export_model import ModelExport
-		export_model = ModelExport(
-			cache_dir=args.export_dir,
-			onnx=True,
-			device="cpu",
-			quant=args.quantize,
-		)
-		export_model.export(model_dir)
\ No newline at end of file
+
+if __name__ == "__main__":
+	main()
\ No newline at end of file
diff --git a/funasr/export/README.md b/funasr/export/README.md
deleted file mode 100644
index bbb8bf8..0000000
--- a/funasr/export/README.md
+++ /dev/null
@@ -1,93 +0,0 @@
-# Export models
-
-## Environments
-### Install modelscope and funasr
-
-The installation is the same as [funasr](https://github.com/alibaba-damo-academy/FunASR/blob/main/README.md#installation)
-```shell
-# pip3 install torch torchaudio
-pip install -U modelscope funasr
-# For the users in China, you could install with the command:
-# pip install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
-```
-### Install the quantization tools
-```shell
-pip install torch-quant # Optional, for torchscript quantization
-pip install onnx onnxruntime # Optional, for onnx quantization
-```
-
-## Usage
-   `Tips`: torch>=1.11.0
-
-   ```shell
-   python -m funasr.export.export_model \
-       --model-name [model_name] \
-       --export-dir [export_dir] \
-       --type [onnx, torch] \
-       --quantize [true, false] \
-       --fallback-num [fallback_num]
-   ```
-   `model-name`: the model is to export. It could be the models from modelscope, or local finetuned model(named: model.pb).
-
-   `export-dir`: the dir where the onnx is export.
-
-   `type`: `onnx` or `torch`, export onnx format model or torchscript format model.
-
-   `quantize`: `true`, export quantized model at the same time; `false`, export fp32 model only.
-
-   `fallback-num`: specify the number of fallback layers to perform automatic mixed precision quantization.
-
-
-### Export onnx format model
-#### Export model from modelscope
-```shell
-python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize false
-```
-#### Export model from local path
-The model'name must be `model.pb`
-```shell
-python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize false
-```
-#### Test onnx model
-Ref to [test](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export/test)
-
-### Export torchscripts format model
-#### Export model from modelscope
-```shell
-python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch --quantize false
-```
-
-#### Export model from local path
-The model'name must be `model.pb`
-```shell
-python -m funasr.export.export_model --model-name /mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type torch --quantize false
-```
-#### Test onnx model
-Ref to [test](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export/test)
-
-## Runtime
-### ONNXRuntime
-#### ONNXRuntime-python
-Ref to [funasr-onnx](../../runtime/python/onnxruntime/README.md)
-#### ONNXRuntime-cpp
-Ref to [docs](../../runtime/readme.md)
-### Libtorch
-#### Libtorch-python
-Ref to [funasr-torch](../../runtime/python/libtorch/README.md)
-#### Libtorch-cpp
-Undo
-## Performance Benchmark
-
-### Paraformer on CPU
-
-[onnx runtime](../../runtime/docs/benchmark_onnx_cpp.md)
-
-[libtorch runtime](../../runtime/docs/benchmark_libtorch.md)
-
-### Paraformer on GPU
-[nv-triton](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/triton_gpu)
-
-
-## Acknowledge
-Torch model quantization is supported by [BladeDISC](https://github.com/alibaba/BladeDISC), an end-to-end DynamIc Shape Compiler project for machine learning workloads. BladeDISC provides general, transparent, and ease of use performance optimization for TensorFlow/PyTorch workloads on GPGPU and CPU backends. If you are interested, please contact us.
-
diff --git a/funasr/export/models/CT_Transformer.py b/funasr/export/models/CT_Transformer.py
deleted file mode 100644
index 2319c4a..0000000
--- a/funasr/export/models/CT_Transformer.py
+++ /dev/null
@@ -1,162 +0,0 @@
-from typing import Tuple
-
-import torch
-import torch.nn as nn
-
-from funasr.models.encoder.sanm_encoder import SANMEncoder
-from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
-from funasr.models.encoder.sanm_encoder import SANMVadEncoder
-from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
-
-class CT_Transformer(nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
-    https://arxiv.org/pdf/2003.01309.pdf
-    """
-    def __init__(
-            self,
-            model,
-            max_seq_len=512,
-            model_name='punc_model',
-            **kwargs,
-    ):
-        super().__init__()
-        onnx = False
-        if "onnx" in kwargs:
-            onnx = kwargs["onnx"]
-        self.embed = model.embed
-        self.decoder = model.decoder
-        # self.model = model
-        self.feats_dim = self.embed.embedding_dim
-        self.num_embeddings = self.embed.num_embeddings
-        self.model_name = model_name
-
-        if isinstance(model.encoder, SANMEncoder):
-            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
-        else:
-            assert False, "Only support samn encode."
-
-    def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
-        """Compute loss value from buffer sequences.
-
-        Args:
-            input (torch.Tensor): Input ids. (batch, len)
-            hidden (torch.Tensor): Target ids. (batch, len)
-
-        """
-        x = self.embed(inputs)
-        # mask = self._target_mask(input)
-        h, _ = self.encoder(x, text_lengths)
-        y = self.decoder(h)
-        return y
-
-    def get_dummy_inputs(self):
-        length = 120
-        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
-        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
-        return (text_indexes, text_lengths)
-
-    def get_input_names(self):
-        return ['inputs', 'text_lengths']
-
-    def get_output_names(self):
-        return ['logits']
-
-    def get_dynamic_axes(self):
-        return {
-            'inputs': {
-                0: 'batch_size',
-                1: 'feats_length'
-            },
-            'text_lengths': {
-                0: 'batch_size',
-            },
-            'logits': {
-                0: 'batch_size',
-                1: 'logits_length'
-            },
-        }
-
-
-class CT_Transformer_VadRealtime(nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
-    https://arxiv.org/pdf/2003.01309.pdf
-    """
-    def __init__(
-        self,
-        model,
-        max_seq_len=512,
-        model_name='punc_model',
-        **kwargs,
-    ):
-        super().__init__()
-        onnx = False
-        if "onnx" in kwargs:
-            onnx = kwargs["onnx"]
-
-        self.embed = model.embed
-        if isinstance(model.encoder, SANMVadEncoder):
-            self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
-        else:
-            assert False, "Only support samn encode."
-        self.decoder = model.decoder
-        self.model_name = model_name
-
-
-
-    def forward(self, inputs: torch.Tensor,
-                text_lengths: torch.Tensor,
-                vad_indexes: torch.Tensor,
-                sub_masks: torch.Tensor,
-                ) -> Tuple[torch.Tensor, None]:
-        """Compute loss value from buffer sequences.
-
-        Args:
-            input (torch.Tensor): Input ids. (batch, len)
-            hidden (torch.Tensor): Target ids. (batch, len)
-
-        """
-        x = self.embed(inputs)
-        # mask = self._target_mask(input)
-        h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
-        y = self.decoder(h)
-        return y
-
-    def with_vad(self):
-        return True
-
-    def get_dummy_inputs(self):
-        length = 120
-        text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
-        text_lengths = torch.tensor([length], dtype=torch.int32)
-        vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
-        sub_masks = torch.ones(length, length, dtype=torch.float32)
-        sub_masks = torch.tril(sub_masks).type(torch.float32)
-        return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
-
-    def get_input_names(self):
-        return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
-
-    def get_output_names(self):
-        return ['logits']
-
-    def get_dynamic_axes(self):
-        return {
-            'inputs': {
-                1: 'feats_length'
-            },
-            'vad_masks': {
-                2: 'feats_length1',
-                3: 'feats_length2'
-            },
-            'sub_masks': {
-                2: 'feats_length1',
-                3: 'feats_length2'
-            },
-            'logits': {
-                1: 'logits_length'
-            },
-        }
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
deleted file mode 100644
index b7b0889..0000000
--- a/funasr/export/models/__init__.py
+++ /dev/null
@@ -1,43 +0,0 @@
-from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
-from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
-from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
-# from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
-
-from funasr.models.e2e_vad import E2EVadModel
-from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
-from funasr.models.target_delay_transformer import TargetDelayTransformer
-from funasr.export.models.CT_Transformer import CT_Transformer as CT_Transformer_export
-from funasr.train.abs_model import PunctuationModel
-from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
-from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
-from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_encoder_predictor as ParaformerOnline_encoder_predictor_export
-from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_decoder as ParaformerOnline_decoder_export
-from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_backbone as ContextualParaformer_backbone_export
-from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_embedder as ContextualParaformer_embedder_export
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-
-
-def get_model(model, export_config=None):
-    if isinstance(model, NeatContextualParaformer):
-        backbone = ContextualParaformer_backbone_export(model, **export_config)
-        embedder = ContextualParaformer_embedder_export(model, **export_config)
-        return [embedder, backbone]
-    elif isinstance(model, BiCifParaformer):
-        return BiCifParaformer_export(model, **export_config)
-    elif isinstance(model, ParaformerOnline):
-        encoder = ParaformerOnline_encoder_predictor_export(model, model_name="model")
-        decoder = ParaformerOnline_decoder_export(model, model_name="decoder")
-        return [encoder, decoder]
-    elif isinstance(model, Paraformer):
-        return Paraformer_export(model, **export_config)
-    # elif isinstance(model, Conformer_export):
-    #     return Conformer_export(model, **export_config)
-    elif isinstance(model, E2EVadModel):
-        return E2EVadModel_export(model, **export_config)
-    elif isinstance(model, PunctuationModel):
-        if isinstance(model.punc_model, TargetDelayTransformer):
-            return CT_Transformer_export(model.punc_model, **export_config)
-        elif isinstance(model.punc_model, VadRealtimeTransformer):
-            return CT_Transformer_VadRealtime_export(model.punc_model, **export_config)
-    else:
-        raise "Funasr does not support the given model type currently."
diff --git a/funasr/export/models/decoder/__init__.py b/funasr/export/models/decoder/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/models/decoder/__init__.py
+++ /dev/null
diff --git a/funasr/export/models/decoder/contextual_decoder.py b/funasr/export/models/decoder/contextual_decoder.py
deleted file mode 100644
index c6f83d0..0000000
--- a/funasr/export/models/decoder/contextual_decoder.py
+++ /dev/null
@@ -1,191 +0,0 @@
-import os
-import torch
-import torch.nn as nn
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
-from funasr.models.transformer.attention import MultiHeadedAttentionCrossAtt
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
-from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
-
-
-class ContextualSANMDecoder(nn.Module):
-    def __init__(self, model,
-                 max_seq_len=512,
-                 model_name='decoder',
-                 onnx: bool = True,):
-        super().__init__()
-        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
-        self.model = model
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
-        for i, d in enumerate(self.model.decoders):
-            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
-                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
-            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
-                d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
-            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
-                d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
-            self.model.decoders[i] = DecoderLayerSANM_export(d)
-
-        if self.model.decoders2 is not None:
-            for i, d in enumerate(self.model.decoders2):
-                if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
-                    d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
-                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
-                    d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
-                self.model.decoders2[i] = DecoderLayerSANM_export(d)
-
-        for i, d in enumerate(self.model.decoders3):
-            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
-                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
-            self.model.decoders3[i] = DecoderLayerSANM_export(d)
-        
-        self.output_layer = model.output_layer
-        self.after_norm = model.after_norm
-        self.model_name = model_name
-
-        # bias decoder
-        if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt):
-            self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn)
-        self.bias_decoder = self.model.bias_decoder
-        # last decoder
-        if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt):
-            self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn)
-        if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder):
-            self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn)
-        if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM):
-            self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward)
-        self.last_decoder = self.model.last_decoder
-        self.bias_output = self.model.bias_output
-        self.dropout = self.model.dropout
-        
-
-    def prepare_mask(self, mask):
-        mask_3d_btd = mask[:, :, None]
-        if len(mask.shape) == 2:
-            mask_4d_bhlt = 1 - mask[:, None, None, :]
-        elif len(mask.shape) == 3:
-            mask_4d_bhlt = 1 - mask[:, None, :]
-        mask_4d_bhlt = mask_4d_bhlt * -10000.0
-    
-        return mask_3d_btd, mask_4d_bhlt
-
-    def forward(
-        self,
-        hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
-        ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
-        bias_embed: torch.Tensor,
-    ):
-
-        tgt = ys_in_pad
-        tgt_mask = self.make_pad_mask(ys_in_lens)
-        tgt_mask, _ = self.prepare_mask(tgt_mask)
-        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
-        memory = hs_pad
-        memory_mask = self.make_pad_mask(hlens)
-        _, memory_mask = self.prepare_mask(memory_mask)
-        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-
-        x = tgt
-        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
-            x, tgt_mask, memory, memory_mask
-        )
-
-        _, _, x_self_attn, x_src_attn = self.last_decoder(
-            x, tgt_mask, memory, memory_mask
-        )
-
-        # contextual paraformer related
-        contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0])
-        # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
-        contextual_mask = self.make_pad_mask(contextual_length)
-        contextual_mask, _ = self.prepare_mask(contextual_mask)
-        # import pdb; pdb.set_trace()
-        contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
-        cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask)
-
-        if self.bias_output is not None:
-            x = torch.cat([x_src_attn, cx], dim=2)
-            x = self.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
-            x = x_self_attn + self.dropout(x)
-
-        if self.model.decoders2 is not None:
-            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
-                x, tgt_mask, memory, memory_mask
-            )
-        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
-            x, tgt_mask, memory, memory_mask
-        )
-        x = self.after_norm(x)
-        x = self.output_layer(x)
-
-        return x, ys_in_lens
-
-
-    def get_dummy_inputs(self, enc_size):
-        tgt = torch.LongTensor([0]).unsqueeze(0)
-        memory = torch.randn(1, 100, enc_size)
-        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        cache = [
-            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
-            for _ in range(cache_num)
-        ]
-        return (tgt, memory, pre_acoustic_embeds, cache)
-
-    def is_optimizable(self):
-        return True
-
-    def get_input_names(self):
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
-               + ['cache_%d' % i for i in range(cache_num)]
-
-    def get_output_names(self):
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        return ['y'] \
-               + ['out_cache_%d' % i for i in range(cache_num)]
-
-    def get_dynamic_axes(self):
-        ret = {
-            'tgt': {
-                0: 'tgt_batch',
-                1: 'tgt_length'
-            },
-            'memory': {
-                0: 'memory_batch',
-                1: 'memory_length'
-            },
-            'pre_acoustic_embeds': {
-                0: 'acoustic_embeds_batch',
-                1: 'acoustic_embeds_length',
-            }
-        }
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        ret.update({
-            'cache_%d' % d: {
-                0: 'cache_%d_batch' % d,
-                2: 'cache_%d_length' % d
-            }
-            for d in range(cache_num)
-        })
-        return ret
-
-    def get_model_config(self, path):
-        return {
-            "dec_type": "XformerDecoder",
-            "model_path": os.path.join(path, f'{self.model_name}.onnx'),
-            "n_layers": len(self.model.decoders) + len(self.model.decoders2),
-            "odim": self.model.decoders[0].size
-        }
diff --git a/funasr/export/models/decoder/sanm_decoder.py b/funasr/export/models/decoder/sanm_decoder.py
deleted file mode 100644
index 8f6e553..0000000
--- a/funasr/export/models/decoder/sanm_decoder.py
+++ /dev/null
@@ -1,314 +0,0 @@
-import os
-
-import torch
-import torch.nn as nn
-
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
-from funasr.models.transformer.attention import MultiHeadedAttentionCrossAtt
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
-from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
-
-
-class ParaformerSANMDecoder(nn.Module):
-    def __init__(self, model,
-                 max_seq_len=512,
-                 model_name='decoder',
-                 onnx: bool = True,):
-        super().__init__()
-        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
-        self.model = model
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
-        for i, d in enumerate(self.model.decoders):
-            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
-                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
-            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
-                d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
-            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
-                d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
-            self.model.decoders[i] = DecoderLayerSANM_export(d)
-
-        if self.model.decoders2 is not None:
-            for i, d in enumerate(self.model.decoders2):
-                if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
-                    d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
-                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
-                    d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
-                self.model.decoders2[i] = DecoderLayerSANM_export(d)
-
-        for i, d in enumerate(self.model.decoders3):
-            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
-                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
-            self.model.decoders3[i] = DecoderLayerSANM_export(d)
-        
-        self.output_layer = model.output_layer
-        self.after_norm = model.after_norm
-        self.model_name = model_name
-        
-
-    def prepare_mask(self, mask):
-        mask_3d_btd = mask[:, :, None]
-        if len(mask.shape) == 2:
-            mask_4d_bhlt = 1 - mask[:, None, None, :]
-        elif len(mask.shape) == 3:
-            mask_4d_bhlt = 1 - mask[:, None, :]
-        mask_4d_bhlt = mask_4d_bhlt * -10000.0
-    
-        return mask_3d_btd, mask_4d_bhlt
-
-    def forward(
-        self,
-        hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
-        ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
-    ):
-
-        tgt = ys_in_pad
-        tgt_mask = self.make_pad_mask(ys_in_lens)
-        tgt_mask, _ = self.prepare_mask(tgt_mask)
-        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
-        memory = hs_pad
-        memory_mask = self.make_pad_mask(hlens)
-        _, memory_mask = self.prepare_mask(memory_mask)
-        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-
-        x = tgt
-        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
-            x, tgt_mask, memory, memory_mask
-        )
-        if self.model.decoders2 is not None:
-            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
-                x, tgt_mask, memory, memory_mask
-            )
-        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
-            x, tgt_mask, memory, memory_mask
-        )
-        x = self.after_norm(x)
-        x = self.output_layer(x)
-
-        return x, ys_in_lens
-
-
-    def get_dummy_inputs(self, enc_size):
-        tgt = torch.LongTensor([0]).unsqueeze(0)
-        memory = torch.randn(1, 100, enc_size)
-        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        cache = [
-            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
-            for _ in range(cache_num)
-        ]
-        return (tgt, memory, pre_acoustic_embeds, cache)
-
-    def is_optimizable(self):
-        return True
-
-    def get_input_names(self):
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
-               + ['cache_%d' % i for i in range(cache_num)]
-
-    def get_output_names(self):
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        return ['y'] \
-               + ['out_cache_%d' % i for i in range(cache_num)]
-
-    def get_dynamic_axes(self):
-        ret = {
-            'tgt': {
-                0: 'tgt_batch',
-                1: 'tgt_length'
-            },
-            'memory': {
-                0: 'memory_batch',
-                1: 'memory_length'
-            },
-            'pre_acoustic_embeds': {
-                0: 'acoustic_embeds_batch',
-                1: 'acoustic_embeds_length',
-            }
-        }
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        ret.update({
-            'cache_%d' % d: {
-                0: 'cache_%d_batch' % d,
-                2: 'cache_%d_length' % d
-            }
-            for d in range(cache_num)
-        })
-        return ret
-
-    def get_model_config(self, path):
-        return {
-            "dec_type": "XformerDecoder",
-            "model_path": os.path.join(path, f'{self.model_name}.onnx'),
-            "n_layers": len(self.model.decoders) + len(self.model.decoders2),
-            "odim": self.model.decoders[0].size
-        }
-
-
-class ParaformerSANMDecoderOnline(nn.Module):
-    def __init__(self, model,
-                 max_seq_len=512,
-                 model_name='decoder',
-                 onnx: bool = True, ):
-        super().__init__()
-        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
-        self.model = model
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-        
-        for i, d in enumerate(self.model.decoders):
-            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
-                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
-            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
-                d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
-            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
-                d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
-            self.model.decoders[i] = DecoderLayerSANM_export(d)
-        
-        if self.model.decoders2 is not None:
-            for i, d in enumerate(self.model.decoders2):
-                if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
-                    d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
-                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
-                    d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
-                self.model.decoders2[i] = DecoderLayerSANM_export(d)
-        
-        for i, d in enumerate(self.model.decoders3):
-            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
-                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
-            self.model.decoders3[i] = DecoderLayerSANM_export(d)
-        
-        self.output_layer = model.output_layer
-        self.after_norm = model.after_norm
-        self.model_name = model_name
-    
-    def prepare_mask(self, mask):
-        mask_3d_btd = mask[:, :, None]
-        if len(mask.shape) == 2:
-            mask_4d_bhlt = 1 - mask[:, None, None, :]
-        elif len(mask.shape) == 3:
-            mask_4d_bhlt = 1 - mask[:, None, :]
-        mask_4d_bhlt = mask_4d_bhlt * -10000.0
-        
-        return mask_3d_btd, mask_4d_bhlt
-    
-    def forward(
-        self,
-        hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
-        ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
-        *args,
-    ):
-        
-        tgt = ys_in_pad
-        tgt_mask = self.make_pad_mask(ys_in_lens)
-        tgt_mask, _ = self.prepare_mask(tgt_mask)
-        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-        
-        memory = hs_pad
-        memory_mask = self.make_pad_mask(hlens)
-        _, memory_mask = self.prepare_mask(memory_mask)
-        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-        
-        x = tgt
-        out_caches = list()
-        for i, decoder in enumerate(self.model.decoders):
-            in_cache = args[i]
-            x, tgt_mask, memory, memory_mask, out_cache = decoder(
-                x, tgt_mask, memory, memory_mask, cache=in_cache
-            )
-            out_caches.append(out_cache)
-        if self.model.decoders2 is not None:
-            for i, decoder in enumerate(self.model.decoders2):
-                in_cache = args[i+len(self.model.decoders)]
-                x, tgt_mask, memory, memory_mask, out_cache = decoder(
-                    x, tgt_mask, memory, memory_mask, cache=in_cache
-                )
-                out_caches.append(out_cache)
-        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
-            x, tgt_mask, memory, memory_mask
-        )
-        x = self.after_norm(x)
-        x = self.output_layer(x)
-        
-        return x, out_caches
-    
-    def get_dummy_inputs(self, enc_size):
-        enc = torch.randn(2, 100, enc_size).type(torch.float32)
-        enc_len = torch.tensor([30, 100], dtype=torch.int32)
-        acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
-        acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
-        cache_num = len(self.model.decoders)
-        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
-            cache_num += len(self.model.decoders2)
-        cache = [
-            torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size-1), dtype=torch.float32)
-            for _ in range(cache_num)
-        ]
-        return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
-
-    def get_input_names(self):
-        cache_num = len(self.model.decoders)
-        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
-            cache_num += len(self.model.decoders2)
-        return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
-               + ['in_cache_%d' % i for i in range(cache_num)]
-
-    def get_output_names(self):
-        cache_num = len(self.model.decoders)
-        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
-            cache_num += len(self.model.decoders2)
-        return ['logits', 'sample_ids'] \
-               + ['out_cache_%d' % i for i in range(cache_num)]
-
-    def get_dynamic_axes(self):
-        ret = {
-            'enc': {
-                0: 'batch_size',
-                1: 'enc_length'
-            },
-            'acoustic_embeds': {
-                0: 'batch_size',
-                1: 'token_length'
-            },
-            'enc_len': {
-                0: 'batch_size',
-            },
-            'acoustic_embeds_len': {
-                0: 'batch_size',
-            },
-        
-        }
-        cache_num = len(self.model.decoders)
-        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
-            cache_num += len(self.model.decoders2)
-        ret.update({
-            'in_cache_%d' % d: {
-                0: 'batch_size',
-            }
-            for d in range(cache_num)
-        })
-        ret.update({
-            'out_cache_%d' % d: {
-                0: 'batch_size',
-            }
-            for d in range(cache_num)
-        })
-        return ret
diff --git a/funasr/export/models/decoder/transformer_decoder.py b/funasr/export/models/decoder/transformer_decoder.py
deleted file mode 100644
index 727f5d7..0000000
--- a/funasr/export/models/decoder/transformer_decoder.py
+++ /dev/null
@@ -1,143 +0,0 @@
-import os
-from funasr.export import models
-
-import torch
-import torch.nn as nn
-
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
-from funasr.models.transformer.attention import MultiHeadedAttentionCrossAtt, MultiHeadedAttention
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
-from funasr.export.models.modules.multihead_att import OnnxMultiHeadedAttention
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
-from funasr.export.models.modules.decoder_layer import DecoderLayer as DecoderLayer_export
-
-
-class ParaformerDecoderSAN(nn.Module):
-    def __init__(self, model,
-                 max_seq_len=512,
-                 model_name='decoder',
-                 onnx: bool = True,):
-        super().__init__()
-        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
-        self.model = model
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
-        for i, d in enumerate(self.model.decoders):
-            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
-                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
-            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
-                d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
-            # if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
-            #     d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
-            if isinstance(d.src_attn, MultiHeadedAttention):
-                d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
-            self.model.decoders[i] = DecoderLayer_export(d)
-        
-        self.output_layer = model.output_layer
-        self.after_norm = model.after_norm
-        self.model_name = model_name
-        
-
-    def prepare_mask(self, mask):
-        mask_3d_btd = mask[:, :, None]
-        if len(mask.shape) == 2:
-            mask_4d_bhlt = 1 - mask[:, None, None, :]
-        elif len(mask.shape) == 3:
-            mask_4d_bhlt = 1 - mask[:, None, :]
-        mask_4d_bhlt = mask_4d_bhlt * -10000.0
-    
-        return mask_3d_btd, mask_4d_bhlt
-
-    def forward(
-        self,
-        hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
-        ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
-    ):
-
-        tgt = ys_in_pad
-        tgt_mask = self.make_pad_mask(ys_in_lens)
-        tgt_mask, _ = self.prepare_mask(tgt_mask)
-        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
-        memory = hs_pad
-        memory_mask = self.make_pad_mask(hlens)
-        _, memory_mask = self.prepare_mask(memory_mask)
-        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-
-        x = tgt
-        x, tgt_mask, memory, memory_mask = self.model.decoders(
-            x, tgt_mask, memory, memory_mask
-        )
-        x = self.after_norm(x)
-        x = self.output_layer(x)
-
-        return x, ys_in_lens
-
-
-    def get_dummy_inputs(self, enc_size):
-        tgt = torch.LongTensor([0]).unsqueeze(0)
-        memory = torch.randn(1, 100, enc_size)
-        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        cache = [
-            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
-            for _ in range(cache_num)
-        ]
-        return (tgt, memory, pre_acoustic_embeds, cache)
-
-    def is_optimizable(self):
-        return True
-
-    def get_input_names(self):
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
-               + ['cache_%d' % i for i in range(cache_num)]
-
-    def get_output_names(self):
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        return ['y'] \
-               + ['out_cache_%d' % i for i in range(cache_num)]
-
-    def get_dynamic_axes(self):
-        ret = {
-            'tgt': {
-                0: 'tgt_batch',
-                1: 'tgt_length'
-            },
-            'memory': {
-                0: 'memory_batch',
-                1: 'memory_length'
-            },
-            'pre_acoustic_embeds': {
-                0: 'acoustic_embeds_batch',
-                1: 'acoustic_embeds_length',
-            }
-        }
-        cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        ret.update({
-            'cache_%d' % d: {
-                0: 'cache_%d_batch' % d,
-                2: 'cache_%d_length' % d
-            }
-            for d in range(cache_num)
-        })
-        return ret
-
-    def get_model_config(self, path):
-        return {
-            "dec_type": "XformerDecoder",
-            "model_path": os.path.join(path, f'{self.model_name}.onnx'),
-            "n_layers": len(self.model.decoders) + len(self.model.decoders2),
-            "odim": self.model.decoders[0].size
-        }
\ No newline at end of file
diff --git a/funasr/export/models/decoder/xformer_decoder.py b/funasr/export/models/decoder/xformer_decoder.py
deleted file mode 100644
index 9215c00..0000000
--- a/funasr/export/models/decoder/xformer_decoder.py
+++ /dev/null
@@ -1,121 +0,0 @@
-import os
-
-import torch
-import torch.nn as nn
-
-from funasr.models.transformer.attention import MultiHeadedAttention
-
-from funasr.export.models.modules.decoder_layer import DecoderLayer as OnnxDecoderLayer
-from funasr.export.models.language_models.embed import Embedding
-from funasr.export.models.modules.multihead_att import \
-    OnnxMultiHeadedAttention
-
-from funasr.export.utils.torch_function import MakePadMask, subsequent_mask
-
-class XformerDecoder(nn.Module):
-    def __init__(self,
-                 model,
-                 max_seq_len = 512,
-                 model_name = 'decoder',
-                 onnx: bool = True,):
-        super().__init__()
-        self.embed = Embedding(model.embed, max_seq_len)
-        self.model = model
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = subsequent_mask(max_seq_len, flip=False)
-
-        if isinstance(self.model.decoders[0].self_attn, MultiHeadedAttention):
-            self.num_heads = self.model.decoders[0].self_attn.h
-            self.hidden_size = self.model.decoders[0].self_attn.linear_out.out_features
-
-        # replace multi-head attention module into customized module.
-        for i, d in enumerate(self.model.decoders):
-            # d is DecoderLayer
-            if isinstance(d.self_attn, MultiHeadedAttention):
-                d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
-            if isinstance(d.src_attn, MultiHeadedAttention):
-                d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
-            self.model.decoders[i] = OnnxDecoderLayer(d)
-
-        self.model_name = model_name
-
-    def prepare_mask(self, mask):
-        mask_3d_btd = mask[:, :, None]
-        if len(mask.shape) == 2:
-            mask_4d_bhlt = 1 - mask[:, None, None, :]
-        elif len(mask.shape) == 3:
-            mask_4d_bhlt = 1 - mask[:, None, :]
-
-        mask_4d_bhlt = mask_4d_bhlt * -10000.0
-        return mask_3d_btd, mask_4d_bhlt
-
-    def forward(self,
-                tgt,
-                memory,
-                cache):
-
-        mask = subsequent_mask(tgt.size(-1)).unsqueeze(0)  # (B, T)
-
-        x = self.embed(tgt)
-        mask = self.prepare_mask(mask)
-        new_cache = []
-        for c, decoder in zip(cache, self.model.decoders):
-            x, mask = decoder(x, mask, memory, None, c)
-            new_cache.append(x)
-            x = x[:, 1:, :]
-
-        if self.model.normalize_before:
-            y = self.model.after_norm(x[:, -1])
-        else:
-            y = x[:, -1]
-
-        if self.model.output_layer is not None:
-            y = torch.log_softmax(self.model.output_layer(y), dim=-1)
-        return y, new_cache
-
-    def get_dummy_inputs(self, enc_size):
-        tgt = torch.LongTensor([0]).unsqueeze(0)
-        memory = torch.randn(1, 100, enc_size)
-        cache_num = len(self.model.decoders)
-        cache = [
-            torch.zeros((1, 1, self.model.decoders[0].size))
-            for _ in range(cache_num)
-        ]
-        return (tgt, memory, cache)
-
-    def is_optimizable(self):
-        return True
-
-    def get_input_names(self):
-        cache_num = len(self.model.decoders)
-        return ["tgt", "memory"] + [
-            "cache_%d" % i for i in range(cache_num)
-        ]
-
-    def get_output_names(self):
-        cache_num = len(self.model.decoders)
-        return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
-
-    def get_dynamic_axes(self):
-        ret = {
-            "tgt": {0: "tgt_batch", 1: "tgt_length"},
-            "memory": {0: "memory_batch", 1: "memory_length"},
-        }
-        cache_num = len(self.model.decoders)
-        ret.update(
-            {
-                "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
-                for d in range(cache_num)
-            }
-        )
-        return ret
-
-    def get_model_config(self, path):
-        return {
-            "dec_type": "XformerDecoder",
-            "model_path": os.path.join(path, f"{self.model_name}.onnx"),
-            "n_layers": len(self.model.decoders),
-            "odim": self.model.decoders[0].size,
-        }
diff --git a/funasr/export/models/e2e_asr_contextual_paraformer.py b/funasr/export/models/e2e_asr_contextual_paraformer.py
deleted file mode 100644
index 0a3eba6..0000000
--- a/funasr/export/models/e2e_asr_contextual_paraformer.py
+++ /dev/null
@@ -1,174 +0,0 @@
-from audioop import bias
-import logging
-import torch
-import torch.nn as nn
-import numpy as np
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
-from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
-from funasr.models.predictor.cif import CifPredictorV2
-from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
-from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
-from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
-from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
-from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
-from funasr.export.models.decoder.contextual_decoder import ContextualSANMDecoder as ContextualSANMDecoder_export
-from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
-
-
-class ContextualParaformer_backbone(nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
-    https://arxiv.org/abs/2206.08317
-    """
-
-    def __init__(
-            self,
-            model,
-            max_seq_len=512,
-            feats_dim=560,
-            model_name='model',
-            **kwargs,
-    ):
-        super().__init__()
-        onnx = False
-        if "onnx" in kwargs:
-            onnx = kwargs["onnx"]
-        if isinstance(model.encoder, SANMEncoder):
-            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
-        elif isinstance(model.encoder, ConformerEncoder):
-            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
-        if isinstance(model.predictor, CifPredictorV2):
-            self.predictor = CifPredictorV2_export(model.predictor)
-        
-        # decoder
-        if isinstance(model.decoder, ContextualParaformerDecoder):
-            self.decoder = ContextualSANMDecoder_export(model.decoder, onnx=onnx)
-        elif isinstance(model.decoder, ParaformerSANMDecoder):
-            self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
-        elif isinstance(model.decoder, ParaformerDecoderSAN):
-            self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
-        
-        self.feats_dim = feats_dim
-        self.model_name = model_name
-
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-        
-    def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-            bias_embed: torch.Tensor,
-    ):
-        # a. To device
-        batch = {"speech": speech, "speech_lengths": speech_lengths}
-        # batch = to_device(batch, device=self.device)
-    
-        enc, enc_len = self.encoder(**batch)
-        mask = self.make_pad_mask(enc_len)[:, None, :]
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
-        pre_token_length = pre_token_length.floor().type(torch.int32)
-
-        # bias_embed = bias_embed. squeeze(0).repeat([enc.shape[0], 1, 1])
-
-        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, bias_embed)
-        decoder_out = torch.log_softmax(decoder_out, dim=-1)
-        # sample_ids = decoder_out.argmax(dim=-1)
-        return decoder_out, pre_token_length
-
-    def get_dummy_inputs(self):
-        speech = torch.randn(2, 30, self.feats_dim)
-        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
-        bias_embed = torch.randn(2, 1, 512)
-        return (speech, speech_lengths, bias_embed)
-
-    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
-        import numpy as np
-        fbank = np.loadtxt(txt_file)
-        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
-        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
-        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
-        return (speech, speech_lengths)
-
-    def get_input_names(self):
-        return ['speech', 'speech_lengths', 'bias_embed']
-
-    def get_output_names(self):
-        return ['logits', 'token_num']
-
-    def get_dynamic_axes(self):
-        return {
-            'speech': {
-                0: 'batch_size',
-                1: 'feats_length'
-            },
-            'speech_lengths': {
-                0: 'batch_size',
-            },
-            'bias_embed': {
-                0: 'batch_size',
-                1: 'num_hotwords'
-            },
-            'logits': {
-                0: 'batch_size',
-                1: 'logits_length'
-            },
-        }
-
-
-class ContextualParaformer_embedder(nn.Module):
-    def __init__(self,
-                 model,
-                 max_seq_len=512,
-                 feats_dim=560,
-                 model_name='model',
-                 **kwargs,):
-        super().__init__()
-        self.embedding = model.bias_embed
-        model.bias_encoder.batch_first = False
-        self.bias_encoder = model.bias_encoder
-        # self.bias_encoder.batch_first = False
-        self.feats_dim = feats_dim
-        self.model_name = "{}_eb".format(model_name)
-    
-    def forward(self, hotword):
-        hotword = self.embedding(hotword).transpose(0, 1) # batch second
-        hw_embed, (_, _) = self.bias_encoder(hotword)
-        return hw_embed
-    
-    def get_dummy_inputs(self):
-        hotword = torch.tensor([
-                                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14], 
-                                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
-                                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-                                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14], 
-                                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
-                                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
-                               ], 
-                                dtype=torch.int32)
-        # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
-        return (hotword)
-
-    def get_input_names(self):
-        return ['hotword']
-
-    def get_output_names(self):
-        return ['hw_embed']
-
-    def get_dynamic_axes(self):
-        return {
-            'hotword': {
-                0: 'num_hotwords',
-            },
-            'hw_embed': {
-                0: 'num_hotwords',
-            },
-        }
\ No newline at end of file
diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py
deleted file mode 100644
index 5697b77..0000000
--- a/funasr/export/models/e2e_asr_paraformer.py
+++ /dev/null
@@ -1,366 +0,0 @@
-import logging
-import torch
-import torch.nn as nn
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
-from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
-from funasr.models.predictor.cif import CifPredictorV2, CifPredictorV3
-from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
-from funasr.export.models.predictor.cif import CifPredictorV3 as CifPredictorV3_export
-from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
-from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
-from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
-from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
-from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoderOnline as ParaformerSANMDecoderOnline_export
-
-
-class Paraformer(nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
-    https://arxiv.org/abs/2206.08317
-    """
-
-    def __init__(
-            self,
-            model,
-            max_seq_len=512,
-            feats_dim=560,
-            model_name='model',
-            **kwargs,
-    ):
-        super().__init__()
-        onnx = False
-        if "onnx" in kwargs:
-            onnx = kwargs["onnx"]
-        if isinstance(model.encoder, SANMEncoder):
-            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
-        elif isinstance(model.encoder, ConformerEncoder):
-            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
-        if isinstance(model.predictor, CifPredictorV2):
-            self.predictor = CifPredictorV2_export(model.predictor)
-        if isinstance(model.decoder, ParaformerSANMDecoder):
-            self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
-        elif isinstance(model.decoder, ParaformerDecoderSAN):
-            self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
-        
-        self.feats_dim = feats_dim
-        self.model_name = model_name
-
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-        
-    def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-    ):
-        # a. To device
-        batch = {"speech": speech, "speech_lengths": speech_lengths}
-        # batch = to_device(batch, device=self.device)
-    
-        enc, enc_len = self.encoder(**batch)
-        mask = self.make_pad_mask(enc_len)[:, None, :]
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
-        pre_token_length = pre_token_length.floor().type(torch.int32)
-
-        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
-        decoder_out = torch.log_softmax(decoder_out, dim=-1)
-        # sample_ids = decoder_out.argmax(dim=-1)
-
-        return decoder_out, pre_token_length
-
-    def get_dummy_inputs(self):
-        speech = torch.randn(2, 30, self.feats_dim)
-        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
-        return (speech, speech_lengths)
-
-    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
-        import numpy as np
-        fbank = np.loadtxt(txt_file)
-        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
-        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
-        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
-        return (speech, speech_lengths)
-
-    def get_input_names(self):
-        return ['speech', 'speech_lengths']
-
-    def get_output_names(self):
-        return ['logits', 'token_num']
-
-    def get_dynamic_axes(self):
-        return {
-            'speech': {
-                0: 'batch_size',
-                1: 'feats_length'
-            },
-            'speech_lengths': {
-                0: 'batch_size',
-            },
-            'logits': {
-                0: 'batch_size',
-                1: 'logits_length'
-            },
-        }
-
-
-class BiCifParaformer(nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
-    https://arxiv.org/abs/2206.08317
-    """
-
-    def __init__(
-            self,
-            model,
-            max_seq_len=512,
-            feats_dim=560,
-            model_name='model',
-            **kwargs,
-    ):
-        super().__init__()
-        onnx = False
-        if "onnx" in kwargs:
-            onnx = kwargs["onnx"]
-        if isinstance(model.encoder, SANMEncoder):
-            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
-        elif isinstance(model.encoder, ConformerEncoder):
-            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
-        else:
-            logging.warning("Unsupported encoder type to export.")
-        if isinstance(model.predictor, CifPredictorV3):
-            self.predictor = CifPredictorV3_export(model.predictor)
-        else:
-            logging.warning("Wrong predictor type to export.")
-        if isinstance(model.decoder, ParaformerSANMDecoder):
-            self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
-        elif isinstance(model.decoder, ParaformerDecoderSAN):
-            self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
-        else:
-            logging.warning("Unsupported decoder type to export.")
-        
-        self.feats_dim = feats_dim
-        self.model_name = model_name
-
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-        
-    def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-    ):
-        # a. To device
-        batch = {"speech": speech, "speech_lengths": speech_lengths}
-        # batch = to_device(batch, device=self.device)
-    
-        enc, enc_len = self.encoder(**batch)
-        mask = self.make_pad_mask(enc_len)[:, None, :]
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
-        pre_token_length = pre_token_length.round().type(torch.int32)
-
-        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
-        decoder_out = torch.log_softmax(decoder_out, dim=-1)
-        
-        # get predicted timestamps
-        us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
-
-        return decoder_out, pre_token_length, us_alphas, us_cif_peak
-
-    def get_dummy_inputs(self):
-        speech = torch.randn(2, 30, self.feats_dim)
-        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
-        return (speech, speech_lengths)
-
-    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
-        import numpy as np
-        fbank = np.loadtxt(txt_file)
-        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
-        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
-        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
-        return (speech, speech_lengths)
-
-    def get_input_names(self):
-        return ['speech', 'speech_lengths']
-
-    def get_output_names(self):
-        return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
-
-    def get_dynamic_axes(self):
-        return {
-            'speech': {
-                0: 'batch_size',
-                1: 'feats_length'
-            },
-            'speech_lengths': {
-                0: 'batch_size',
-            },
-            'logits': {
-                0: 'batch_size',
-                1: 'logits_length'
-            },
-            'us_alphas': {
-                0: 'batch_size',
-                1: 'alphas_length'
-            },
-            'us_cif_peak': {
-                0: 'batch_size',
-                1: 'alphas_length'
-            },
-        }
-
-
-class ParaformerOnline_encoder_predictor(nn.Module):
-    """
-    Author: Speech Lab, Alibaba Group, China
-    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
-    https://arxiv.org/abs/2206.08317
-    """
-    
-    def __init__(
-        self,
-        model,
-        max_seq_len=512,
-        feats_dim=560,
-        model_name='model',
-        **kwargs,
-    ):
-        super().__init__()
-        onnx = False
-        if "onnx" in kwargs:
-            onnx = kwargs["onnx"]
-        if isinstance(model.encoder, SANMEncoder) or isinstance(model.encoder, SANMEncoderChunkOpt):
-            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
-        elif isinstance(model.encoder, ConformerEncoder):
-            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
-        if isinstance(model.predictor, CifPredictorV2):
-            self.predictor = CifPredictorV2_export(model.predictor)
-        
-        self.feats_dim = feats_dim
-        self.model_name = model_name
-        
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-    
-    def forward(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-    ):
-        # a. To device
-        batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
-        # batch = to_device(batch, device=self.device)
-        
-        enc, enc_len = self.encoder(**batch)
-        mask = self.make_pad_mask(enc_len)[:, None, :]
-        alphas, _ = self.predictor.forward_cnn(enc, mask)
-        
-        return enc, enc_len, alphas
-    
-    def get_dummy_inputs(self):
-        speech = torch.randn(2, 30, self.feats_dim)
-        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
-        return (speech, speech_lengths)
-    
-    def get_input_names(self):
-        return ['speech', 'speech_lengths']
-    
-    def get_output_names(self):
-        return ['enc', 'enc_len', 'alphas']
-    
-    def get_dynamic_axes(self):
-        return {
-            'speech': {
-                0: 'batch_size',
-                1: 'feats_length'
-            },
-            'speech_lengths': {
-                0: 'batch_size',
-            },
-            'enc': {
-                0: 'batch_size',
-                1: 'feats_length'
-            },
-            'enc_len': {
-                0: 'batch_size',
-            },
-            'alphas': {
-                0: 'batch_size',
-                1: 'feats_length'
-            },
-        }
-
-
-class ParaformerOnline_decoder(nn.Module):
-    """
-    Author: Speech Lab, Alibaba Group, China
-    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
-    https://arxiv.org/abs/2206.08317
-    """
-    
-    def __init__(
-        self,
-        model,
-        max_seq_len=512,
-        feats_dim=560,
-        model_name='model',
-        **kwargs,
-    ):
-        super().__init__()
-        onnx = False
-        if "onnx" in kwargs:
-            onnx = kwargs["onnx"]
-
-        if isinstance(model.decoder, ParaformerDecoderSAN):
-            self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
-        elif isinstance(model.decoder, ParaformerSANMDecoder):
-            self.decoder = ParaformerSANMDecoderOnline_export(model.decoder, onnx=onnx)
-        
-        self.feats_dim = feats_dim
-        self.model_name = model_name
-        self.enc_size = model.encoder._output_size
-        
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-    
-    def forward(
-        self,
-        enc: torch.Tensor,
-        enc_len: torch.Tensor,
-        acoustic_embeds: torch.Tensor,
-        acoustic_embeds_len: torch.Tensor,
-        *args,
-    ):
-        decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
-        sample_ids = decoder_out.argmax(dim=-1)
-        
-        return decoder_out, sample_ids, out_caches
-    
-    def get_dummy_inputs(self, ):
-        dummy_inputs = self.decoder.get_dummy_inputs(enc_size=self.enc_size)
-        return dummy_inputs
-
-    def get_input_names(self):
-        
-        return self.decoder.get_input_names()
-
-    def get_output_names(self):
-        
-        return self.decoder.get_output_names()
-
-    def get_dynamic_axes(self):
-        return self.decoder.get_dynamic_axes()
diff --git a/funasr/export/models/e2e_vad.py b/funasr/export/models/e2e_vad.py
deleted file mode 100644
index d3e8f30..0000000
--- a/funasr/export/models/e2e_vad.py
+++ /dev/null
@@ -1,60 +0,0 @@
-from enum import Enum
-from typing import List, Tuple, Dict, Any
-
-import torch
-from torch import nn
-import math
-
-from funasr.models.encoder.fsmn_encoder import FSMN
-from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
-
-class E2EVadModel(nn.Module):
-    def __init__(self, model,
-                max_seq_len=512,
-                feats_dim=400,
-                model_name='model',
-                **kwargs,):
-        super(E2EVadModel, self).__init__()
-        self.feats_dim = feats_dim
-        self.max_seq_len = max_seq_len
-        self.model_name = model_name
-        if isinstance(model.encoder, FSMN):
-            self.encoder = FSMN_export(model.encoder)
-        else:
-            raise "unsupported encoder"
-        
-
-    def forward(self, feats: torch.Tensor, *args, ):
-
-        scores, out_caches = self.encoder(feats, *args)
-        return scores, out_caches
-
-    def get_dummy_inputs(self, frame=30):
-        speech = torch.randn(1, frame, self.feats_dim)
-        in_cache0 = torch.randn(1, 128, 19, 1)
-        in_cache1 = torch.randn(1, 128, 19, 1)
-        in_cache2 = torch.randn(1, 128, 19, 1)
-        in_cache3 = torch.randn(1, 128, 19, 1)
-        
-        return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
-
-    # def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
-    #     import numpy as np
-    #     fbank = np.loadtxt(txt_file)
-    #     fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
-    #     speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
-    #     speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
-    #     return (speech, speech_lengths)
-
-    def get_input_names(self):
-        return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
-
-    def get_output_names(self):
-        return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
-
-    def get_dynamic_axes(self):
-        return {
-            'speech': {
-                1: 'feats_length'
-            },
-        }
diff --git a/funasr/export/models/encoder/__init__.py b/funasr/export/models/encoder/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/models/encoder/__init__.py
+++ /dev/null
diff --git a/funasr/export/models/encoder/conformer_encoder.py b/funasr/export/models/encoder/conformer_encoder.py
deleted file mode 100644
index 76719c5..0000000
--- a/funasr/export/models/encoder/conformer_encoder.py
+++ /dev/null
@@ -1,105 +0,0 @@
-import torch
-import torch.nn as nn
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.transformer.attention import MultiHeadedAttentionSANM
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
-from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
-from funasr.export.models.modules.encoder_layer import EncoderLayerConformer as EncoderLayerConformer_export
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
-from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
-from funasr.export.models.encoder.sanm_encoder import SANMEncoder
-from funasr.models.transformer.attention import RelPositionMultiHeadedAttention
-# from funasr.export.models.modules.multihead_att import RelPositionMultiHeadedAttention as RelPositionMultiHeadedAttention_export
-from funasr.export.models.modules.multihead_att import OnnxRelPosMultiHeadedAttention as RelPositionMultiHeadedAttention_export
-
-
-class ConformerEncoder(nn.Module):
-    def __init__(
-        self,
-        model,
-        max_seq_len=512,
-        feats_dim=560,
-        model_name='encoder',
-        onnx: bool = True,
-    ):
-        super().__init__()
-        self.embed = model.embed
-        self.model = model
-        self.feats_dim = feats_dim
-        self._output_size = model._output_size
-
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
-        for i, d in enumerate(self.model.encoders):
-            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
-                d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
-            if isinstance(d.self_attn, RelPositionMultiHeadedAttention):
-                d.self_attn = RelPositionMultiHeadedAttention_export(d.self_attn)
-            if isinstance(d.feed_forward, PositionwiseFeedForward):
-                d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
-            self.model.encoders[i] = EncoderLayerConformer_export(d)
-        
-        self.model_name = model_name
-        self.num_heads = model.encoders[0].self_attn.h
-        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
-
-    
-    def prepare_mask(self, mask):
-        if len(mask.shape) == 2:
-            mask = 1 - mask[:, None, None, :]
-        elif len(mask.shape) == 3:
-            mask = 1 - mask[:, None, :]
-        
-        return mask * -10000.0
-
-    def forward(self,
-                speech: torch.Tensor,
-                speech_lengths: torch.Tensor,
-                ):
-        mask = self.make_pad_mask(speech_lengths)
-        mask = self.prepare_mask(mask)
-        if self.embed is None:
-            xs_pad = speech
-        else:
-            xs_pad = self.embed(speech)
-
-        encoder_outs = self.model.encoders(xs_pad, mask)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
-        if isinstance(xs_pad, tuple):
-            xs_pad = xs_pad[0]
-        xs_pad = self.model.after_norm(xs_pad)
-
-        return xs_pad, speech_lengths
-
-    def get_output_size(self):
-        return self.model.encoders[0].size
-
-    def get_dummy_inputs(self):
-        feats = torch.randn(1, 100, self.feats_dim)
-        return (feats)
-
-    def get_input_names(self):
-        return ['feats']
-
-    def get_output_names(self):
-        return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
-
-    def get_dynamic_axes(self):
-        return {
-            'feats': {
-                1: 'feats_length'
-            },
-            'encoder_out': {
-                1: 'enc_out_length'
-            },
-            'predictor_weight':{
-                1: 'pre_out_length'
-            }
-
-        }
diff --git a/funasr/export/models/encoder/fsmn_encoder.py b/funasr/export/models/encoder/fsmn_encoder.py
deleted file mode 100755
index b8e6433..0000000
--- a/funasr/export/models/encoder/fsmn_encoder.py
+++ /dev/null
@@ -1,296 +0,0 @@
-from typing import Tuple, Dict
-import copy
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from funasr.models.encoder.fsmn_encoder import BasicBlock
-
-class LinearTransform(nn.Module):
-
-    def __init__(self, input_dim, output_dim):
-        super(LinearTransform, self).__init__()
-        self.input_dim = input_dim
-        self.output_dim = output_dim
-        self.linear = nn.Linear(input_dim, output_dim, bias=False)
-
-    def forward(self, input):
-        output = self.linear(input)
-
-        return output
-
-
-class AffineTransform(nn.Module):
-
-    def __init__(self, input_dim, output_dim):
-        super(AffineTransform, self).__init__()
-        self.input_dim = input_dim
-        self.output_dim = output_dim
-        self.linear = nn.Linear(input_dim, output_dim)
-
-    def forward(self, input):
-        output = self.linear(input)
-
-        return output
-
-
-class RectifiedLinear(nn.Module):
-
-    def __init__(self, input_dim, output_dim):
-        super(RectifiedLinear, self).__init__()
-        self.dim = input_dim
-        self.relu = nn.ReLU()
-        self.dropout = nn.Dropout(0.1)
-
-    def forward(self, input):
-        out = self.relu(input)
-        return out
-
-
-class FSMNBlock(nn.Module):
-
-    def __init__(
-            self,
-            input_dim: int,
-            output_dim: int,
-            lorder=None,
-            rorder=None,
-            lstride=1,
-            rstride=1,
-    ):
-        super(FSMNBlock, self).__init__()
-
-        self.dim = input_dim
-
-        if lorder is None:
-            return
-
-        self.lorder = lorder
-        self.rorder = rorder
-        self.lstride = lstride
-        self.rstride = rstride
-
-        self.conv_left = nn.Conv2d(
-            self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
-
-        if self.rorder > 0:
-            self.conv_right = nn.Conv2d(
-                self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
-        else:
-            self.conv_right = None
-
-    def forward(self, input: torch.Tensor, cache: torch.Tensor):
-        x = torch.unsqueeze(input, 1)
-        x_per = x.permute(0, 3, 2, 1)  # B D T C
-        
-        cache = cache.to(x_per.device)
-        y_left = torch.cat((cache, x_per), dim=2)
-        cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
-        y_left = self.conv_left(y_left)
-        out = x_per + y_left
-
-        if self.conv_right is not None:
-            # maybe need to check
-            y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
-            y_right = y_right[:, :, self.rstride:, :]
-            y_right = self.conv_right(y_right)
-            out += y_right
-
-        out_per = out.permute(0, 3, 2, 1)
-        output = out_per.squeeze(1)
-
-        return output, cache
-
-
-class BasicBlock_export(nn.Module):
-    def __init__(self,
-                 model,
-                 ):
-        super(BasicBlock_export, self).__init__()
-        self.linear = model.linear
-        self.fsmn_block = model.fsmn_block
-        self.affine = model.affine
-        self.relu = model.relu
-
-    def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
-        x = self.linear(input)  # B T D
-        # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
-        # if cache_layer_name not in in_cache:
-        #     in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
-        x, out_cache = self.fsmn_block(x, in_cache)
-        x = self.affine(x)
-        x = self.relu(x)
-        return x, out_cache
-
-
-# class FsmnStack(nn.Sequential):
-#     def __init__(self, *args):
-#         super(FsmnStack, self).__init__(*args)
-#
-#     def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
-#         x = input
-#         for module in self._modules.values():
-#             x = module(x, in_cache)
-#         return x
-
-
-'''
-FSMN net for keyword spotting
-input_dim:              input dimension
-linear_dim:             fsmn input dimensionll
-proj_dim:               fsmn projection dimension
-lorder:                 fsmn left order
-rorder:                 fsmn right order
-num_syn:                output dimension
-fsmn_layers:            no. of sequential fsmn layers
-'''
-
-
-class FSMN(nn.Module):
-    def __init__(
-            self, model,
-    ):
-        super(FSMN, self).__init__()
-        
-        # self.input_dim = input_dim
-        # self.input_affine_dim = input_affine_dim
-        # self.fsmn_layers = fsmn_layers
-        # self.linear_dim = linear_dim
-        # self.proj_dim = proj_dim
-        # self.output_affine_dim = output_affine_dim
-        # self.output_dim = output_dim
-        #
-        # self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
-        # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
-        # self.relu = RectifiedLinear(linear_dim, linear_dim)
-        # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
-        #                         range(fsmn_layers)])
-        # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
-        # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
-        # self.softmax = nn.Softmax(dim=-1)
-        self.in_linear1 = model.in_linear1
-        self.in_linear2 = model.in_linear2
-        self.relu = model.relu
-        # self.fsmn = model.fsmn
-        self.out_linear1 = model.out_linear1
-        self.out_linear2 = model.out_linear2
-        self.softmax = model.softmax
-        self.fsmn = model.fsmn
-        for i, d in enumerate(model.fsmn):
-            if isinstance(d, BasicBlock):
-                self.fsmn[i] = BasicBlock_export(d)
-
-    def fuse_modules(self):
-        pass
-
-    def forward(
-            self,
-            input: torch.Tensor,
-            *args,
-    ):
-        """
-        Args:
-            input (torch.Tensor): Input tensor (B, T, D)
-            in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
-            {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
-        """
-
-        x = self.in_linear1(input)
-        x = self.in_linear2(x)
-        x = self.relu(x)
-        # x4 = self.fsmn(x3, in_cache)  # self.in_cache will update automatically in self.fsmn
-        out_caches = list()
-        for i, d in enumerate(self.fsmn):
-            in_cache = args[i]
-            x, out_cache = d(x, in_cache)
-            out_caches.append(out_cache)
-        x = self.out_linear1(x)
-        x = self.out_linear2(x)
-        x = self.softmax(x)
-
-        return x, out_caches
-
-
-'''
-one deep fsmn layer
-dimproj:                projection dimension, input and output dimension of memory blocks
-dimlinear:              dimension of mapping layer
-lorder:                 left order
-rorder:                 right order
-lstride:                left stride
-rstride:                right stride
-'''
-
-
-class DFSMN(nn.Module):
-
-    def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
-        super(DFSMN, self).__init__()
-
-        self.lorder = lorder
-        self.rorder = rorder
-        self.lstride = lstride
-        self.rstride = rstride
-
-        self.expand = AffineTransform(dimproj, dimlinear)
-        self.shrink = LinearTransform(dimlinear, dimproj)
-
-        self.conv_left = nn.Conv2d(
-            dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
-
-        if rorder > 0:
-            self.conv_right = nn.Conv2d(
-                dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
-        else:
-            self.conv_right = None
-
-    def forward(self, input):
-        f1 = F.relu(self.expand(input))
-        p1 = self.shrink(f1)
-
-        x = torch.unsqueeze(p1, 1)
-        x_per = x.permute(0, 3, 2, 1)
-
-        y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
-
-        if self.conv_right is not None:
-            y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
-            y_right = y_right[:, :, self.rstride:, :]
-            out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
-        else:
-            out = x_per + self.conv_left(y_left)
-
-        out1 = out.permute(0, 3, 2, 1)
-        output = input + out1.squeeze(1)
-
-        return output
-
-
-'''
-build stacked dfsmn layers
-'''
-
-
-def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
-    repeats = [
-        nn.Sequential(
-            DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
-        for i in range(fsmn_layers)
-    ]
-
-    return nn.Sequential(*repeats)
-
-
-if __name__ == '__main__':
-    fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
-    print(fsmn)
-
-    num_params = sum(p.numel() for p in fsmn.parameters())
-    print('the number of model params: {}'.format(num_params))
-    x = torch.zeros(128, 200, 400)  # batch-size * time * dim
-    y, _ = fsmn(x)  # batch-size * time * dim
-    print('input shape: {}'.format(x.shape))
-    print('output shape: {}'.format(y.shape))
-
-    print(fsmn.to_kaldi_net())
diff --git a/funasr/export/models/encoder/sanm_encoder.py b/funasr/export/models/encoder/sanm_encoder.py
deleted file mode 100644
index 7ef863e..0000000
--- a/funasr/export/models/encoder/sanm_encoder.py
+++ /dev/null
@@ -1,218 +0,0 @@
-import torch
-import torch.nn as nn
-
-from funasr.export.utils.torch_function import MakePadMask
-from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.transformer.attention import MultiHeadedAttentionSANM
-from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANM as MultiHeadedAttentionSANM_export
-from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
-from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
-from funasr.models.transformer.embedding import StreamSinusoidalPositionEncoder
-
-
-class SANMEncoder(nn.Module):
-    def __init__(
-        self,
-        model,
-        max_seq_len=512,
-        feats_dim=560,
-        model_name='encoder',
-        onnx: bool = True,
-    ):
-        super().__init__()
-        self.embed = model.embed
-        if isinstance(self.embed, StreamSinusoidalPositionEncoder):
-            self.embed = None
-        self.model = model
-        self.feats_dim = feats_dim
-        self._output_size = model._output_size
-
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-
-        if hasattr(model, 'encoders0'):
-            for i, d in enumerate(self.model.encoders0):
-                if isinstance(d.self_attn, MultiHeadedAttentionSANM):
-                    d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
-                if isinstance(d.feed_forward, PositionwiseFeedForward):
-                    d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
-                self.model.encoders0[i] = EncoderLayerSANM_export(d)
-
-        for i, d in enumerate(self.model.encoders):
-            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
-                d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
-            if isinstance(d.feed_forward, PositionwiseFeedForward):
-                d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
-            self.model.encoders[i] = EncoderLayerSANM_export(d)
-        
-        self.model_name = model_name
-        self.num_heads = model.encoders[0].self_attn.h
-        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
-
-    
-    def prepare_mask(self, mask):
-        mask_3d_btd = mask[:, :, None]
-        if len(mask.shape) == 2:
-            mask_4d_bhlt = 1 - mask[:, None, None, :]
-        elif len(mask.shape) == 3:
-            mask_4d_bhlt = 1 - mask[:, None, :]
-        mask_4d_bhlt = mask_4d_bhlt * -10000.0
-        
-        return mask_3d_btd, mask_4d_bhlt
-
-    def forward(self,
-                speech: torch.Tensor,
-                speech_lengths: torch.Tensor,
-                online: bool = False 
-                ):
-        if not online:
-            speech = speech * self._output_size ** 0.5
-        mask = self.make_pad_mask(speech_lengths)
-        mask = self.prepare_mask(mask)
-        if self.embed is None:
-            xs_pad = speech
-        else:
-            xs_pad = self.embed(speech)
-
-        encoder_outs = self.model.encoders0(xs_pad, mask)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
-        encoder_outs = self.model.encoders(xs_pad, mask)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
-        xs_pad = self.model.after_norm(xs_pad)
-
-        return xs_pad, speech_lengths
-
-    def get_output_size(self):
-        return self.model.encoders[0].size
-
-    def get_dummy_inputs(self):
-        feats = torch.randn(1, 100, self.feats_dim)
-        return (feats)
-
-    def get_input_names(self):
-        return ['feats']
-
-    def get_output_names(self):
-        return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
-
-    def get_dynamic_axes(self):
-        return {
-            'feats': {
-                1: 'feats_length'
-            },
-            'encoder_out': {
-                1: 'enc_out_length'
-            },
-            'predictor_weight':{
-                1: 'pre_out_length'
-            }
-
-        }
-
-
-class SANMVadEncoder(nn.Module):
-    def __init__(
-        self,
-        model,
-        max_seq_len=512,
-        feats_dim=560,
-        model_name='encoder',
-        onnx: bool = True,
-    ):
-        super().__init__()
-        self.embed = model.embed
-        self.model = model
-        self.feats_dim = feats_dim
-        self._output_size = model._output_size
-        
-        if onnx:
-            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
-        else:
-            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-        
-        if hasattr(model, 'encoders0'):
-            for i, d in enumerate(self.model.encoders0):
-                if isinstance(d.self_attn, MultiHeadedAttentionSANM):
-                    d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
-                if isinstance(d.feed_forward, PositionwiseFeedForward):
-                    d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
-                self.model.encoders0[i] = EncoderLayerSANM_export(d)
-        
-        for i, d in enumerate(self.model.encoders):
-            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
-                d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
-            if isinstance(d.feed_forward, PositionwiseFeedForward):
-                d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
-            self.model.encoders[i] = EncoderLayerSANM_export(d)
-        
-        self.model_name = model_name
-        self.num_heads = model.encoders[0].self_attn.h
-        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
-    
-    def prepare_mask(self, mask, sub_masks):
-        mask_3d_btd = mask[:, :, None]
-        mask_4d_bhlt = (1 - sub_masks) * -10000.0
-        
-        return mask_3d_btd, mask_4d_bhlt
-    
-    def forward(self,
-                speech: torch.Tensor,
-                speech_lengths: torch.Tensor,
-                vad_masks: torch.Tensor,
-                sub_masks: torch.Tensor,
-                ):
-        speech = speech * self._output_size ** 0.5
-        mask = self.make_pad_mask(speech_lengths)
-        vad_masks = self.prepare_mask(mask, vad_masks)
-        mask = self.prepare_mask(mask, sub_masks)
-        
-        if self.embed is None:
-            xs_pad = speech
-        else:
-            xs_pad = self.embed(speech)
-        
-        encoder_outs = self.model.encoders0(xs_pad, mask)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        
-        # encoder_outs = self.model.encoders(xs_pad, mask)
-        for layer_idx, encoder_layer in enumerate(self.model.encoders):
-            if layer_idx == len(self.model.encoders) - 1:
-                mask = vad_masks
-            encoder_outs = encoder_layer(xs_pad, mask)
-            xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        
-        xs_pad = self.model.after_norm(xs_pad)
-        
-        return xs_pad, speech_lengths
-    
-    def get_output_size(self):
-        return self.model.encoders[0].size
-    
-    # def get_dummy_inputs(self):
-    #     feats = torch.randn(1, 100, self.feats_dim)
-    #     return (feats)
-    #
-    # def get_input_names(self):
-    #     return ['feats']
-    #
-    # def get_output_names(self):
-    #     return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
-    #
-    # def get_dynamic_axes(self):
-    #     return {
-    #         'feats': {
-    #             1: 'feats_length'
-    #         },
-    #         'encoder_out': {
-    #             1: 'enc_out_length'
-    #         },
-    #         'predictor_weight': {
-    #             1: 'pre_out_length'
-    #         }
-    #
-    #     }
diff --git a/funasr/export/models/modules/__init__.py b/funasr/export/models/modules/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/models/modules/__init__.py
+++ /dev/null
diff --git a/funasr/export/models/modules/decoder_layer.py b/funasr/export/models/modules/decoder_layer.py
deleted file mode 100644
index 9a464a4..0000000
--- a/funasr/export/models/modules/decoder_layer.py
+++ /dev/null
@@ -1,71 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-import torch
-from torch import nn
-
-
-class DecoderLayerSANM(nn.Module):
-
-    def __init__(
-        self,
-        model
-    ):
-        super().__init__()
-        self.self_attn = model.self_attn
-        self.src_attn = model.src_attn
-        self.feed_forward = model.feed_forward
-        self.norm1 = model.norm1
-        self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
-        self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
-        self.size = model.size
-
-
-    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
-
-        residual = tgt
-        tgt = self.norm1(tgt)
-        tgt = self.feed_forward(tgt)
-
-        x = tgt
-        if self.self_attn is not None:
-            tgt = self.norm2(tgt)
-            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
-            x = residual + x
-
-        if self.src_attn is not None:
-            residual = x
-            x = self.norm3(x)
-            x = residual + self.src_attn(x, memory, memory_mask)
-
-
-        return x, tgt_mask, memory, memory_mask, cache
-
-
-class DecoderLayer(nn.Module):
-    def __init__(self, model):
-        super().__init__()
-        self.self_attn = model.self_attn
-        self.src_attn = model.src_attn
-        self.feed_forward = model.feed_forward
-        self.norm1 = model.norm1
-        self.norm2 = model.norm2
-        self.norm3 = model.norm3
-    
-    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
-        residual = tgt
-        tgt = self.norm1(tgt)
-        tgt_q = tgt
-        tgt_q_mask = tgt_mask
-        x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
-
-        residual = x
-        x = self.norm2(x)
-        
-        x = residual + self.src_attn(x, memory, memory, memory_mask)
-
-        residual = x
-        x = self.norm3(x)
-        x = residual + self.feed_forward(x)
-
-        return x, tgt_mask, memory, memory_mask
diff --git a/funasr/export/models/modules/encoder_layer.py b/funasr/export/models/modules/encoder_layer.py
deleted file mode 100644
index 7d01397..0000000
--- a/funasr/export/models/modules/encoder_layer.py
+++ /dev/null
@@ -1,91 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-import torch
-from torch import nn
-
-
-class EncoderLayerSANM(nn.Module):
-    def __init__(
-        self,
-        model,
-    ):
-        """Construct an EncoderLayer object."""
-        super().__init__()
-        self.self_attn = model.self_attn
-        self.feed_forward = model.feed_forward
-        self.norm1 = model.norm1
-        self.norm2 = model.norm2
-        self.in_size = model.in_size
-        self.size = model.size
-
-    def forward(self, x, mask):
-
-        residual = x
-        x = self.norm1(x)
-        x = self.self_attn(x, mask)
-        if self.in_size == self.size:
-            x = x + residual
-        residual = x
-        x = self.norm2(x)
-        x = self.feed_forward(x)
-        x = x + residual
-
-        return x, mask
-
-
-class EncoderLayerConformer(nn.Module):
-    def __init__(
-        self,
-        model,
-    ):
-        """Construct an EncoderLayer object."""
-        super().__init__()
-        self.self_attn = model.self_attn
-        self.feed_forward = model.feed_forward
-        self.feed_forward_macaron = model.feed_forward_macaron
-        self.conv_module = model.conv_module
-        self.norm_ff = model.norm_ff
-        self.norm_mha = model.norm_mha
-        self.norm_ff_macaron = model.norm_ff_macaron
-        self.norm_conv = model.norm_conv
-        self.norm_final = model.norm_final
-        self.size = model.size
-
-    def forward(self, x, mask):
-        if isinstance(x, tuple):
-            x, pos_emb = x[0], x[1]
-        else:
-            x, pos_emb = x, None
-
-        if self.feed_forward_macaron is not None:
-            residual = x
-            x = self.norm_ff_macaron(x)
-            x = residual + self.feed_forward_macaron(x) * 0.5
-
-        residual = x
-        x = self.norm_mha(x)
-
-        x_q = x
-
-        if pos_emb is not None:
-            x_att = self.self_attn(x_q, x, x, pos_emb, mask)
-        else:
-            x_att = self.self_attn(x_q, x, x, mask)
-        x = residual + x_att
-
-        if self.conv_module is not None:
-            residual = x
-            x = self.norm_conv(x)
-            x = residual +  self.conv_module(x)
-
-        residual = x
-        x = self.norm_ff(x)
-        x = residual + self.feed_forward(x) * 0.5
-
-        x = self.norm_final(x)
-
-        if pos_emb is not None:
-            return (x, pos_emb), mask
-
-        return x, mask
diff --git a/funasr/export/models/modules/feedforward.py b/funasr/export/models/modules/feedforward.py
deleted file mode 100644
index 9388ae1..0000000
--- a/funasr/export/models/modules/feedforward.py
+++ /dev/null
@@ -1,31 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-import torch
-import torch.nn as nn
-
-
-class PositionwiseFeedForward(nn.Module):
-	def __init__(self, model):
-		super().__init__()
-		self.w_1 = model.w_1
-		self.w_2 = model.w_2
-		self.activation = model.activation
-	
-	def forward(self, x):
-		x = self.activation(self.w_1(x))
-		x = self.w_2(x)
-		return x
-
-
-class PositionwiseFeedForwardDecoderSANM(nn.Module):
-	def __init__(self, model):
-		super().__init__()
-		self.w_1 = model.w_1
-		self.w_2 = model.w_2
-		self.activation = model.activation
-		self.norm = model.norm
-	
-	def forward(self, x):
-		x = self.activation(self.w_1(x))
-		x = self.w_2(self.norm(x))
-		return x
\ No newline at end of file
diff --git a/funasr/export/models/modules/multihead_att.py b/funasr/export/models/modules/multihead_att.py
deleted file mode 100644
index 4885c4e..0000000
--- a/funasr/export/models/modules/multihead_att.py
+++ /dev/null
@@ -1,243 +0,0 @@
-import os
-import math
-
-import torch
-import torch.nn as nn
-
-
-class MultiHeadedAttentionSANM(nn.Module):
-    def __init__(self, model):
-        super().__init__()
-        self.d_k = model.d_k
-        self.h = model.h
-        self.linear_out = model.linear_out
-        self.linear_q_k_v = model.linear_q_k_v
-        self.fsmn_block = model.fsmn_block
-        self.pad_fn = model.pad_fn
-
-        self.attn = None
-        self.all_head_size = self.h * self.d_k
-
-    def forward(self, x, mask):
-        mask_3d_btd, mask_4d_bhlt = mask
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        fsmn_memory = self.forward_fsmn(v, mask_3d_btd)
-        q_h = q_h * self.d_k**(-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, mask_4d_bhlt)
-        return att_outs + fsmn_memory
-
-    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
-        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
-        x = x.view(new_x_shape)
-        return x.permute(0, 2, 1, 3)
-
-    def forward_qkv(self, x):
-        q_k_v = self.linear_q_k_v(x)
-        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = self.transpose_for_scores(q)
-        k_h = self.transpose_for_scores(k)
-        v_h = self.transpose_for_scores(v)
-        return q_h, k_h, v_h, v
-
-    def forward_fsmn(self, inputs, mask):
-        # b, t, d = inputs.size()
-        # mask = torch.reshape(mask, (b, -1, 1))
-        inputs = inputs * mask
-        x = inputs.transpose(1, 2)
-        x = self.pad_fn(x)
-        x = self.fsmn_block(x)
-        x = x.transpose(1, 2)
-        x = x + inputs
-        x = x * mask
-        return x
-
-    def forward_attention(self, value, scores, mask):
-        scores = scores + mask
-
-        self.attn = torch.softmax(scores, dim=-1)
-        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
-
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
-        context_layer = context_layer.view(new_context_layer_shape)
-        return self.linear_out(context_layer)  # (batch, time1, d_model)
-
-
-def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
-    x = x * mask
-    x = x.transpose(1, 2)
-    if cache is None:
-        x = pad_fn(x)
-    else:
-        x = torch.cat((cache, x), dim=2)
-        cache = x[:, :, -(kernel_size-1):]
-    return x, cache
-
-
-torch_version = tuple([int(i) for i in torch.__version__.split(".")[:2]])
-if torch_version >= (1, 8):
-    import torch.fx
-    torch.fx.wrap('preprocess_for_attn')
-
-
-class MultiHeadedAttentionSANMDecoder(nn.Module):
-    def __init__(self, model):
-        super().__init__()
-        self.fsmn_block = model.fsmn_block
-        self.pad_fn = model.pad_fn
-        self.kernel_size = model.kernel_size
-        self.attn = None
-
-    def forward(self, inputs, mask, cache=None):
-        x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
-        x = self.fsmn_block(x)
-        x = x.transpose(1, 2)
-
-        x = x + inputs
-        x = x * mask
-        return x, cache
-
-
-class MultiHeadedAttentionCrossAtt(nn.Module):
-    def __init__(self, model):
-        super().__init__()
-        self.d_k = model.d_k
-        self.h = model.h
-        self.linear_q = model.linear_q
-        self.linear_k_v = model.linear_k_v
-        self.linear_out = model.linear_out
-        self.attn = None
-        self.all_head_size = self.h * self.d_k
-
-    def forward(self, x, memory, memory_mask):
-        q, k, v = self.forward_qkv(x, memory)
-        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
-        return self.forward_attention(v, scores, memory_mask)
-
-    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
-        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
-        x = x.view(new_x_shape)
-        return x.permute(0, 2, 1, 3)
-
-    def forward_qkv(self, x, memory):
-        q = self.linear_q(x)
-
-        k_v = self.linear_k_v(memory)
-        k, v = torch.split(k_v, int(self.h * self.d_k), dim=-1)
-        q = self.transpose_for_scores(q)
-        k = self.transpose_for_scores(k)
-        v = self.transpose_for_scores(v)
-        return q, k, v
-
-    def forward_attention(self, value, scores, mask):
-        scores = scores + mask
-
-        self.attn = torch.softmax(scores, dim=-1)
-        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
-
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
-        context_layer = context_layer.view(new_context_layer_shape)
-        return self.linear_out(context_layer)  # (batch, time1, d_model)
-
-
-class OnnxMultiHeadedAttention(nn.Module):
-    def __init__(self, model):
-        super().__init__()
-        self.d_k = model.d_k
-        self.h = model.h
-        self.linear_q = model.linear_q
-        self.linear_k = model.linear_k
-        self.linear_v = model.linear_v
-        self.linear_out = model.linear_out
-        self.attn = None
-        self.all_head_size = self.h * self.d_k
-    
-    def forward(self, query, key, value, mask):
-        q, k, v = self.forward_qkv(query, key, value)
-        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
-        return self.forward_attention(v, scores, mask)
-
-    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
-        new_x_shape = x.size()[:-1] + (self.h, self.d_k)
-        x = x.view(new_x_shape)
-        return x.permute(0, 2, 1, 3)
-
-    def forward_qkv(self, query, key, value):
-        q = self.linear_q(query)
-        k = self.linear_k(key)
-        v = self.linear_v(value)
-        q = self.transpose_for_scores(q)
-        k = self.transpose_for_scores(k)
-        v = self.transpose_for_scores(v)
-        return q, k, v
-    
-    def forward_attention(self, value, scores, mask):
-        scores = scores + mask
-
-        self.attn = torch.softmax(scores, dim=-1)
-        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
-        
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
-        context_layer = context_layer.view(new_context_layer_shape)
-        return self.linear_out(context_layer)  # (batch, time1, d_model)
-
-
-class OnnxRelPosMultiHeadedAttention(OnnxMultiHeadedAttention):
-    def __init__(self, model):
-        super().__init__(model)
-        self.linear_pos = model.linear_pos
-        self.pos_bias_u = model.pos_bias_u
-        self.pos_bias_v = model.pos_bias_v
-    
-    def forward(self, query, key, value, pos_emb, mask):
-        q, k, v = self.forward_qkv(query, key, value)
-        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
-
-        p = self.transpose_for_scores(self.linear_pos(pos_emb)) # (batch, head, time1, d_k)
-
-        # (batch, head, time1, d_k)
-        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
-        # (batch, head, time1, d_k)
-        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
-
-        # compute attention score
-        # first compute matrix a and matrix c
-        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
-        # (batch, head, time1, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
-
-        # compute matrix b and matrix d
-        # (batch, head, time1, time1)
-        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
-        matrix_bd = self.rel_shift(matrix_bd)
-
-        scores = (matrix_ac + matrix_bd) / math.sqrt(
-            self.d_k
-        )  # (batch, head, time1, time2)
-
-        return self.forward_attention(v, scores, mask)
-
-    def rel_shift(self, x):
-        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
-        x_padded = torch.cat([zero_pad, x], dim=-1)
-
-        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
-        x = x_padded[:, :, 1:].view_as(x)[
-            :, :, :, : x.size(-1) // 2 + 1
-        ]  # only keep the positions from 0 to time2
-        return x
-
-    def forward_attention(self, value, scores, mask):
-        scores = scores + mask
-
-        self.attn = torch.softmax(scores, dim=-1)
-        context_layer = torch.matmul(self.attn, value)  # (batch, head, time1, d_k)
-        
-        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
-        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
-        context_layer = context_layer.view(new_context_layer_shape)
-        return self.linear_out(context_layer)  # (batch, time1, d_model)
-        
diff --git a/funasr/export/models/predictor/__init__.py b/funasr/export/models/predictor/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/models/predictor/__init__.py
+++ /dev/null
diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py
deleted file mode 100644
index 03c4433..0000000
--- a/funasr/export/models/predictor/cif.py
+++ /dev/null
@@ -1,295 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-import torch
-from torch import nn
-
-
-def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
-	if maxlen is None:
-		maxlen = lengths.max()
-	row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
-	matrix = torch.unsqueeze(lengths, dim=-1)
-	mask = row_vector < matrix
-	mask = mask.detach()
-	
-	return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
-
-def sequence_mask_scripts(lengths, maxlen:int):
-	row_vector = torch.arange(0, maxlen, 1).type(lengths.dtype).to(lengths.device)
-	matrix = torch.unsqueeze(lengths, dim=-1)
-	mask = row_vector < matrix
-	return mask.type(torch.float32).to(lengths.device)
-
-class CifPredictorV2(nn.Module):
-	def __init__(self, model):
-		super().__init__()
-		
-		self.pad = model.pad
-		self.cif_conv1d = model.cif_conv1d
-		self.cif_output = model.cif_output
-		self.threshold = model.threshold
-		self.smooth_factor = model.smooth_factor
-		self.noise_threshold = model.noise_threshold
-		self.tail_threshold = model.tail_threshold
-	
-	def forward(self, hidden: torch.Tensor,
-	            mask: torch.Tensor,
-	            ):
-		alphas, token_num = self.forward_cnn(hidden, mask)
-		mask = mask.transpose(-1, -2).float()
-		mask = mask.squeeze(-1)
-		hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
-		acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-		
-		return acoustic_embeds, token_num, alphas, cif_peak
-	
-	def forward_cnn(self, hidden: torch.Tensor,
-	            mask: torch.Tensor,
-	            ):
-		h = hidden
-		context = h.transpose(1, 2)
-		queries = self.pad(context)
-		output = torch.relu(self.cif_conv1d(queries))
-		output = output.transpose(1, 2)
-		
-		output = self.cif_output(output)
-		alphas = torch.sigmoid(output)
-		alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
-		mask = mask.transpose(-1, -2).float()
-		alphas = alphas * mask
-		alphas = alphas.squeeze(-1)
-		token_num = alphas.sum(-1)
-
-		return alphas, token_num
-	
-	def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
-		b, t, d = hidden.size()
-		tail_threshold = self.tail_threshold
-		
-		zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
-		ones_t = torch.ones_like(zeros_t)
-
-		mask_1 = torch.cat([mask, zeros_t], dim=1)
-		mask_2 = torch.cat([ones_t, mask], dim=1)
-		mask = mask_2 - mask_1
-		tail_threshold = mask * tail_threshold
-		alphas = torch.cat([alphas, zeros_t], dim=1)
-		alphas = torch.add(alphas, tail_threshold)
-
-		zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
-		hidden = torch.cat([hidden, zeros], dim=1)
-		token_num = alphas.sum(dim=-1)
-		token_num_floor = torch.floor(token_num)
-		
-		return hidden, alphas, token_num_floor
-
-
-# @torch.jit.script
-# def cif(hidden, alphas, threshold: float):
-# 	batch_size, len_time, hidden_size = hidden.size()
-# 	threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
-#
-# 	# loop varss
-# 	integrate = torch.zeros([batch_size], device=hidden.device)
-# 	frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
-# 	# intermediate vars along time
-# 	list_fires = []
-# 	list_frames = []
-#
-# 	for t in range(len_time):
-# 		alpha = alphas[:, t]
-# 		distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
-#
-# 		integrate += alpha
-# 		list_fires.append(integrate)
-#
-# 		fire_place = integrate >= threshold
-# 		integrate = torch.where(fire_place,
-# 		                        integrate - torch.ones([batch_size], device=hidden.device),
-# 		                        integrate)
-# 		cur = torch.where(fire_place,
-# 		                  distribution_completion,
-# 		                  alpha)
-# 		remainds = alpha - cur
-#
-# 		frame += cur[:, None] * hidden[:, t, :]
-# 		list_frames.append(frame)
-# 		frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
-# 		                    remainds[:, None] * hidden[:, t, :],
-# 		                    frame)
-#
-# 	fires = torch.stack(list_fires, 1)
-# 	frames = torch.stack(list_frames, 1)
-# 	list_ls = []
-# 	len_labels = torch.floor(alphas.sum(-1)).int()
-# 	max_label_len = len_labels.max()
-# 	for b in range(batch_size):
-# 		fire = fires[b, :]
-# 		l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
-# 		pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device)
-# 		list_ls.append(torch.cat([l, pad_l], 0))
-# 	return torch.stack(list_ls, 0), fires
-
-
-@torch.jit.script
-def cif(hidden, alphas, threshold: float):
-	batch_size, len_time, hidden_size = hidden.size()
-	threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device)
-	
-	# loop varss
-	integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=hidden.device)
-	frame = torch.zeros([batch_size, hidden_size], dtype=hidden.dtype, device=hidden.device)
-	# intermediate vars along time
-	list_fires = []
-	list_frames = []
-	
-	for t in range(len_time):
-		alpha = alphas[:, t]
-		distribution_completion = torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device) - integrate
-		
-		integrate += alpha
-		list_fires.append(integrate)
-		
-		fire_place = integrate >= threshold
-		integrate = torch.where(fire_place,
-		                        integrate - torch.ones([batch_size], dtype=alphas.dtype, device=hidden.device),
-		                        integrate)
-		cur = torch.where(fire_place,
-		                  distribution_completion,
-		                  alpha)
-		remainds = alpha - cur
-		
-		frame += cur[:, None] * hidden[:, t, :]
-		list_frames.append(frame)
-		frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
-		                    remainds[:, None] * hidden[:, t, :],
-		                    frame)
-	
-	fires = torch.stack(list_fires, 1)
-	frames = torch.stack(list_frames, 1)
-
-	fire_idxs = fires >= threshold
-	frame_fires = torch.zeros_like(hidden)
-	max_label_len = frames[0, fire_idxs[0]].size(0)
-	for b in range(batch_size):
-		frame_fire = frames[b, fire_idxs[b]]
-		frame_len = frame_fire.size(0)
-		frame_fires[b, :frame_len, :] = frame_fire
-	
-		if frame_len >= max_label_len:
-			max_label_len = frame_len
-	frame_fires = frame_fires[:, :max_label_len, :]
-	return frame_fires, fires
-
-
-class CifPredictorV3(nn.Module):
-	def __init__(self, model):
-		super().__init__()
-		
-		self.pad = model.pad
-		self.cif_conv1d = model.cif_conv1d
-		self.cif_output = model.cif_output
-		self.threshold = model.threshold
-		self.smooth_factor = model.smooth_factor
-		self.noise_threshold = model.noise_threshold
-		self.tail_threshold = model.tail_threshold
-
-		self.upsample_times = model.upsample_times
-		self.upsample_cnn = model.upsample_cnn
-		self.blstm = model.blstm
-		self.cif_output2 = model.cif_output2
-		self.smooth_factor2 = model.smooth_factor2
-		self.noise_threshold2 = model.noise_threshold2
-	
-	def forward(self, hidden: torch.Tensor,
-	            mask: torch.Tensor,
-	            ):
-		h = hidden
-		context = h.transpose(1, 2)
-		queries = self.pad(context)
-		output = torch.relu(self.cif_conv1d(queries))
-		output = output.transpose(1, 2)
-		
-		output = self.cif_output(output)
-		alphas = torch.sigmoid(output)
-		alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
-		mask = mask.transpose(-1, -2).float()
-		alphas = alphas * mask
-		alphas = alphas.squeeze(-1)
-		token_num = alphas.sum(-1)
-		
-		mask = mask.squeeze(-1)
-		hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
-		acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-		
-		return acoustic_embeds, token_num, alphas, cif_peak
-	
-	def get_upsample_timestmap(self, hidden, mask=None, token_num=None):
-		h = hidden
-		b = hidden.shape[0]
-		context = h.transpose(1, 2)
-
-		# generate alphas2
-		_output = context
-		output2 = self.upsample_cnn(_output)
-		output2 = output2.transpose(1, 2)
-		output2, (_, _) = self.blstm(output2)
-		alphas2 = torch.sigmoid(self.cif_output2(output2))
-		alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
-		
-		mask = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
-		mask = mask.unsqueeze(-1)
-		alphas2 = alphas2 * mask
-		alphas2 = alphas2.squeeze(-1)
-		_token_num = alphas2.sum(-1)
-		alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
-		# upsampled alphas and cif_peak
-		us_alphas = alphas2
-		us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
-		return us_alphas, us_cif_peak
-
-	def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
-		b, t, d = hidden.size()
-		tail_threshold = self.tail_threshold
-		
-		zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
-		ones_t = torch.ones_like(zeros_t)
-
-		mask_1 = torch.cat([mask, zeros_t], dim=1)
-		mask_2 = torch.cat([ones_t, mask], dim=1)
-		mask = mask_2 - mask_1
-		tail_threshold = mask * tail_threshold
-		alphas = torch.cat([alphas, zeros_t], dim=1)
-		alphas = torch.add(alphas, tail_threshold)
-
-		zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
-		hidden = torch.cat([hidden, zeros], dim=1)
-		token_num = alphas.sum(dim=-1)
-		token_num_floor = torch.floor(token_num)
-		
-		return hidden, alphas, token_num_floor
-
-
-@torch.jit.script
-def cif_wo_hidden(alphas, threshold: float):
-    batch_size, len_time = alphas.size()
-
-    # loop varss
-    integrate = torch.zeros([batch_size], dtype=alphas.dtype, device=alphas.device)
-    # intermediate vars along time
-    list_fires = []
-
-    for t in range(len_time):
-        alpha = alphas[:, t]
-
-        integrate += alpha
-        list_fires.append(integrate)
-
-        fire_place = integrate >= threshold
-        integrate = torch.where(fire_place,
-                                integrate - torch.ones([batch_size], device=alphas.device)*threshold,
-                                integrate)
-
-    fires = torch.stack(list_fires, 1)
-    return fires
diff --git a/funasr/export/test/__init__.py b/funasr/export/test/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/test/__init__.py
+++ /dev/null
diff --git a/funasr/export/test/test_onnx.py b/funasr/export/test/test_onnx.py
deleted file mode 100644
index 4351728..0000000
--- a/funasr/export/test/test_onnx.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import onnxruntime
-import numpy as np
-
-
-if __name__ == '__main__':
-    onnx_path = "/mnt/workspace/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.onnx"
-    sess = onnxruntime.InferenceSession(onnx_path)
-    input_name = [nd.name for nd in sess.get_inputs()]
-    output_name = [nd.name for nd in sess.get_outputs()]
-
-    def _get_feed_dict(feats_length):
-        return {'speech': np.zeros((1, feats_length, 560), dtype=np.float32), 'speech_lengths': np.array([feats_length,], dtype=np.int32)}
-
-    def _run(feed_dict):
-        output = sess.run(output_name, input_feed=feed_dict)
-        for name, value in zip(output_name, output):
-            print('{}: {}'.format(name, value.shape))
-
-    _run(_get_feed_dict(100))
-    _run(_get_feed_dict(200))
\ No newline at end of file
diff --git a/funasr/export/test/test_onnx_punc.py b/funasr/export/test/test_onnx_punc.py
deleted file mode 100644
index 39f85f4..0000000
--- a/funasr/export/test/test_onnx_punc.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import onnxruntime
-import numpy as np
-
-
-if __name__ == '__main__':
-    onnx_path = "../damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.onnx"
-    sess = onnxruntime.InferenceSession(onnx_path)
-    input_name = [nd.name for nd in sess.get_inputs()]
-    output_name = [nd.name for nd in sess.get_outputs()]
-
-    def _get_feed_dict(text_length):
-        return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)}
-
-    def _run(feed_dict):
-        output = sess.run(output_name, input_feed=feed_dict)
-        for name, value in zip(output_name, output):
-            print('{}: {}'.format(name, value))
-    _run(_get_feed_dict(10))
diff --git a/funasr/export/test/test_onnx_punc_vadrealtime.py b/funasr/export/test/test_onnx_punc_vadrealtime.py
deleted file mode 100644
index 507226e..0000000
--- a/funasr/export/test/test_onnx_punc_vadrealtime.py
+++ /dev/null
@@ -1,22 +0,0 @@
-import onnxruntime
-import numpy as np
-
-
-if __name__ == '__main__':
-    onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx"
-    sess = onnxruntime.InferenceSession(onnx_path)
-    input_name = [nd.name for nd in sess.get_inputs()]
-    output_name = [nd.name for nd in sess.get_outputs()]
-
-    def _get_feed_dict(text_length):
-        return {'inputs': np.ones((1, text_length), dtype=np.int64),
-                'text_lengths': np.array([text_length,], dtype=np.int32),
-                'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
-                'sub_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
-                }
-
-    def _run(feed_dict):
-        output = sess.run(output_name, input_feed=feed_dict)
-        for name, value in zip(output_name, output):
-            print('{}: {}'.format(name, value))
-    _run(_get_feed_dict(10))
diff --git a/funasr/export/test/test_onnx_vad.py b/funasr/export/test/test_onnx_vad.py
deleted file mode 100644
index 12f058f..0000000
--- a/funasr/export/test/test_onnx_vad.py
+++ /dev/null
@@ -1,26 +0,0 @@
-import onnxruntime
-import numpy as np
-
-
-if __name__ == '__main__':
-    onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"
-    sess = onnxruntime.InferenceSession(onnx_path)
-    input_name = [nd.name for nd in sess.get_inputs()]
-    output_name = [nd.name for nd in sess.get_outputs()]
-
-    def _get_feed_dict(feats_length):
-        
-        return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32),
-                'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32),
-                'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32),
-                'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32),
-                'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32),
-                }
-
-    def _run(feed_dict):
-        output = sess.run(output_name, input_feed=feed_dict)
-        for name, value in zip(output_name, output):
-            print('{}: {}'.format(name, value.shape))
-
-    _run(_get_feed_dict(100))
-    _run(_get_feed_dict(200))
\ No newline at end of file
diff --git a/funasr/export/test/test_torchscripts.py b/funasr/export/test/test_torchscripts.py
deleted file mode 100644
index 9afec74..0000000
--- a/funasr/export/test/test_torchscripts.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import torch
-import numpy as np
-
-if __name__ == '__main__':
-	onnx_path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts"
-	loaded = torch.jit.load(onnx_path)
-	
-	x = torch.rand([2, 21, 560])
-	x_len = torch.IntTensor([6, 21])
-	res = loaded(x, x_len)
-	print(res[0].size(), res[1])
-	
-	x = torch.rand([5, 50, 560])
-	x_len = torch.IntTensor([6, 21, 10, 30, 50])
-	res = loaded(x, x_len)
-	print(res[0].size(), res[1])
-	
\ No newline at end of file
diff --git a/funasr/export/utils/__init__.py b/funasr/export/utils/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/export/utils/__init__.py
+++ /dev/null
diff --git a/funasr/export/utils/torch_function.py b/funasr/export/utils/torch_function.py
deleted file mode 100644
index a078a7e..0000000
--- a/funasr/export/utils/torch_function.py
+++ /dev/null
@@ -1,80 +0,0 @@
-from typing import Optional
-
-import torch
-import torch.nn as nn
-
-import numpy as np
-
-
-class MakePadMask(nn.Module):
-    def __init__(self, max_seq_len=512, flip=True):
-        super().__init__()
-        if flip:
-            self.mask_pad = torch.Tensor(1 - np.tri(max_seq_len)).type(torch.bool)
-        else:
-            self.mask_pad = torch.Tensor(np.tri(max_seq_len)).type(torch.bool)
-    
-    def forward(self, lengths, xs=None, length_dim=-1, maxlen=None):
-        """Make mask tensor containing indices of padded part.
-        This implementation creates the same mask tensor with original make_pad_mask,
-        which can be converted into onnx format.
-        Dimension length of xs should be 2 or 3.
-        """
-        if length_dim == 0:
-            raise ValueError("length_dim cannot be 0: {}".format(length_dim))
-
-        if xs is not None and len(xs.shape) == 3:
-            if length_dim == 1:
-                lengths = lengths.unsqueeze(1).expand(
-                    *xs.transpose(1, 2).shape[:2])
-            else:
-                lengths = lengths.unsqueeze(1).expand(*xs.shape[:2])
-
-        if maxlen is not None:
-            m = maxlen
-        elif xs is not None:
-            m = xs.shape[-1]
-        else:
-            m = torch.max(lengths)
-
-        mask = self.mask_pad[lengths - 1][..., :m].type(torch.float32)
-
-        if length_dim == 1:
-            return mask.transpose(1, 2)
-        else:
-            return mask
-
-class sequence_mask(nn.Module):
-    def __init__(self, max_seq_len=512, flip=True):
-        super().__init__()
-    
-    def forward(self, lengths, max_seq_len=None, dtype=torch.float32, device=None):
-        if max_seq_len is None:
-            max_seq_len = lengths.max()
-        row_vector = torch.arange(0, max_seq_len, 1).to(lengths.device)
-        matrix = torch.unsqueeze(lengths, dim=-1)
-        mask = row_vector < matrix
-        
-        return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
-
-def normalize(input: torch.Tensor, p: float = 2.0, dim: int = 1, out: Optional[torch.Tensor] = None) -> torch.Tensor:
-    if out is None:
-        denom = input.norm(p, dim, keepdim=True).expand_as(input)
-        return input / denom
-    else:
-        denom = input.norm(p, dim, keepdim=True).expand_as(input)
-        return torch.div(input, denom, out=out)
-
-def subsequent_mask(size: torch.Tensor):
-    return torch.ones(size, size).tril()
-
-
-def MakePadMask_test():
-    feats_length = torch.tensor([10]).type(torch.long)
-    mask_fn = MakePadMask()
-    mask = mask_fn(feats_length)
-    print(mask)
-
-
-if __name__ == '__main__':
-    MakePadMask_test()
\ No newline at end of file
diff --git a/funasr/export/__init__.py b/funasr/frontends/__init__.py
similarity index 100%
rename from funasr/export/__init__.py
rename to funasr/frontends/__init__.py
diff --git a/funasr/models/frontend/default.py b/funasr/frontends/default.py
similarity index 95%
rename from funasr/models/frontend/default.py
rename to funasr/frontends/default.py
index b4e518a..8ac1ca8 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/frontends/default.py
@@ -6,20 +6,19 @@
 import humanfriendly
 import numpy as np
 import torch
+import torch.nn as nn
 try:
     from torch_complex.tensor import ComplexTensor
 except:
     print("Please install torch_complex firstly")
 
-from funasr.models.frontend.utils.log_mel import LogMel
-from funasr.models.frontend.utils.stft import Stft
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.utils.frontend import Frontend
-from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.frontends.utils.log_mel import LogMel
+from funasr.frontends.utils.stft import Stft
+from funasr.frontends.utils.frontend import Frontend
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 
 
-class DefaultFrontend(AbsFrontend):
+class DefaultFrontend(nn.Module):
     """Conventional frontend structure for ASR.
     Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
     """
@@ -38,7 +37,7 @@
             fmin: int = None,
             fmax: int = None,
             htk: bool = False,
-            frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
+            frontend_conf: Optional[dict] = None,
             apply_stft: bool = True,
             use_channel: int = None,
     ):
@@ -139,7 +138,7 @@
         return input_stft, feats_lens
 
 
-class MultiChannelFrontend(AbsFrontend):
+class MultiChannelFrontend(nn.Module):
     """Conventional frontend structure for ASR.
     Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
     """
@@ -160,7 +159,7 @@
             fmin: int = None,
             fmax: int = None,
             htk: bool = False,
-            frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
+            frontend_conf: Optional[dict] = None,
             apply_stft: bool = True,
             use_channel: int = None,
             lfr_m: int = 1,
diff --git a/funasr/models/frontend/eend_ola_feature.py b/funasr/frontends/eend_ola_feature.py
similarity index 100%
rename from funasr/models/frontend/eend_ola_feature.py
rename to funasr/frontends/eend_ola_feature.py
diff --git a/funasr/models/frontend/fused.py b/funasr/frontends/fused.py
similarity index 95%
rename from funasr/models/frontend/fused.py
rename to funasr/frontends/fused.py
index ff95871..24f73f4 100644
--- a/funasr/models/frontend/fused.py
+++ b/funasr/frontends/fused.py
@@ -1,12 +1,12 @@
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.frontends.default import DefaultFrontend
+from funasr.frontends.s3prl import S3prlFrontend
 import numpy as np
 import torch
+import torch.nn as nn
 from typing import Tuple
 
 
-class FusedFrontends(AbsFrontend):
+class FusedFrontends(nn.Module):
     def __init__(
         self, frontends=None, align_method="linear_projection", proj_dim=100, fs=16000
     ):
diff --git a/funasr/models/frontend/s3prl.py b/funasr/frontends/s3prl.py
similarity index 93%
rename from funasr/models/frontend/s3prl.py
rename to funasr/frontends/s3prl.py
index 0e419bc..ff60592 100644
--- a/funasr/models/frontend/s3prl.py
+++ b/funasr/frontends/s3prl.py
@@ -8,11 +8,10 @@
 
 import humanfriendly
 import torch
+import torch.nn as nn
 
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.utils.frontend import Frontend
+from funasr.frontends.utils.frontend import Frontend
 from funasr.models.transformer.utils.nets_utils import pad_list
-from funasr.utils.get_default_kwargs import get_default_kwargs
 
 
 def base_s3prl_setup(args):
@@ -26,13 +25,13 @@
     return args
 
 
-class S3prlFrontend(AbsFrontend):
+class S3prlFrontend(nn.Module):
     """Speech Pretrained Representation frontend structure for ASR."""
 
     def __init__(
             self,
             fs: Union[int, str] = 16000,
-            frontend_conf: Optional[dict] = get_default_kwargs(Frontend),
+            frontend_conf: Optional[dict] = None,
             download_dir: str = None,
             multilayer_feature: bool = False,
     ):
diff --git a/funasr/models/frontend/utils/__init__.py b/funasr/frontends/utils/__init__.py
similarity index 100%
rename from funasr/models/frontend/utils/__init__.py
rename to funasr/frontends/utils/__init__.py
diff --git a/funasr/models/frontend/utils/beamformer.py b/funasr/frontends/utils/beamformer.py
similarity index 100%
rename from funasr/models/frontend/utils/beamformer.py
rename to funasr/frontends/utils/beamformer.py
diff --git a/funasr/models/frontend/utils/complex_utils.py b/funasr/frontends/utils/complex_utils.py
similarity index 100%
rename from funasr/models/frontend/utils/complex_utils.py
rename to funasr/frontends/utils/complex_utils.py
diff --git a/funasr/models/frontend/utils/dnn_beamformer.py b/funasr/frontends/utils/dnn_beamformer.py
similarity index 94%
rename from funasr/models/frontend/utils/dnn_beamformer.py
rename to funasr/frontends/utils/dnn_beamformer.py
index 0f4e9da..75637d2 100644
--- a/funasr/models/frontend/utils/dnn_beamformer.py
+++ b/funasr/frontends/utils/dnn_beamformer.py
@@ -4,12 +4,12 @@
 import torch
 from torch.nn import functional as F
 
-from funasr.models.frontend.utils.beamformer import apply_beamforming_vector
-from funasr.models.frontend.utils.beamformer import get_mvdr_vector
-from funasr.models.frontend.utils.beamformer import (
+from funasr.frontends.utils.beamformer import apply_beamforming_vector
+from funasr.frontends.utils.beamformer import get_mvdr_vector
+from funasr.frontends.utils.beamformer import (
     get_power_spectral_density_matrix,  # noqa: H301
 )
-from funasr.models.frontend.utils.mask_estimator import MaskEstimator
+from funasr.frontends.utils.mask_estimator import MaskEstimator
 from torch_complex.tensor import ComplexTensor
 
 
diff --git a/funasr/models/frontend/utils/dnn_wpe.py b/funasr/frontends/utils/dnn_wpe.py
similarity index 97%
rename from funasr/models/frontend/utils/dnn_wpe.py
rename to funasr/frontends/utils/dnn_wpe.py
index 6a14c5d..9171339 100644
--- a/funasr/models/frontend/utils/dnn_wpe.py
+++ b/funasr/frontends/utils/dnn_wpe.py
@@ -4,7 +4,7 @@
 import torch
 from torch_complex.tensor import ComplexTensor
 
-from funasr.models.frontend.utils.mask_estimator import MaskEstimator
+from funasr.frontends.utils.mask_estimator import MaskEstimator
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 
 
diff --git a/funasr/models/frontend/utils/feature_transform.py b/funasr/frontends/utils/feature_transform.py
similarity index 100%
rename from funasr/models/frontend/utils/feature_transform.py
rename to funasr/frontends/utils/feature_transform.py
diff --git a/funasr/models/frontend/utils/frontend.py b/funasr/frontends/utils/frontend.py
similarity index 96%
rename from funasr/models/frontend/utils/frontend.py
rename to funasr/frontends/utils/frontend.py
index 8f84e73..a2a6cc1 100644
--- a/funasr/models/frontend/utils/frontend.py
+++ b/funasr/frontends/utils/frontend.py
@@ -8,8 +8,8 @@
 import torch.nn as nn
 from torch_complex.tensor import ComplexTensor
 
-from funasr.models.frontend.utils.dnn_beamformer import DNN_Beamformer
-from funasr.models.frontend.utils.dnn_wpe import DNN_WPE
+from funasr.frontends.utils.dnn_beamformer import DNN_Beamformer
+from funasr.frontends.utils.dnn_wpe import DNN_WPE
 
 
 class Frontend(nn.Module):
diff --git a/funasr/models/frontend/utils/log_mel.py b/funasr/frontends/utils/log_mel.py
similarity index 100%
rename from funasr/models/frontend/utils/log_mel.py
rename to funasr/frontends/utils/log_mel.py
diff --git a/funasr/models/frontend/utils/mask_estimator.py b/funasr/frontends/utils/mask_estimator.py
similarity index 100%
rename from funasr/models/frontend/utils/mask_estimator.py
rename to funasr/frontends/utils/mask_estimator.py
diff --git a/funasr/models/frontend/utils/stft.py b/funasr/frontends/utils/stft.py
similarity index 98%
rename from funasr/models/frontend/utils/stft.py
rename to funasr/frontends/utils/stft.py
index 2a0bc96..00d9ec5 100644
--- a/funasr/models/frontend/utils/stft.py
+++ b/funasr/frontends/utils/stft.py
@@ -10,7 +10,7 @@
 except:
     print("Please install torch_complex firstly")
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.frontend.utils.complex_utils import is_complex
+from funasr.frontends.utils.complex_utils import is_complex
 
 import librosa
 import numpy as np
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/frontends/wav_frontend.py
similarity index 97%
rename from funasr/models/frontend/wav_frontend.py
rename to funasr/frontends/wav_frontend.py
index ac16065..4866fa1 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/frontends/wav_frontend.py
@@ -4,11 +4,13 @@
 
 import numpy as np
 import torch
+import torch.nn as nn
 import torchaudio.compliance.kaldi as kaldi
 from torch.nn.utils.rnn import pad_sequence
 
-import funasr.models.frontend.eend_ola_feature as eend_ola_feature
-from funasr.models.frontend.abs_frontend import AbsFrontend
+import funasr.frontends.eend_ola_feature as eend_ola_feature
+from funasr.utils.register import register_class
+
 
 
 def load_cmvn(cmvn_file):
@@ -73,8 +75,8 @@
     LFR_outputs = torch.vstack(LFR_inputs)
     return LFR_outputs.type(torch.float32)
 
-
-class WavFrontend(AbsFrontend):
+@register_class("frontend_classes", "WavFrontend")
+class WavFrontend(nn.Module):
     """Conventional frontend structure for ASR.
     """
 
@@ -93,6 +95,7 @@
             dither: float = 1.0,
             snip_edges: bool = True,
             upsacle_samples: bool = True,
+            **kwargs,
     ):
         super().__init__()
         self.fs = fs
@@ -208,7 +211,8 @@
         return feats_pad, feats_lens
 
 
-class WavFrontendOnline(AbsFrontend):
+@register_class("frontend_classes", "WavFrontendOnline")
+class WavFrontendOnline(nn.Module):
     """Conventional frontend structure for streaming ASR/VAD.
     """
 
@@ -227,6 +231,7 @@
             dither: float = 1.0,
             snip_edges: bool = True,
             upsacle_samples: bool = True,
+            **kwargs,
     ):
         super().__init__()
         self.fs = fs
@@ -454,7 +459,7 @@
         self.lfr_splice_cache = []
 
 
-class WavFrontendMel23(AbsFrontend):
+class WavFrontendMel23(nn.Module):
     """Conventional frontend structure for ASR.
     """
 
@@ -465,6 +470,7 @@
             frame_shift: int = 10,
             lfr_m: int = 1,
             lfr_n: int = 1,
+            **kwargs,
     ):
         super().__init__()
         self.fs = fs
diff --git a/funasr/models/frontend/wav_frontend_kaldifeat.py b/funasr/frontends/wav_frontend_kaldifeat.py
similarity index 100%
rename from funasr/models/frontend/wav_frontend_kaldifeat.py
rename to funasr/frontends/wav_frontend_kaldifeat.py
diff --git a/funasr/models/frontend/windowing.py b/funasr/frontends/windowing.py
similarity index 96%
rename from funasr/models/frontend/windowing.py
rename to funasr/frontends/windowing.py
index 94c9d27..3250550 100644
--- a/funasr/models/frontend/windowing.py
+++ b/funasr/frontends/windowing.py
@@ -4,12 +4,12 @@
 
 """Sliding Window for raw audio input data."""
 
-from funasr.models.frontend.abs_frontend import AbsFrontend
 import torch
+import torch.nn as nn
 from typing import Tuple
 
 
-class SlidingWindow(AbsFrontend):
+class SlidingWindow(nn.Module):
     """Sliding Window.
     Provides a sliding window over a batched continuous raw audio tensor.
     Optionally, provides padding (Currently not implemented).
diff --git a/funasr/metrics/compute_acc.py b/funasr/metrics/compute_acc.py
new file mode 100644
index 0000000..2b45836
--- /dev/null
+++ b/funasr/metrics/compute_acc.py
@@ -0,0 +1,23 @@
+import torch
+
+def th_accuracy(pad_outputs, pad_targets, ignore_label):
+    """Calculate accuracy.
+
+    Args:
+        pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
+        pad_targets (LongTensor): Target label tensors (B, Lmax).
+        ignore_label (int): Ignore label id.
+
+    Returns:
+        float: Accuracy value (0.0 - 1.0).
+
+    """
+    pad_pred = pad_outputs.view(
+        pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
+    ).argmax(2)
+    mask = pad_targets != ignore_label
+    numerator = torch.sum(
+        pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
+    )
+    denominator = torch.sum(mask)
+    return float(numerator) / float(denominator)
diff --git a/funasr/models/bat/attention.py b/funasr/models/bat/attention.py
new file mode 100644
index 0000000..11645b3
--- /dev/null
+++ b/funasr/models/bat/attention.py
@@ -0,0 +1,238 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+from typing import Optional, Tuple
+
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+import funasr.models.lora.layers as lora
+
+
+class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
+    """RelPositionMultiHeadedAttention definition.
+    Args:
+        num_heads: Number of attention heads.
+        embed_size: Embedding size.
+        dropout_rate: Dropout rate.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        embed_size: int,
+        dropout_rate: float = 0.0,
+        simplified_attention_score: bool = False,
+    ) -> None:
+        """Construct an MultiHeadedAttention object."""
+        super().__init__()
+
+        self.d_k = embed_size // num_heads
+        self.num_heads = num_heads
+
+        assert self.d_k * num_heads == embed_size, (
+            "embed_size (%d) must be divisible by num_heads (%d)",
+            (embed_size, num_heads),
+        )
+
+        self.linear_q = torch.nn.Linear(embed_size, embed_size)
+        self.linear_k = torch.nn.Linear(embed_size, embed_size)
+        self.linear_v = torch.nn.Linear(embed_size, embed_size)
+
+        self.linear_out = torch.nn.Linear(embed_size, embed_size)
+
+        if simplified_attention_score:
+            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
+
+            self.compute_att_score = self.compute_simplified_attention_score
+        else:
+            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
+
+            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            torch.nn.init.xavier_uniform_(self.pos_bias_u)
+            torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+            self.compute_att_score = self.compute_attention_score
+
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.attn = None
+
+    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+        """Compute relative positional encoding.
+        Args:
+            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
+            left_context: Number of frames in left context.
+        Returns:
+            x: Output sequence. (B, H, T_1, T_2)
+        """
+        batch_size, n_heads, time1, n = x.shape
+        time2 = time1 + left_context
+
+        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
+
+        return x.as_strided(
+            (batch_size, n_heads, time1, time2),
+            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
+            storage_offset=(n_stride * (time1 - 1)),
+        )
+
+    def compute_simplified_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Simplified attention score computation.
+        Reference: https://github.com/k2-fsa/icefall/pull/458
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+        """
+        pos_enc = self.linear_pos(pos_enc)
+
+        matrix_ac = torch.matmul(query, key.transpose(2, 3))
+
+        matrix_bd = self.rel_shift(
+            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
+            left_context=left_context,
+        )
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def compute_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Attention score computation.
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+        """
+        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
+
+        query = query.transpose(1, 2)
+        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
+        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def forward_qkv(
+        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Transform query, key and value.
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            v: Value tensor. (B, T_2, size)
+        Returns:
+            q: Transformed query tensor. (B, H, T_1, d_k)
+            k: Transformed key tensor. (B, H, T_2, d_k)
+            v: Transformed value tensor. (B, H, T_2, d_k)
+        """
+        n_batch = query.size(0)
+
+        q = (
+            self.linear_q(query)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        k = (
+            self.linear_k(key)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        v = (
+            self.linear_v(value)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+
+        return q, k, v
+
+    def forward_attention(
+        self,
+        value: torch.Tensor,
+        scores: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Compute attention context vector.
+        Args:
+            value: Transformed value. (B, H, T_2, d_k)
+            scores: Attention score. (B, H, T_1, T_2)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+        Returns:
+           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
+        """
+        batch_size = scores.size(0)
+        mask = mask.unsqueeze(1).unsqueeze(2)
+        if chunk_mask is not None:
+            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
+        scores = scores.masked_fill(mask, float("-inf"))
+        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+
+        attn_output = self.dropout(self.attn)
+        attn_output = torch.matmul(attn_output, value)
+
+        attn_output = self.linear_out(
+            attn_output.transpose(1, 2)
+            .contiguous()
+            .view(batch_size, -1, self.num_heads * self.d_k)
+        )
+
+        return attn_output
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Compute scaled dot product attention with rel. positional encoding.
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            value: Value tensor. (B, T_2, size)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+            left_context: Number of frames in left context.
+        Returns:
+            : Output tensor. (B, T_1, H * d_k)
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
+        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
+
diff --git a/funasr/models/bat/cif_predictor.py b/funasr/models/bat/cif_predictor.py
new file mode 100644
index 0000000..9aa3e33
--- /dev/null
+++ b/funasr/models/bat/cif_predictor.py
@@ -0,0 +1,220 @@
+# import torch
+# from torch import nn
+# from torch import Tensor
+# import logging
+# import numpy as np
+# from funasr.train_utils.device_funcs import to_device
+# from funasr.models.transformer.utils.nets_utils import make_pad_mask
+# from funasr.models.scama.utils import sequence_mask
+# from typing import Optional, Tuple
+#
+# from funasr.utils.register import register_class
+#
+# class mae_loss(nn.Module):
+#
+#     def __init__(self, normalize_length=False):
+#         super(mae_loss, self).__init__()
+#         self.normalize_length = normalize_length
+#         self.criterion = torch.nn.L1Loss(reduction='sum')
+#
+#     def forward(self, token_length, pre_token_length):
+#         loss_token_normalizer = token_length.size(0)
+#         if self.normalize_length:
+#             loss_token_normalizer = token_length.sum().type(torch.float32)
+#         loss = self.criterion(token_length, pre_token_length)
+#         loss = loss / loss_token_normalizer
+#         return loss
+#
+#
+# def cif(hidden, alphas, threshold):
+#     batch_size, len_time, hidden_size = hidden.size()
+#
+#     # loop varss
+#     integrate = torch.zeros([batch_size], device=hidden.device)
+#     frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
+#     # intermediate vars along time
+#     list_fires = []
+#     list_frames = []
+#
+#     for t in range(len_time):
+#         alpha = alphas[:, t]
+#         distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
+#
+#         integrate += alpha
+#         list_fires.append(integrate)
+#
+#         fire_place = integrate >= threshold
+#         integrate = torch.where(fire_place,
+#                                 integrate - torch.ones([batch_size], device=hidden.device),
+#                                 integrate)
+#         cur = torch.where(fire_place,
+#                           distribution_completion,
+#                           alpha)
+#         remainds = alpha - cur
+#
+#         frame += cur[:, None] * hidden[:, t, :]
+#         list_frames.append(frame)
+#         frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
+#                             remainds[:, None] * hidden[:, t, :],
+#                             frame)
+#
+#     fires = torch.stack(list_fires, 1)
+#     frames = torch.stack(list_frames, 1)
+#     list_ls = []
+#     len_labels = torch.round(alphas.sum(-1)).int()
+#     max_label_len = len_labels.max()
+#     for b in range(batch_size):
+#         fire = fires[b, :]
+#         l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
+#         pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
+#         list_ls.append(torch.cat([l, pad_l], 0))
+#     return torch.stack(list_ls, 0), fires
+#
+#
+# def cif_wo_hidden(alphas, threshold):
+#     batch_size, len_time = alphas.size()
+#
+#     # loop varss
+#     integrate = torch.zeros([batch_size], device=alphas.device)
+#     # intermediate vars along time
+#     list_fires = []
+#
+#     for t in range(len_time):
+#         alpha = alphas[:, t]
+#
+#         integrate += alpha
+#         list_fires.append(integrate)
+#
+#         fire_place = integrate >= threshold
+#         integrate = torch.where(fire_place,
+#                                 integrate - torch.ones([batch_size], device=alphas.device)*threshold,
+#                                 integrate)
+#
+#     fires = torch.stack(list_fires, 1)
+#     return fires
+#
+# @register_class("predictor_classes", "BATPredictor")
+# class BATPredictor(nn.Module):
+#     def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
+#         super(BATPredictor, self).__init__()
+#
+#         self.pad = nn.ConstantPad1d((l_order, r_order), 0)
+#         self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
+#         self.cif_output = nn.Linear(idim, 1)
+#         self.dropout = torch.nn.Dropout(p=dropout)
+#         self.threshold = threshold
+#         self.smooth_factor = smooth_factor
+#         self.noise_threshold = noise_threshold
+#         self.return_accum = return_accum
+#
+#     def cif(
+#         self,
+#         input: Tensor,
+#         alpha: Tensor,
+#         beta: float = 1.0,
+#         return_accum: bool = False,
+#     ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+#         B, S, C = input.size()
+#         assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}"
+#
+#         dtype = alpha.dtype
+#         alpha = alpha.float()
+#
+#         alpha_sum = alpha.sum(1)
+#         feat_lengths = (alpha_sum / beta).floor().long()
+#         T = feat_lengths.max()
+#
+#         # aggregate and integrate
+#         csum = alpha.cumsum(-1)
+#         with torch.no_grad():
+#             # indices used for scattering
+#             right_idx = (csum / beta).floor().long().clip(max=T)
+#             left_idx = right_idx.roll(1, dims=1)
+#             left_idx[:, 0] = 0
+#
+#             # count # of fires from each source
+#             fire_num = right_idx - left_idx
+#             extra_weights = (fire_num - 1).clip(min=0)
+#             # The extra entry in last dim is for
+#             output = input.new_zeros((B, T + 1, C))
+#             source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input)
+#             zero = alpha.new_zeros((1,))
+#
+#         # right scatter
+#         fire_mask = fire_num > 0
+#         right_weight = torch.where(
+#             fire_mask,
+#             csum - right_idx.type_as(alpha) * beta,
+#             zero
+#         ).type_as(input)
+#         # assert right_weight.ge(0).all(), f"{right_weight} should be non-negative."
+#         output.scatter_add_(
+#             1,
+#             right_idx.unsqueeze(-1).expand(-1, -1, C),
+#             right_weight.unsqueeze(-1) * input
+#         )
+#
+#         # left scatter
+#         left_weight = (
+#             alpha - right_weight - extra_weights.type_as(alpha) * beta
+#         ).type_as(input)
+#         output.scatter_add_(
+#             1,
+#             left_idx.unsqueeze(-1).expand(-1, -1, C),
+#             left_weight.unsqueeze(-1) * input
+#         )
+#
+#          # extra scatters
+#         if extra_weights.ge(0).any():
+#             extra_steps = extra_weights.max().item()
+#             tgt_idx = left_idx
+#             src_feats = input * beta
+#             for _ in range(extra_steps):
+#                 tgt_idx = (tgt_idx + 1).clip(max=T)
+#                 # (B, S, 1)
+#                 src_mask = (extra_weights > 0)
+#                 output.scatter_add_(
+#                     1,
+#                     tgt_idx.unsqueeze(-1).expand(-1, -1, C),
+#                     src_feats * src_mask.unsqueeze(2)
+#                 )
+#                 extra_weights -= 1
+#
+#         output = output[:, :T, :]
+#
+#         if return_accum:
+#             return output, csum
+#         else:
+#             return output, alpha
+#
+#     def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None):
+#         h = hidden
+#         context = h.transpose(1, 2)
+#         queries = self.pad(context)
+#         memory = self.cif_conv1d(queries)
+#         output = memory + context
+#         output = self.dropout(output)
+#         output = output.transpose(1, 2)
+#         output = torch.relu(output)
+#         output = self.cif_output(output)
+#         alphas = torch.sigmoid(output)
+#         alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold)
+#         if mask is not None:
+#             alphas = alphas * mask.transpose(-1, -2).float()
+#         if mask_chunk_predictor is not None:
+#             alphas = alphas * mask_chunk_predictor
+#         alphas = alphas.squeeze(-1)
+#         if target_label_length is not None:
+#             target_length = target_label_length
+#         elif target_label is not None:
+#             target_length = (target_label != ignore_id).float().sum(-1)
+#             # logging.info("target_length: {}".format(target_length))
+#         else:
+#             target_length = None
+#         token_num = alphas.sum(-1)
+#         if target_length is not None:
+#             # length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5
+#             # target_length = length_noise + target_length
+#             alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1))
+#         acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum)
+#         return acoustic_embeds, token_num, alphas, cif_peak
diff --git a/funasr/models/bat/conformer_chunk_encoder.py b/funasr/models/bat/conformer_chunk_encoder.py
new file mode 100644
index 0000000..2dc03c3
--- /dev/null
+++ b/funasr/models/bat/conformer_chunk_encoder.py
@@ -0,0 +1,701 @@
+
+"""Conformer encoder definition."""
+
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import torch
+from torch import nn
+
+
+from funasr.models.bat.attention import (
+    RelPositionMultiHeadedAttentionChunk,
+)
+from funasr.models.transformer.embedding import (
+    StreamingRelPositionalEncoding,
+)
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.nets_utils import get_activation
+from funasr.models.transformer.utils.nets_utils import (
+    TooShortUttError,
+    check_short_utt,
+    make_chunk_mask,
+    make_source_mask,
+)
+from funasr.models.transformer.positionwise_feed_forward import (
+    PositionwiseFeedForward,
+)
+from funasr.models.transformer.utils.repeat import repeat, MultiBlocks
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+from funasr.models.transformer.utils.subsampling import StreamingConvInput
+from funasr.utils.register import register_class
+
+
+
+class ChunkEncoderLayer(nn.Module):
+    """Chunk Conformer module definition.
+    Args:
+        block_size: Input/output size.
+        self_att: Self-attention module instance.
+        feed_forward: Feed-forward module instance.
+        feed_forward_macaron: Feed-forward module instance for macaron network.
+        conv_mod: Convolution module instance.
+        norm_class: Normalization module class.
+        norm_args: Normalization module arguments.
+        dropout_rate: Dropout rate.
+    """
+
+    def __init__(
+        self,
+        block_size: int,
+        self_att: torch.nn.Module,
+        feed_forward: torch.nn.Module,
+        feed_forward_macaron: torch.nn.Module,
+        conv_mod: torch.nn.Module,
+        norm_class: torch.nn.Module = LayerNorm,
+        norm_args: Dict = {},
+        dropout_rate: float = 0.0,
+    ) -> None:
+        """Construct a Conformer object."""
+        super().__init__()
+
+        self.self_att = self_att
+
+        self.feed_forward = feed_forward
+        self.feed_forward_macaron = feed_forward_macaron
+        self.feed_forward_scale = 0.5
+
+        self.conv_mod = conv_mod
+
+        self.norm_feed_forward = norm_class(block_size, **norm_args)
+        self.norm_self_att = norm_class(block_size, **norm_args)
+
+        self.norm_macaron = norm_class(block_size, **norm_args)
+        self.norm_conv = norm_class(block_size, **norm_args)
+        self.norm_final = norm_class(block_size, **norm_args)
+
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+        self.block_size = block_size
+        self.cache = None
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset self-attention and convolution modules cache for streaming.
+        Args:
+            left_context: Number of left frames during chunk-by-chunk inference.
+            device: Device to use for cache tensor.
+        """
+        self.cache = [
+            torch.zeros(
+                (1, left_context, self.block_size),
+                device=device,
+            ),
+            torch.zeros(
+                (
+                    1,
+                    self.block_size,
+                    self.conv_mod.kernel_size - 1,
+                ),
+                device=device,
+            ),
+        ]
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Conformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T)
+            chunk_mask: Chunk mask. (T_2, T_2)
+        Returns:
+            x: Conformer output sequences. (B, T, D_block)
+            mask: Source mask. (B, T)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+        """
+        residual = x
+
+        x = self.norm_macaron(x)
+        x = residual + self.feed_forward_scale * self.dropout(
+            self.feed_forward_macaron(x)
+        )
+
+        residual = x
+        x = self.norm_self_att(x)
+        x_q = x
+        x = residual + self.dropout(
+            self.self_att(
+                x_q,
+                x,
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=chunk_mask,
+            )
+        )
+
+        residual = x
+
+        x = self.norm_conv(x)
+        x, _ = self.conv_mod(x)
+        x = residual + self.dropout(x)
+        residual = x
+
+        x = self.norm_feed_forward(x)
+        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
+
+        x = self.norm_final(x)
+        return x, mask, pos_enc
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_size: int = 16,
+        left_context: int = 0,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode chunk of input sequence.
+        Args:
+            x: Conformer input sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+            mask: Source mask. (B, T_2)
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+        Returns:
+            x: Conformer output sequences. (B, T, D_block)
+            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+        """
+        residual = x
+
+        x = self.norm_macaron(x)
+        x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
+
+        residual = x
+        x = self.norm_self_att(x)
+        if left_context > 0:
+            key = torch.cat([self.cache[0], x], dim=1)
+        else:
+            key = x
+        val = key
+
+        if right_context > 0:
+            att_cache = key[:, -(left_context + right_context) : -right_context, :]
+        else:
+            att_cache = key[:, -left_context:, :]
+        x = residual + self.self_att(
+            x,
+            key,
+            val,
+            pos_enc,
+            mask,
+            left_context=left_context,
+        )
+
+        residual = x
+        x = self.norm_conv(x)
+        x, conv_cache = self.conv_mod(
+            x, cache=self.cache[1], right_context=right_context
+        )
+        x = residual + x
+        residual = x
+
+        x = self.norm_feed_forward(x)
+        x = residual + self.feed_forward_scale * self.feed_forward(x)
+
+        x = self.norm_final(x)
+        self.cache = [att_cache, conv_cache]
+
+        return x, pos_enc
+
+
+
+class CausalConvolution(nn.Module):
+    """ConformerConvolution module definition.
+    Args:
+        channels: The number of channels.
+        kernel_size: Size of the convolving kernel.
+        activation: Type of activation function.
+        norm_args: Normalization module arguments.
+        causal: Whether to use causal convolution (set to True if streaming).
+    """
+
+    def __init__(
+        self,
+        channels: int,
+        kernel_size: int,
+        activation: torch.nn.Module = torch.nn.ReLU(),
+        norm_args: Dict = {},
+        causal: bool = False,
+    ) -> None:
+        """Construct an ConformerConvolution object."""
+        super().__init__()
+
+        assert (kernel_size - 1) % 2 == 0
+
+        self.kernel_size = kernel_size
+
+        self.pointwise_conv1 = torch.nn.Conv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        )
+
+        if causal:
+            self.lorder = kernel_size - 1
+            padding = 0
+        else:
+            self.lorder = 0
+            padding = (kernel_size - 1) // 2
+
+        self.depthwise_conv = torch.nn.Conv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=padding,
+            groups=channels,
+        )
+        self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
+        self.pointwise_conv2 = torch.nn.Conv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+        )
+
+        self.activation = activation
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        cache: Optional[torch.Tensor] = None,
+        right_context: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute convolution module.
+        Args:
+            x: ConformerConvolution input sequences. (B, T, D_hidden)
+            cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
+            right_context: Number of frames in right context.
+        Returns:
+            x: ConformerConvolution output sequences. (B, T, D_hidden)
+            cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
+        """
+        x = self.pointwise_conv1(x.transpose(1, 2))
+        x = torch.nn.functional.glu(x, dim=1)
+
+        if self.lorder > 0:
+            if cache is None:
+                x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+            else:
+                x = torch.cat([cache, x], dim=2)
+
+                if right_context > 0:
+                    cache = x[:, :, -(self.lorder + right_context) : -right_context]
+                else:
+                    cache = x[:, :, -self.lorder :]
+
+        x = self.depthwise_conv(x)
+        x = self.activation(self.norm(x))
+
+        x = self.pointwise_conv2(x).transpose(1, 2)
+
+        return x, cache
+
+@register_class("encoder_classes", "ConformerChunkEncoder")
+class ConformerChunkEncoder(nn.Module):
+    """Encoder module definition.
+    Args:
+        input_size: Input size.
+        body_conf: Encoder body configuration.
+        input_conf: Encoder input configuration.
+        main_conf: Encoder main configuration.
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        embed_vgg_like: bool = False,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 3,
+        macaron_style: bool = False,
+        rel_pos_type: str = "legacy",
+        pos_enc_layer_type: str = "rel_pos",
+        selfattention_layer_type: str = "rel_selfattn",
+        activation_type: str = "swish",
+        use_cnn_module: bool = True,
+        zero_triu: bool = False,
+        norm_type: str = "layer_norm",
+        cnn_module_kernel: int = 31,
+        conv_mod_norm_eps: float = 0.00001,
+        conv_mod_norm_momentum: float = 0.1,
+        simplified_att_score: bool = False,
+        dynamic_chunk_training: bool = False,
+        short_chunk_threshold: float = 0.75,
+        short_chunk_size: int = 25,
+        left_chunk_size: int = 0,
+        time_reduction_factor: int = 1,
+        unified_model_training: bool = False,
+        default_chunk_size: int = 16,
+        jitter_range: int = 4,
+        subsampling_factor: int = 1,
+    ) -> None:
+        """Construct an Encoder object."""
+        super().__init__()
+
+
+        self.embed = StreamingConvInput(
+            input_size,
+            output_size,
+            subsampling_factor,
+            vgg_like=embed_vgg_like,
+            output_size=output_size,
+        )
+
+        self.pos_enc = StreamingRelPositionalEncoding(
+            output_size,
+            positional_dropout_rate,
+        )
+
+        activation = get_activation(
+            activation_type
+       )
+
+        pos_wise_args = (
+            output_size,
+            linear_units,
+            positional_dropout_rate,
+            activation,
+        )
+
+        conv_mod_norm_args = {
+            "eps": conv_mod_norm_eps,
+            "momentum": conv_mod_norm_momentum,
+        }
+
+        conv_mod_args = (
+            output_size,
+            cnn_module_kernel,
+            activation,
+            conv_mod_norm_args,
+            dynamic_chunk_training or unified_model_training,
+        )
+
+        mult_att_args = (
+            attention_heads,
+            output_size,
+            attention_dropout_rate,
+            simplified_att_score,
+        )
+
+
+        fn_modules = []
+        for _ in range(num_blocks):
+            module = lambda: ChunkEncoderLayer(
+                output_size,
+                RelPositionMultiHeadedAttentionChunk(*mult_att_args),
+                PositionwiseFeedForward(*pos_wise_args),
+                PositionwiseFeedForward(*pos_wise_args),
+                CausalConvolution(*conv_mod_args),
+                dropout_rate=dropout_rate,
+            )
+            fn_modules.append(module)
+
+        self.encoders = MultiBlocks(
+            [fn() for fn in fn_modules],
+            output_size,
+        )
+
+        self._output_size = output_size
+
+        self.dynamic_chunk_training = dynamic_chunk_training
+        self.short_chunk_threshold = short_chunk_threshold
+        self.short_chunk_size = short_chunk_size
+        self.left_chunk_size = left_chunk_size
+
+        self.unified_model_training = unified_model_training
+        self.default_chunk_size = default_chunk_size
+        self.jitter_range = jitter_range
+
+        self.time_reduction_factor = time_reduction_factor
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
+        """Return the corresponding number of sample for a given chunk size, in frames.
+        Where size is the number of features frames after applying subsampling.
+        Args:
+            size: Number of frames after subsampling.
+            hop_length: Frontend's hop length
+        Returns:
+            : Number of raw samples
+        """
+        return self.embed.get_size_before_subsampling(size) * hop_length
+
+    def get_encoder_input_size(self, size: int) -> int:
+        """Return the corresponding number of sample for a given chunk size, in frames.
+        Where size is the number of features frames after applying subsampling.
+        Args:
+            size: Number of frames after subsampling.
+        Returns:
+            : Number of raw samples
+        """
+        return self.embed.get_size_before_subsampling(size)
+
+
+    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+        """Initialize/Reset encoder streaming cache.
+        Args:
+            left_context: Number of frames in left context.
+            device: Device ID.
+        """
+        return self.encoders.reset_streaming_cache(left_context, device)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len).to(x.device)
+
+        if self.unified_model_training:
+            if self.training:
+                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+            else:
+                chunk_size = self.default_chunk_size
+            x, mask = self.embed(x, mask, chunk_size)
+            pos_enc = self.pos_enc(x)
+            chunk_mask = make_chunk_mask(
+                x.size(1),
+                chunk_size,
+                left_chunk_size=self.left_chunk_size,
+                device=x.device,
+            )
+            x_utt = self.encoders(
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=None,
+            )
+            x_chunk = self.encoders(
+                x,
+                pos_enc,
+                mask,
+                chunk_mask=chunk_mask,
+            )
+
+            olens = mask.eq(0).sum(1)
+            if self.time_reduction_factor > 1:
+                x_utt = x_utt[:,::self.time_reduction_factor,:]
+                x_chunk = x_chunk[:,::self.time_reduction_factor,:]
+                olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+            return x_utt, x_chunk, olens
+
+        elif self.dynamic_chunk_training:
+            max_len = x.size(1)
+            if self.training:
+                chunk_size = torch.randint(1, max_len, (1,)).item()
+
+                if chunk_size > (max_len * self.short_chunk_threshold):
+                    chunk_size = max_len
+                else:
+                    chunk_size = (chunk_size % self.short_chunk_size) + 1
+            else:
+                chunk_size = self.default_chunk_size
+
+            x, mask = self.embed(x, mask, chunk_size)
+            pos_enc = self.pos_enc(x)
+
+            chunk_mask = make_chunk_mask(
+                x.size(1),
+                chunk_size,
+                left_chunk_size=self.left_chunk_size,
+                device=x.device,
+            )
+        else:
+            x, mask = self.embed(x, mask, None)
+            pos_enc = self.pos_enc(x)
+            chunk_mask = None
+        x = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=chunk_mask,
+        )
+
+        olens = mask.eq(0).sum(1)
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+        return x, olens, None
+
+    def full_utt_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: Encoder input features. (B, T_in, F)
+            x_len: Encoder input features lengths. (B,)
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+           x_len: Encoder outputs lenghts. (B,)
+        """
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len).to(x.device)
+        x, mask = self.embed(x, mask, None)
+        pos_enc = self.pos_enc(x)
+        x_utt = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=None,
+        )
+
+        if self.time_reduction_factor > 1:
+            x_utt = x_utt[:,::self.time_reduction_factor,:]
+        return x_utt
+
+    def simu_chunk_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        short_status, limit_size = check_short_utt(
+            self.embed.subsampling_factor, x.size(1)
+        )
+
+        if short_status:
+            raise TooShortUttError(
+                f"has {x.size(1)} frames and is too short for subsampling "
+                + f"(it needs more than {limit_size} frames), return empty results",
+                x.size(1),
+                limit_size,
+            )
+
+        mask = make_source_mask(x_len)
+
+        x, mask = self.embed(x, mask, chunk_size)
+        pos_enc = self.pos_enc(x)
+        chunk_mask = make_chunk_mask(
+            x.size(1),
+            chunk_size,
+            left_chunk_size=self.left_chunk_size,
+            device=x.device,
+        )
+
+        x = self.encoders(
+            x,
+            pos_enc,
+            mask,
+            chunk_mask=chunk_mask,
+        )
+        olens = mask.eq(0).sum(1)
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+
+        return x
+
+    def chunk_forward(
+        self,
+        x: torch.Tensor,
+        x_len: torch.Tensor,
+        processed_frames: torch.tensor,
+        chunk_size: int = 16,
+        left_context: int = 32,
+        right_context: int = 0,
+    ) -> torch.Tensor:
+        """Encode input sequences as chunks.
+        Args:
+            x: Encoder input features. (1, T_in, F)
+            x_len: Encoder input features lengths. (1,)
+            processed_frames: Number of frames already seen.
+            left_context: Number of frames in left context.
+            right_context: Number of frames in right context.
+        Returns:
+           x: Encoder outputs. (B, T_out, D_enc)
+        """
+        mask = make_source_mask(x_len)
+        x, mask = self.embed(x, mask, None)
+
+        if left_context > 0:
+            processed_mask = (
+                torch.arange(left_context, device=x.device)
+                .view(1, left_context)
+                .flip(1)
+            )
+            processed_mask = processed_mask >= processed_frames
+            mask = torch.cat([processed_mask, mask], dim=1)
+        pos_enc = self.pos_enc(x, left_context=left_context)
+        x = self.encoders.chunk_forward(
+            x,
+            pos_enc,
+            mask,
+            chunk_size=chunk_size,
+            left_context=left_context,
+            right_context=right_context,
+        )
+
+        if right_context > 0:
+            x = x[:, 0:-right_context, :]
+
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+        return x
diff --git a/funasr/models/bat/model.py b/funasr/models/bat/model.py
index 83ce302..d814e31 100644
--- a/funasr/models/bat/model.py
+++ b/funasr/models/bat/model.py
@@ -5,34 +5,23 @@
 from typing import Dict, List, Optional, Tuple, Union
 
 import torch
+import torch.nn as nn
 from packaging.version import parse as V
 from funasr.losses.label_smoothing_loss import (
     LabelSmoothingLoss,  # noqa: H301
 )
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.decoder.rnnt_decoder import RNNTDecoder
-from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.joint_net.joint_network import JointNetwork
+
 from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.transformer.add_sos_eos import add_sos_eos
-from funasr.layers.abs_normalize import AbsNormalize
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.train_utils.device_funcs import force_gatherable
-from funasr.models.base_model import FunASRModel
 
-if V(torch.__version__) >= V("1.6.0"):
-    from torch.cuda.amp import autocast
-else:
-
-    @contextmanager
-    def autocast(enabled=True):
-        yield
+from torch.cuda.amp import autocast
 
 
-class BATModel(FunASRModel):
+
+class BATModel(nn.Module):
     """BATModel module definition.
 
     Args:
@@ -61,18 +50,7 @@
 
     def __init__(
         self,
-        vocab_size: int,
-        token_list: Union[Tuple[str, ...], List[str]],
-        frontend: Optional[AbsFrontend],
-        specaug: Optional[AbsSpecAug],
-        normalize: Optional[AbsNormalize],
-        encoder: AbsEncoder,
-        decoder: RNNTDecoder,
-        joint_network: JointNetwork,
-        att_decoder: Optional[AbsAttDecoder] = None,
-        predictor = None,
-        transducer_weight: float = 1.0,
-        predictor_weight: float = 1.0,
+        
         cif_weight: float = 1.0,
         fastemit_lambda: float = 0.0,
         auxiliary_ctc_weight: float = 0.0,
@@ -89,6 +67,7 @@
         length_normalized_loss: bool = False,
         r_d: int = 5,
         r_u: int = 5,
+        **kwargs,
     ) -> None:
         """Construct an BATModel object."""
         super().__init__()
diff --git a/funasr/models/bici_paraformer/cif_predictor.py b/funasr/models/bici_paraformer/cif_predictor.py
new file mode 100644
index 0000000..67d801c
--- /dev/null
+++ b/funasr/models/bici_paraformer/cif_predictor.py
@@ -0,0 +1,340 @@
+import torch
+from torch import nn
+from torch import Tensor
+import logging
+import numpy as np
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.scama.utils import sequence_mask
+from typing import Optional, Tuple
+
+from funasr.utils.register import register_class
+
+
+class mae_loss(nn.Module):
+
+    def __init__(self, normalize_length=False):
+        super(mae_loss, self).__init__()
+        self.normalize_length = normalize_length
+        self.criterion = torch.nn.L1Loss(reduction='sum')
+
+    def forward(self, token_length, pre_token_length):
+        loss_token_normalizer = token_length.size(0)
+        if self.normalize_length:
+            loss_token_normalizer = token_length.sum().type(torch.float32)
+        loss = self.criterion(token_length, pre_token_length)
+        loss = loss / loss_token_normalizer
+        return loss
+
+
+def cif(hidden, alphas, threshold):
+    batch_size, len_time, hidden_size = hidden.size()
+
+    # loop varss
+    integrate = torch.zeros([batch_size], device=hidden.device)
+    frame = torch.zeros([batch_size, hidden_size], device=hidden.device)
+    # intermediate vars along time
+    list_fires = []
+    list_frames = []
+
+    for t in range(len_time):
+        alpha = alphas[:, t]
+        distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate
+
+        integrate += alpha
+        list_fires.append(integrate)
+
+        fire_place = integrate >= threshold
+        integrate = torch.where(fire_place,
+                                integrate - torch.ones([batch_size], device=hidden.device),
+                                integrate)
+        cur = torch.where(fire_place,
+                          distribution_completion,
+                          alpha)
+        remainds = alpha - cur
+
+        frame += cur[:, None] * hidden[:, t, :]
+        list_frames.append(frame)
+        frame = torch.where(fire_place[:, None].repeat(1, hidden_size),
+                            remainds[:, None] * hidden[:, t, :],
+                            frame)
+
+    fires = torch.stack(list_fires, 1)
+    frames = torch.stack(list_frames, 1)
+    list_ls = []
+    len_labels = torch.round(alphas.sum(-1)).int()
+    max_label_len = len_labels.max()
+    for b in range(batch_size):
+        fire = fires[b, :]
+        l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze())
+        pad_l = torch.zeros([max_label_len - l.size(0), hidden_size], device=hidden.device)
+        list_ls.append(torch.cat([l, pad_l], 0))
+    return torch.stack(list_ls, 0), fires
+
+
+def cif_wo_hidden(alphas, threshold):
+    batch_size, len_time = alphas.size()
+
+    # loop varss
+    integrate = torch.zeros([batch_size], device=alphas.device)
+    # intermediate vars along time
+    list_fires = []
+
+    for t in range(len_time):
+        alpha = alphas[:, t]
+
+        integrate += alpha
+        list_fires.append(integrate)
+
+        fire_place = integrate >= threshold
+        integrate = torch.where(fire_place,
+                                integrate - torch.ones([batch_size], device=alphas.device)*threshold,
+                                integrate)
+
+    fires = torch.stack(list_fires, 1)
+    return fires
+
+@register_class("predictor_classes", "CifPredictorV3")
+class CifPredictorV3(nn.Module):
+    def __init__(self,
+                 idim,
+                 l_order,
+                 r_order,
+                 threshold=1.0,
+                 dropout=0.1,
+                 smooth_factor=1.0,
+                 noise_threshold=0,
+                 tail_threshold=0.0,
+                 tf2torch_tensor_name_prefix_torch="predictor",
+                 tf2torch_tensor_name_prefix_tf="seq2seq/cif",
+                 smooth_factor2=1.0,
+                 noise_threshold2=0,
+                 upsample_times=5,
+                 upsample_type="cnn",
+                 use_cif1_cnn=True,
+                 tail_mask=True,
+                 ):
+        super(CifPredictorV3, self).__init__()
+
+        self.pad = nn.ConstantPad1d((l_order, r_order), 0)
+        self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
+        self.cif_output = nn.Linear(idim, 1)
+        self.dropout = torch.nn.Dropout(p=dropout)
+        self.threshold = threshold
+        self.smooth_factor = smooth_factor
+        self.noise_threshold = noise_threshold
+        self.tail_threshold = tail_threshold
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+
+        self.upsample_times = upsample_times
+        self.upsample_type = upsample_type
+        self.use_cif1_cnn = use_cif1_cnn
+        if self.upsample_type == 'cnn':
+            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+            self.cif_output2 = nn.Linear(idim, 1)
+        elif self.upsample_type == 'cnn_blstm':
+            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+            self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
+            self.cif_output2 = nn.Linear(idim*2, 1)
+        elif self.upsample_type == 'cnn_attn':
+            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
+            from funasr.models.transformer.encoder import EncoderLayer as TransformerEncoderLayer
+            from funasr.models.transformer.attention import MultiHeadedAttention
+            from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
+            positionwise_layer_args = (
+                idim,
+                idim*2,
+                0.1,
+            )
+            self.self_attn = TransformerEncoderLayer(
+                idim,
+                MultiHeadedAttention(
+                    4, idim, 0.1
+                ),
+                PositionwiseFeedForward(*positionwise_layer_args),
+                0.1,
+                True, #normalize_before,
+                False, #concat_after,
+            )
+            self.cif_output2 = nn.Linear(idim, 1)
+        self.smooth_factor2 = smooth_factor2
+        self.noise_threshold2 = noise_threshold2
+
+    def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
+                target_label_length=None):
+        h = hidden
+        context = h.transpose(1, 2)
+        queries = self.pad(context)
+        output = torch.relu(self.cif_conv1d(queries))
+
+        # alphas2 is an extra head for timestamp prediction
+        if not self.use_cif1_cnn:
+            _output = context
+        else:
+            _output = output
+        if self.upsample_type == 'cnn':
+            output2 = self.upsample_cnn(_output)
+            output2 = output2.transpose(1,2)
+        elif self.upsample_type == 'cnn_blstm':
+            output2 = self.upsample_cnn(_output)
+            output2 = output2.transpose(1,2)
+            output2, (_, _) = self.blstm(output2)
+        elif self.upsample_type == 'cnn_attn':
+            output2 = self.upsample_cnn(_output)
+            output2 = output2.transpose(1,2)
+            output2, _ = self.self_attn(output2, mask)
+        # import pdb; pdb.set_trace()
+        alphas2 = torch.sigmoid(self.cif_output2(output2))
+        alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
+        # repeat the mask in T demension to match the upsampled length
+        if mask is not None:
+            mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
+            mask2 = mask2.unsqueeze(-1)
+            alphas2 = alphas2 * mask2
+        alphas2 = alphas2.squeeze(-1)
+        token_num2 = alphas2.sum(-1)
+
+        output = output.transpose(1, 2)
+
+        output = self.cif_output(output)
+        alphas = torch.sigmoid(output)
+        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+        if mask is not None:
+            mask = mask.transpose(-1, -2).float()
+            alphas = alphas * mask
+        if mask_chunk_predictor is not None:
+            alphas = alphas * mask_chunk_predictor
+        alphas = alphas.squeeze(-1)
+        mask = mask.squeeze(-1)
+        if target_label_length is not None:
+            target_length = target_label_length
+        elif target_label is not None:
+            target_length = (target_label != ignore_id).float().sum(-1)
+        else:
+            target_length = None
+        token_num = alphas.sum(-1)
+
+        if target_length is not None:
+            alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+        elif self.tail_threshold > 0.0:
+            hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
+
+        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+        if target_length is None and self.tail_threshold > 0.0:
+            token_num_int = torch.max(token_num).type(torch.int32).item()
+            acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
+        return acoustic_embeds, token_num, alphas, cif_peak, token_num2
+
+    def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
+        h = hidden
+        b = hidden.shape[0]
+        context = h.transpose(1, 2)
+        queries = self.pad(context)
+        output = torch.relu(self.cif_conv1d(queries))
+
+        # alphas2 is an extra head for timestamp prediction
+        if not self.use_cif1_cnn:
+            _output = context
+        else:
+            _output = output
+        if self.upsample_type == 'cnn':
+            output2 = self.upsample_cnn(_output)
+            output2 = output2.transpose(1,2)
+        elif self.upsample_type == 'cnn_blstm':
+            output2 = self.upsample_cnn(_output)
+            output2 = output2.transpose(1,2)
+            output2, (_, _) = self.blstm(output2)
+        elif self.upsample_type == 'cnn_attn':
+            output2 = self.upsample_cnn(_output)
+            output2 = output2.transpose(1,2)
+            output2, _ = self.self_attn(output2, mask)
+        alphas2 = torch.sigmoid(self.cif_output2(output2))
+        alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
+        # repeat the mask in T demension to match the upsampled length
+        if mask is not None:
+            mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
+            mask2 = mask2.unsqueeze(-1)
+            alphas2 = alphas2 * mask2
+        alphas2 = alphas2.squeeze(-1)
+        _token_num = alphas2.sum(-1)
+        if token_num is not None:
+            alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
+        # re-downsample
+        ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
+        ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
+        # upsampled alphas and cif_peak
+        us_alphas = alphas2
+        us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
+        return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
+
+    def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
+        b, t, d = hidden.size()
+        tail_threshold = self.tail_threshold
+        if mask is not None:
+            zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
+            ones_t = torch.ones_like(zeros_t)
+            mask_1 = torch.cat([mask, zeros_t], dim=1)
+            mask_2 = torch.cat([ones_t, mask], dim=1)
+            mask = mask_2 - mask_1
+            tail_threshold = mask * tail_threshold
+            alphas = torch.cat([alphas, zeros_t], dim=1)
+            alphas = torch.add(alphas, tail_threshold)
+        else:
+            tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
+            tail_threshold = torch.reshape(tail_threshold, (1, 1))
+            alphas = torch.cat([alphas, tail_threshold], dim=1)
+        zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
+        hidden = torch.cat([hidden, zeros], dim=1)
+        token_num = alphas.sum(dim=-1)
+        token_num_floor = torch.floor(token_num)
+
+        return hidden, alphas, token_num_floor
+
+    def gen_frame_alignments(self,
+                             alphas: torch.Tensor = None,
+                             encoder_sequence_length: torch.Tensor = None):
+        batch_size, maximum_length = alphas.size()
+        int_type = torch.int32
+
+        is_training = self.training
+        if is_training:
+            token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
+        else:
+            token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
+
+        max_token_num = torch.max(token_num).item()
+
+        alphas_cumsum = torch.cumsum(alphas, dim=1)
+        alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
+        alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
+
+        index = torch.ones([batch_size, max_token_num], dtype=int_type)
+        index = torch.cumsum(index, dim=1)
+        index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
+
+        index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
+        index_div_bool_zeros = index_div.eq(0)
+        index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
+        index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
+        token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
+        index_div_bool_zeros_count *= token_num_mask
+
+        index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
+        ones = torch.ones_like(index_div_bool_zeros_count_tile)
+        zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
+        ones = torch.cumsum(ones, dim=2)
+        cond = index_div_bool_zeros_count_tile == ones
+        index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
+
+        index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
+        index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
+        index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
+        index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
+        predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
+            int_type).to(encoder_sequence_length.device)
+        index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
+
+        predictor_alignments = index_div_bool_zeros_count_tile_out
+        predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
+        return predictor_alignments.detach(), predictor_alignments_length.detach()
diff --git a/funasr/models/bici_paraformer/model.py b/funasr/models/bici_paraformer/model.py
new file mode 100644
index 0000000..23a6985
--- /dev/null
+++ b/funasr/models/bici_paraformer/model.py
@@ -0,0 +1,328 @@
+
+import logging
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+import tempfile
+import codecs
+import requests
+import re
+import copy
+import torch
+import torch.nn as nn
+import random
+import numpy as np
+import time
+
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import force_gatherable
+
+from funasr.models.paraformer.search import Hypothesis
+
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.utils.register import register_class, registry_tables
+from funasr.models.ctc.ctc import CTC
+
+from funasr.models.paraformer.model import Paraformer
+
+@register_class("model_classes", "BiCifParaformer")
+class BiCifParaformer(Paraformer):
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+	https://arxiv.org/abs/2206.08317
+	"""
+	
+	def __init__(
+		self,
+		*args,
+		**kwargs,
+	):
+		super().__init__(*args, **kwargs)
+
+
+	def _calc_pre2_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		_, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
+		
+		# loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
+		
+		return loss_pre2
+	
+	
+	def _calc_att_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
+		                                                                             encoder_out_mask,
+		                                                                             ignore_id=self.ignore_id)
+		
+		# 0. sampler
+		decoder_out_1st = None
+		if self.sampling_ratio > 0.0:
+			sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+			                                               pre_acoustic_embeds)
+		else:
+			sematic_embeds = pre_acoustic_embeds
+		
+		# 1. Forward decoder
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+		)
+		decoder_out, _ = decoder_outs[0], decoder_outs[1]
+		
+		if decoder_out_1st is None:
+			decoder_out_1st = decoder_out
+		# 2. Compute attention loss
+		loss_att = self.criterion_att(decoder_out, ys_pad)
+		acc_att = th_accuracy(
+			decoder_out_1st.view(-1, self.vocab_size),
+			ys_pad,
+			ignore_label=self.ignore_id,
+		)
+		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		
+		# Compute cer/wer using attention-decoder
+		if self.training or self.error_calculator is None:
+			cer_att, wer_att = None, None
+		else:
+			ys_hat = decoder_out_1st.argmax(dim=-1)
+			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+		
+		return loss_att, acc_att, cer_att, wer_att, loss_pre
+
+
+	def calc_predictor(self, encoder_out, encoder_out_lens):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+		                                                                                                  None,
+		                                                                                                  encoder_out_mask,
+		                                                                                                  ignore_id=self.ignore_id)
+		return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+
+	def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
+		                                                                                    encoder_out_mask,
+		                                                                                    token_num)
+		return ds_alphas, ds_cif_peak, us_alphas, us_peaks
+	
+	
+	def forward(
+		self,
+		speech: torch.Tensor,
+		speech_lengths: torch.Tensor,
+		text: torch.Tensor,
+		text_lengths: torch.Tensor,
+		**kwargs,
+	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+		"""Frontend + Encoder + Decoder + Calc loss
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				text: (Batch, Length)
+				text_lengths: (Batch,)
+		"""
+		if len(text_lengths.size()) > 1:
+			text_lengths = text_lengths[:, 0]
+		if len(speech_lengths.size()) > 1:
+			speech_lengths = speech_lengths[:, 0]
+		
+		batch_size = speech.shape[0]
+		
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+
+		loss_ctc, cer_ctc = None, None
+		loss_pre = None
+		stats = dict()
+		
+		# decoder: CTC branch
+		if self.ctc_weight != 0.0:
+			loss_ctc, cer_ctc = self._calc_ctc_loss(
+				encoder_out, encoder_out_lens, text, text_lengths
+			)
+			
+			# Collect CTC branch stats
+			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+			stats["cer_ctc"] = cer_ctc
+
+
+		# decoder: Attention decoder branch
+		loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+			encoder_out, encoder_out_lens, text, text_lengths
+		)
+		
+		loss_pre2 = self._calc_pre2_loss(
+			encoder_out, encoder_out_lens, text, text_lengths
+		)
+		
+		# 3. CTC-Att loss definition
+		if self.ctc_weight == 0.0:
+			loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+		else:
+			loss = self.ctc_weight * loss_ctc + (
+				1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+		
+		# Collect Attn branch stats
+		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+		stats["acc"] = acc_att
+		stats["cer"] = cer_att
+		stats["wer"] = wer_att
+		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+		stats["loss_pre2"] = loss_pre2.detach().cpu()
+		
+		stats["loss"] = torch.clone(loss.detach())
+		
+		# force_gatherable: to-device and to-tensor if scalar for DataParallel
+		if self.length_normalized_loss:
+			batch_size = int((text_lengths + self.predictor_bias).sum())
+		
+		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+		return loss, stats, weight
+	
+	def generate(self,
+	             data_in: list,
+	             data_lengths: list = None,
+	             key: list = None,
+	             tokenizer=None,
+	             frontend=None,
+	             **kwargs,
+	             ):
+		
+		# init beamsearch
+		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		if self.beam_search is None and (is_use_lm or is_use_ctc):
+			logging.info("enable beam_search")
+			self.init_beam_search(**kwargs)
+			self.nbest = kwargs.get("nbest", 1)
+		
+		meta_data = {}
+		# extract fbank feats
+		time1 = time.perf_counter()
+		audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+		time2 = time.perf_counter()
+		meta_data["load_data"] = f"{time2 - time1:0.3f}"
+		speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+		                                       frontend=self.frontend)
+		time3 = time.perf_counter()
+		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+		meta_data[
+			"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+		
+		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+		
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+		
+		# predictor
+		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+		                                                                predictor_outs[2], predictor_outs[3]
+		pre_token_length = pre_token_length.round().long()
+		if torch.max(pre_token_length) < 1:
+			return []
+		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+		                                               pre_token_length)
+		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+		
+		# BiCifParaformer, test no bias cif2
+
+		_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
+			                                                                    pre_token_length)
+		
+		results = []
+		b, n, d = decoder_out.size()
+		for i in range(b):
+			x = encoder_out[i, :encoder_out_lens[i], :]
+			am_scores = decoder_out[i, :pre_token_length[i], :]
+			if self.beam_search is not None:
+				nbest_hyps = self.beam_search(
+					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+					minlenratio=kwargs.get("minlenratio", 0.0)
+				)
+				
+				nbest_hyps = nbest_hyps[: self.nbest]
+			else:
+				
+				yseq = am_scores.argmax(dim=-1)
+				score = am_scores.max(dim=-1)[0]
+				score = torch.sum(score, dim=-1)
+				# pad with mask tokens to ensure compatibility with sos/eos tokens
+				yseq = torch.tensor(
+					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+				)
+				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+			for nbest_idx, hyp in enumerate(nbest_hyps):
+				ibest_writer = None
+				if ibest_writer is None and kwargs.get("output_dir") is not None:
+					writer = DatadirWriter(kwargs.get("output_dir"))
+					ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+				# remove sos/eos and get results
+				last_pos = -1
+				if isinstance(hyp.yseq, list):
+					token_int = hyp.yseq[1:last_pos]
+				else:
+					token_int = hyp.yseq[1:last_pos].tolist()
+				
+				# remove blank symbol id, which is assumed to be 0
+				token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+				
+				# Change integer-ids to tokens
+				token = tokenizer.ids2tokens(token_int)
+				text = tokenizer.tokens2text(token)
+				
+				_, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
+				                                           us_peaks[i][:encoder_out_lens[i] * 3],
+				                                           copy.copy(token),
+				                                           vad_offset=kwargs.get("begin_time", 0))
+				
+				text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token, timestamp)
+				
+				result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
+				            "time_stamp_postprocessed": time_stamp_postprocessed,
+				            "word_lists": word_lists
+				            }
+				results.append(result_i)
+				
+				if ibest_writer is not None:
+					ibest_writer["token"][key[i]] = " ".join(token)
+					ibest_writer["text"][key[i]] = text
+					ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+					
+		
+		return results, meta_data
diff --git a/funasr/models/branchformer/branchformer_encoder.py b/funasr/models/branchformer/encoder.py
similarity index 98%
rename from funasr/models/branchformer/branchformer_encoder.py
rename to funasr/models/branchformer/encoder.py
index 0f037c8..11b6429 100644
--- a/funasr/models/branchformer/branchformer_encoder.py
+++ b/funasr/models/branchformer/encoder.py
@@ -16,8 +16,8 @@
 
 import numpy
 import torch
+import torch.nn as nn
 
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.branchformer.cgmlp import ConvolutionalGatingMLP
 from funasr.models.branchformer.fastformer import FastSelfAttention
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
@@ -33,8 +33,8 @@
     ScaledPositionalEncoding,
 )
 from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import (
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import (
     Conv2dSubsampling,
     Conv2dSubsampling2,
     Conv2dSubsampling6,
@@ -43,6 +43,7 @@
     check_short_utt,
 )
 
+from funasr.utils.register import register_class
 
 class BranchformerEncoderLayer(torch.nn.Module):
     """Branchformer encoder layer module.
@@ -290,8 +291,8 @@
 
         return x, mask
 
-
-class BranchformerEncoder(AbsEncoder):
+@register_class("encoder_classes", "BranchformerEncoder")
+class BranchformerEncoder(nn.Module):
     """Branchformer encoder module."""
 
     def __init__(
diff --git a/funasr/models/branchformer/model.py b/funasr/models/branchformer/model.py
index 5cb2af7..a14b407 100644
--- a/funasr/models/branchformer/model.py
+++ b/funasr/models/branchformer/model.py
@@ -1,57 +1,9 @@
 import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-import time
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
-	LabelSmoothingLoss,  # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
-
-from funasr.models.model_class_factory import *
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
-	from torch.cuda.amp import autocast
-else:
-	# Nothing to do if torch<1.6.0
-	@contextmanager
-	def autocast(enabled=True):
-		yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
 
 from funasr.models.transformer.model import Transformer
+from funasr.utils.register import register_class
 
+@register_class("model_classes", "Branchformer")
 class Branchformer(Transformer):
 	"""CTC-attention hybrid Encoder-Decoder model"""
 
diff --git a/funasr/models/cnn/DTDNN.py b/funasr/models/cnn/DTDNN.py
index 3de0b1e..02fcfdf 100644
--- a/funasr/models/cnn/DTDNN.py
+++ b/funasr/models/cnn/DTDNN.py
@@ -6,7 +6,7 @@
 import torch.nn.functional as F
 from torch import nn
 
-from funasr.modules.cnn.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
+from funasr.models.cnn.layers import DenseLayer, StatsPool, TDNNLayer, CAMDenseTDNNBlock, TransitLayer, \
     BasicResBlock, get_nonlinear
 
 
diff --git a/funasr/models/cnn/ResNet.py b/funasr/models/cnn/ResNet.py
index 54c3901..b846e9e 100644
--- a/funasr/models/cnn/ResNet.py
+++ b/funasr/models/cnn/ResNet.py
@@ -16,7 +16,7 @@
 import torch.nn.functional as F
 
 import funasr.models.sond.pooling.pooling_layers as pooling_layers
-from funasr.modules.cnn.fusion import AFF
+from funasr.models.cnn.fusion import AFF
 
 
 class ReLU(nn.Hardtanh):
diff --git a/funasr/models/cnn/ResNet_aug.py b/funasr/models/cnn/ResNet_aug.py
index 6b03c67..95416ef 100644
--- a/funasr/models/cnn/ResNet_aug.py
+++ b/funasr/models/cnn/ResNet_aug.py
@@ -16,7 +16,7 @@
 import torch.nn.functional as F
 
 import funasr.models.sond.pooling.pooling_layers as pooling_layers
-from funasr.modules.cnn.fusion import AFF
+from funasr.models.cnn.fusion import AFF
 
 
 class ReLU(nn.Hardtanh):
diff --git a/funasr/models/conformer/conformer_encoder.py b/funasr/models/conformer/conformer_encoder.py
deleted file mode 100644
index 4b8cfe6..0000000
--- a/funasr/models/conformer/conformer_encoder.py
+++ /dev/null
@@ -1,1280 +0,0 @@
-# Copyright 2020 Tomoki Hayashi
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
-"""Conformer encoder definition."""
-
-import logging
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-from typing import Dict
-
-import torch
-from torch import nn
-
-from funasr.models.ctc import CTC
-from funasr.models.transformer.attention import (
-    MultiHeadedAttention,  # noqa: H301
-    RelPositionMultiHeadedAttention,  # noqa: H301
-    RelPositionMultiHeadedAttentionChunk,
-    LegacyRelPositionMultiHeadedAttention,  # noqa: H301
-)
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.transformer.embedding import (
-    PositionalEncoding,  # noqa: H301
-    ScaledPositionalEncoding,  # noqa: H301
-    RelPositionalEncoding,  # noqa: H301
-    LegacyRelPositionalEncoding,  # noqa: H301
-    StreamingRelPositionalEncoding,
-)
-from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
-from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
-from funasr.models.transformer.utils.nets_utils import get_activation
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.transformer.utils.nets_utils import (
-    TooShortUttError,
-    check_short_utt,
-    make_chunk_mask,
-    make_source_mask,
-)
-from funasr.models.transformer.positionwise_feed_forward import (
-    PositionwiseFeedForward,  # noqa: H301
-)
-from funasr.models.transformer.repeat import repeat, MultiBlocks
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
-from funasr.models.transformer.subsampling import Conv2dSubsamplingPad
-from funasr.models.transformer.subsampling import StreamingConvInput
-
-class ConvolutionModule(nn.Module):
-    """ConvolutionModule in Conformer model.
-
-    Args:
-        channels (int): The number of channels of conv layers.
-        kernel_size (int): Kernerl size of conv layers.
-
-    """
-
-    def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
-        """Construct an ConvolutionModule object."""
-        super(ConvolutionModule, self).__init__()
-        # kernerl_size should be a odd number for 'SAME' padding
-        assert (kernel_size - 1) % 2 == 0
-
-        self.pointwise_conv1 = nn.Conv1d(
-            channels,
-            2 * channels,
-            kernel_size=1,
-            stride=1,
-            padding=0,
-            bias=bias,
-        )
-        self.depthwise_conv = nn.Conv1d(
-            channels,
-            channels,
-            kernel_size,
-            stride=1,
-            padding=(kernel_size - 1) // 2,
-            groups=channels,
-            bias=bias,
-        )
-        self.norm = nn.BatchNorm1d(channels)
-        self.pointwise_conv2 = nn.Conv1d(
-            channels,
-            channels,
-            kernel_size=1,
-            stride=1,
-            padding=0,
-            bias=bias,
-        )
-        self.activation = activation
-
-    def forward(self, x):
-        """Compute convolution module.
-
-        Args:
-            x (torch.Tensor): Input tensor (#batch, time, channels).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time, channels).
-
-        """
-        # exchange the temporal dimension and the feature dimension
-        x = x.transpose(1, 2)
-
-        # GLU mechanism
-        x = self.pointwise_conv1(x)  # (batch, 2*channel, dim)
-        x = nn.functional.glu(x, dim=1)  # (batch, channel, dim)
-
-        # 1D Depthwise Conv
-        x = self.depthwise_conv(x)
-        x = self.activation(self.norm(x))
-
-        x = self.pointwise_conv2(x)
-
-        return x.transpose(1, 2)
-
-
-class EncoderLayer(nn.Module):
-    """Encoder layer module.
-
-    Args:
-        size (int): Input dimension.
-        self_attn (torch.nn.Module): Self-attention module instance.
-            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
-            can be used as the argument.
-        feed_forward (torch.nn.Module): Feed-forward module instance.
-            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
-            can be used as the argument.
-        feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
-            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
-            can be used as the argument.
-        conv_module (torch.nn.Module): Convolution module instance.
-            `ConvlutionModule` instance can be used as the argument.
-        dropout_rate (float): Dropout rate.
-        normalize_before (bool): Whether to use layer_norm before the first block.
-        concat_after (bool): Whether to concat attention layer's input and output.
-            if True, additional linear will be applied.
-            i.e. x -> x + linear(concat(x, att(x)))
-            if False, no additional linear will be applied. i.e. x -> x + att(x)
-        stochastic_depth_rate (float): Proability to skip this layer.
-            During training, the layer may skip residual computation and return input
-            as-is with given probability.
-    """
-
-    def __init__(
-            self,
-            size,
-            self_attn,
-            feed_forward,
-            feed_forward_macaron,
-            conv_module,
-            dropout_rate,
-            normalize_before=True,
-            concat_after=False,
-            stochastic_depth_rate=0.0,
-    ):
-        """Construct an EncoderLayer object."""
-        super(EncoderLayer, self).__init__()
-        self.self_attn = self_attn
-        self.feed_forward = feed_forward
-        self.feed_forward_macaron = feed_forward_macaron
-        self.conv_module = conv_module
-        self.norm_ff = LayerNorm(size)  # for the FNN module
-        self.norm_mha = LayerNorm(size)  # for the MHA module
-        if feed_forward_macaron is not None:
-            self.norm_ff_macaron = LayerNorm(size)
-            self.ff_scale = 0.5
-        else:
-            self.ff_scale = 1.0
-        if self.conv_module is not None:
-            self.norm_conv = LayerNorm(size)  # for the CNN module
-            self.norm_final = LayerNorm(size)  # for the final output of the block
-        self.dropout = nn.Dropout(dropout_rate)
-        self.size = size
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        if self.concat_after:
-            self.concat_linear = nn.Linear(size + size, size)
-        self.stochastic_depth_rate = stochastic_depth_rate
-
-    def forward(self, x_input, mask, cache=None):
-        """Compute encoded features.
-
-        Args:
-            x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
-                - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
-                - w/o pos emb: Tensor (#batch, time, size).
-            mask (torch.Tensor): Mask tensor for the input (#batch, time).
-            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time, size).
-            torch.Tensor: Mask tensor (#batch, time).
-
-        """
-        if isinstance(x_input, tuple):
-            x, pos_emb = x_input[0], x_input[1]
-        else:
-            x, pos_emb = x_input, None
-
-        skip_layer = False
-        # with stochastic depth, residual connection `x + f(x)` becomes
-        # `x <- x + 1 / (1 - p) * f(x)` at training time.
-        stoch_layer_coeff = 1.0
-        if self.training and self.stochastic_depth_rate > 0:
-            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
-            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
-        if skip_layer:
-            if cache is not None:
-                x = torch.cat([cache, x], dim=1)
-            if pos_emb is not None:
-                return (x, pos_emb), mask
-            return x, mask
-
-        # whether to use macaron style
-        if self.feed_forward_macaron is not None:
-            residual = x
-            if self.normalize_before:
-                x = self.norm_ff_macaron(x)
-            x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
-                self.feed_forward_macaron(x)
-            )
-            if not self.normalize_before:
-                x = self.norm_ff_macaron(x)
-
-        # multi-headed self-attention module
-        residual = x
-        if self.normalize_before:
-            x = self.norm_mha(x)
-
-        if cache is None:
-            x_q = x
-        else:
-            assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
-            x_q = x[:, -1:, :]
-            residual = residual[:, -1:, :]
-            mask = None if mask is None else mask[:, -1:, :]
-
-        if pos_emb is not None:
-            x_att = self.self_attn(x_q, x, x, pos_emb, mask)
-        else:
-            x_att = self.self_attn(x_q, x, x, mask)
-
-        if self.concat_after:
-            x_concat = torch.cat((x, x_att), dim=-1)
-            x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
-        else:
-            x = residual + stoch_layer_coeff * self.dropout(x_att)
-        if not self.normalize_before:
-            x = self.norm_mha(x)
-
-        # convolution module
-        if self.conv_module is not None:
-            residual = x
-            if self.normalize_before:
-                x = self.norm_conv(x)
-            x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
-            if not self.normalize_before:
-                x = self.norm_conv(x)
-
-        # feed forward module
-        residual = x
-        if self.normalize_before:
-            x = self.norm_ff(x)
-        x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
-            self.feed_forward(x)
-        )
-        if not self.normalize_before:
-            x = self.norm_ff(x)
-
-        if self.conv_module is not None:
-            x = self.norm_final(x)
-
-        if cache is not None:
-            x = torch.cat([cache, x], dim=1)
-
-        if pos_emb is not None:
-            return (x, pos_emb), mask
-
-        return x, mask
-
-class ChunkEncoderLayer(torch.nn.Module):
-    """Chunk Conformer module definition.
-    Args:
-        block_size: Input/output size.
-        self_att: Self-attention module instance.
-        feed_forward: Feed-forward module instance.
-        feed_forward_macaron: Feed-forward module instance for macaron network.
-        conv_mod: Convolution module instance.
-        norm_class: Normalization module class.
-        norm_args: Normalization module arguments.
-        dropout_rate: Dropout rate.
-    """
-
-    def __init__(
-        self,
-        block_size: int,
-        self_att: torch.nn.Module,
-        feed_forward: torch.nn.Module,
-        feed_forward_macaron: torch.nn.Module,
-        conv_mod: torch.nn.Module,
-        norm_class: torch.nn.Module = LayerNorm,
-        norm_args: Dict = {},
-        dropout_rate: float = 0.0,
-    ) -> None:
-        """Construct a Conformer object."""
-        super().__init__()
-
-        self.self_att = self_att
-
-        self.feed_forward = feed_forward
-        self.feed_forward_macaron = feed_forward_macaron
-        self.feed_forward_scale = 0.5
-
-        self.conv_mod = conv_mod
-
-        self.norm_feed_forward = norm_class(block_size, **norm_args)
-        self.norm_self_att = norm_class(block_size, **norm_args)
-
-        self.norm_macaron = norm_class(block_size, **norm_args)
-        self.norm_conv = norm_class(block_size, **norm_args)
-        self.norm_final = norm_class(block_size, **norm_args)
-
-        self.dropout = torch.nn.Dropout(dropout_rate)
-
-        self.block_size = block_size
-        self.cache = None
-
-    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
-        """Initialize/Reset self-attention and convolution modules cache for streaming.
-        Args:
-            left_context: Number of left frames during chunk-by-chunk inference.
-            device: Device to use for cache tensor.
-        """
-        self.cache = [
-            torch.zeros(
-                (1, left_context, self.block_size),
-                device=device,
-            ),
-            torch.zeros(
-                (
-                    1,
-                    self.block_size,
-                    self.conv_mod.kernel_size - 1,
-                ),
-                device=device,
-            ),
-        ]
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        """Encode input sequences.
-        Args:
-            x: Conformer input sequences. (B, T, D_block)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-            mask: Source mask. (B, T)
-            chunk_mask: Chunk mask. (T_2, T_2)
-        Returns:
-            x: Conformer output sequences. (B, T, D_block)
-            mask: Source mask. (B, T)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-        """
-        residual = x
-
-        x = self.norm_macaron(x)
-        x = residual + self.feed_forward_scale * self.dropout(
-            self.feed_forward_macaron(x)
-        )
-
-        residual = x
-        x = self.norm_self_att(x)
-        x_q = x
-        x = residual + self.dropout(
-            self.self_att(
-                x_q,
-                x,
-                x,
-                pos_enc,
-                mask,
-                chunk_mask=chunk_mask,
-            )
-        )
-
-        residual = x
-
-        x = self.norm_conv(x)
-        x, _ = self.conv_mod(x)
-        x = residual + self.dropout(x)
-        residual = x
-
-        x = self.norm_feed_forward(x)
-        x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
-
-        x = self.norm_final(x)
-        return x, mask, pos_enc
-
-    def chunk_forward(
-        self,
-        x: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_size: int = 16,
-        left_context: int = 0,
-        right_context: int = 0,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Encode chunk of input sequence.
-        Args:
-            x: Conformer input sequences. (B, T, D_block)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-            mask: Source mask. (B, T_2)
-            left_context: Number of frames in left context.
-            right_context: Number of frames in right context.
-        Returns:
-            x: Conformer output sequences. (B, T, D_block)
-            pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
-        """
-        residual = x
-
-        x = self.norm_macaron(x)
-        x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
-
-        residual = x
-        x = self.norm_self_att(x)
-        if left_context > 0:
-            key = torch.cat([self.cache[0], x], dim=1)
-        else:
-            key = x
-        val = key
-
-        if right_context > 0:
-            att_cache = key[:, -(left_context + right_context) : -right_context, :]
-        else:
-            att_cache = key[:, -left_context:, :]
-        x = residual + self.self_att(
-            x,
-            key,
-            val,
-            pos_enc,
-            mask,
-            left_context=left_context,
-        )
-
-        residual = x
-        x = self.norm_conv(x)
-        x, conv_cache = self.conv_mod(
-            x, cache=self.cache[1], right_context=right_context
-        )
-        x = residual + x
-        residual = x
-
-        x = self.norm_feed_forward(x)
-        x = residual + self.feed_forward_scale * self.feed_forward(x)
-
-        x = self.norm_final(x)
-        self.cache = [att_cache, conv_cache]
-
-        return x, pos_enc
-
-
-class ConformerEncoder(AbsEncoder):
-    """Conformer encoder module.
-
-    Args:
-        input_size (int): Input dimension.
-        output_size (int): Dimension of attention.
-        attention_heads (int): The number of heads of multi head attention.
-        linear_units (int): The number of units of position-wise feed forward.
-        num_blocks (int): The number of decoder blocks.
-        dropout_rate (float): Dropout rate.
-        attention_dropout_rate (float): Dropout rate in attention.
-        positional_dropout_rate (float): Dropout rate after adding positional encoding.
-        input_layer (Union[str, torch.nn.Module]): Input layer type.
-        normalize_before (bool): Whether to use layer_norm before the first block.
-        concat_after (bool): Whether to concat attention layer's input and output.
-            If True, additional linear will be applied.
-            i.e. x -> x + linear(concat(x, att(x)))
-            If False, no additional linear will be applied. i.e. x -> x + att(x)
-        positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
-        positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
-        rel_pos_type (str): Whether to use the latest relative positional encoding or
-            the legacy one. The legacy relative positional encoding will be deprecated
-            in the future. More Details can be found in
-            https://github.com/espnet/espnet/pull/2816.
-        encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
-        encoder_attn_layer_type (str): Encoder attention layer type.
-        activation_type (str): Encoder activation function type.
-        macaron_style (bool): Whether to use macaron style for positionwise layer.
-        use_cnn_module (bool): Whether to use convolution module.
-        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
-        cnn_module_kernel (int): Kernerl size of convolution module.
-        padding_idx (int): Padding idx for input_layer=embed.
-
-    """
-
-    def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            input_layer: str = "conv2d",
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            positionwise_layer_type: str = "linear",
-            positionwise_conv_kernel_size: int = 3,
-            macaron_style: bool = False,
-            rel_pos_type: str = "legacy",
-            pos_enc_layer_type: str = "rel_pos",
-            selfattention_layer_type: str = "rel_selfattn",
-            activation_type: str = "swish",
-            use_cnn_module: bool = True,
-            zero_triu: bool = False,
-            cnn_module_kernel: int = 31,
-            padding_idx: int = -1,
-            interctc_layer_idx: List[int] = [],
-            interctc_use_conditioning: bool = False,
-            stochastic_depth_rate: Union[float, List[float]] = 0.0,
-    ):
-        super().__init__()
-        self._output_size = output_size
-
-        if rel_pos_type == "legacy":
-            if pos_enc_layer_type == "rel_pos":
-                pos_enc_layer_type = "legacy_rel_pos"
-            if selfattention_layer_type == "rel_selfattn":
-                selfattention_layer_type = "legacy_rel_selfattn"
-        elif rel_pos_type == "latest":
-            assert selfattention_layer_type != "legacy_rel_selfattn"
-            assert pos_enc_layer_type != "legacy_rel_pos"
-        else:
-            raise ValueError("unknown rel_pos_type: " + rel_pos_type)
-
-        activation = get_activation(activation_type)
-        if pos_enc_layer_type == "abs_pos":
-            pos_enc_class = PositionalEncoding
-        elif pos_enc_layer_type == "scaled_abs_pos":
-            pos_enc_class = ScaledPositionalEncoding
-        elif pos_enc_layer_type == "rel_pos":
-            assert selfattention_layer_type == "rel_selfattn"
-            pos_enc_class = RelPositionalEncoding
-        elif pos_enc_layer_type == "legacy_rel_pos":
-            assert selfattention_layer_type == "legacy_rel_selfattn"
-            pos_enc_class = LegacyRelPositionalEncoding
-            logging.warning(
-                "Using legacy_rel_pos and it will be deprecated in the future."
-            )
-        else:
-            raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
-
-        if input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(input_size, output_size),
-                torch.nn.LayerNorm(output_size),
-                torch.nn.Dropout(dropout_rate),
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2d":
-            self.embed = Conv2dSubsampling(
-                input_size,
-                output_size,
-                dropout_rate,
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2dpad":
-            self.embed = Conv2dSubsamplingPad(
-                input_size,
-                output_size,
-                dropout_rate,
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2d2":
-            self.embed = Conv2dSubsampling2(
-                input_size,
-                output_size,
-                dropout_rate,
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2d6":
-            self.embed = Conv2dSubsampling6(
-                input_size,
-                output_size,
-                dropout_rate,
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2d8":
-            self.embed = Conv2dSubsampling8(
-                input_size,
-                output_size,
-                dropout_rate,
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif isinstance(input_layer, torch.nn.Module):
-            self.embed = torch.nn.Sequential(
-                input_layer,
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer is None:
-            self.embed = torch.nn.Sequential(
-                pos_enc_class(output_size, positional_dropout_rate)
-            )
-        else:
-            raise ValueError("unknown input_layer: " + input_layer)
-        self.normalize_before = normalize_before
-        if positionwise_layer_type == "linear":
-            positionwise_layer = PositionwiseFeedForward
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                dropout_rate,
-                activation,
-            )
-        elif positionwise_layer_type == "conv1d":
-            positionwise_layer = MultiLayeredConv1d
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d-linear":
-            positionwise_layer = Conv1dLinear
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        else:
-            raise NotImplementedError("Support only linear or conv1d.")
-
-        if selfattention_layer_type == "selfattn":
-            encoder_selfattn_layer = MultiHeadedAttention
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                attention_dropout_rate,
-            )
-        elif selfattention_layer_type == "legacy_rel_selfattn":
-            assert pos_enc_layer_type == "legacy_rel_pos"
-            encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                attention_dropout_rate,
-            )
-            logging.warning(
-                "Using legacy_rel_selfattn and it will be deprecated in the future."
-            )
-        elif selfattention_layer_type == "rel_selfattn":
-            assert pos_enc_layer_type == "rel_pos"
-            encoder_selfattn_layer = RelPositionMultiHeadedAttention
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                attention_dropout_rate,
-                zero_triu,
-            )
-        else:
-            raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
-
-        convolution_layer = ConvolutionModule
-        convolution_layer_args = (output_size, cnn_module_kernel, activation)
-
-        if isinstance(stochastic_depth_rate, float):
-            stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
-
-        if len(stochastic_depth_rate) != num_blocks:
-            raise ValueError(
-                f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
-                f"should be equal to num_blocks ({num_blocks})"
-            )
-
-        self.encoders = repeat(
-            num_blocks,
-            lambda lnum: EncoderLayer(
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args),
-                positionwise_layer(*positionwise_layer_args),
-                positionwise_layer(*positionwise_layer_args) if macaron_style else None,
-                convolution_layer(*convolution_layer_args) if use_cnn_module else None,
-                dropout_rate,
-                normalize_before,
-                concat_after,
-                stochastic_depth_rate[lnum],
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-
-        self.interctc_layer_idx = interctc_layer_idx
-        if len(interctc_layer_idx) > 0:
-            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
-        self.interctc_use_conditioning = interctc_use_conditioning
-        self.conditioning_layer = None
-
-    def output_size(self) -> int:
-        return self._output_size
-
-    def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            prev_states: torch.Tensor = None,
-            ctc: CTC = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Calculate forward propagation.
-
-        Args:
-            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
-            ilens (torch.Tensor): Input length (#batch).
-            prev_states (torch.Tensor): Not to be used now.
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, L, output_size).
-            torch.Tensor: Output length (#batch).
-            torch.Tensor: Not to be used now.
-
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-
-        if (
-                isinstance(self.embed, Conv2dSubsampling)
-                or isinstance(self.embed, Conv2dSubsampling2)
-                or isinstance(self.embed, Conv2dSubsampling6)
-                or isinstance(self.embed, Conv2dSubsampling8)
-                or isinstance(self.embed, Conv2dSubsamplingPad)
-        ):
-            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
-            if short_status:
-                raise TooShortUttError(
-                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
-                    + f"(it needs more than {limit_size} frames), return empty results",
-                    xs_pad.size(1),
-                    limit_size,
-                )
-            xs_pad, masks = self.embed(xs_pad, masks)
-        else:
-            xs_pad = self.embed(xs_pad)
-
-        intermediate_outs = []
-        if len(self.interctc_layer_idx) == 0:
-            xs_pad, masks = self.encoders(xs_pad, masks)
-        else:
-            for layer_idx, encoder_layer in enumerate(self.encoders):
-                xs_pad, masks = encoder_layer(xs_pad, masks)
-
-                if layer_idx + 1 in self.interctc_layer_idx:
-                    encoder_out = xs_pad
-                    if isinstance(encoder_out, tuple):
-                        encoder_out = encoder_out[0]
-
-                    # intermediate outputs are also normalized
-                    if self.normalize_before:
-                        encoder_out = self.after_norm(encoder_out)
-
-                    intermediate_outs.append((layer_idx + 1, encoder_out))
-
-                    if self.interctc_use_conditioning:
-                        ctc_out = ctc.softmax(encoder_out)
-
-                        if isinstance(xs_pad, tuple):
-                            x, pos_emb = xs_pad
-                            x = x + self.conditioning_layer(ctc_out)
-                            xs_pad = (x, pos_emb)
-                        else:
-                            xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
-        if isinstance(xs_pad, tuple):
-            xs_pad = xs_pad[0]
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-
-        olens = masks.squeeze(1).sum(1)
-        if len(intermediate_outs) > 0:
-            return (xs_pad, intermediate_outs), olens, None
-        return xs_pad, olens, None
-
-
-class CausalConvolution(torch.nn.Module):
-    """ConformerConvolution module definition.
-    Args:
-        channels: The number of channels.
-        kernel_size: Size of the convolving kernel.
-        activation: Type of activation function.
-        norm_args: Normalization module arguments.
-        causal: Whether to use causal convolution (set to True if streaming).
-    """
-
-    def __init__(
-        self,
-        channels: int,
-        kernel_size: int,
-        activation: torch.nn.Module = torch.nn.ReLU(),
-        norm_args: Dict = {},
-        causal: bool = False,
-    ) -> None:
-        """Construct an ConformerConvolution object."""
-        super().__init__()
-
-        assert (kernel_size - 1) % 2 == 0
-
-        self.kernel_size = kernel_size
-
-        self.pointwise_conv1 = torch.nn.Conv1d(
-            channels,
-            2 * channels,
-            kernel_size=1,
-            stride=1,
-            padding=0,
-        )
-
-        if causal:
-            self.lorder = kernel_size - 1
-            padding = 0
-        else:
-            self.lorder = 0
-            padding = (kernel_size - 1) // 2
-
-        self.depthwise_conv = torch.nn.Conv1d(
-            channels,
-            channels,
-            kernel_size,
-            stride=1,
-            padding=padding,
-            groups=channels,
-        )
-        self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
-        self.pointwise_conv2 = torch.nn.Conv1d(
-            channels,
-            channels,
-            kernel_size=1,
-            stride=1,
-            padding=0,
-        )
-
-        self.activation = activation
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        cache: Optional[torch.Tensor] = None,
-        right_context: int = 0,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Compute convolution module.
-        Args:
-            x: ConformerConvolution input sequences. (B, T, D_hidden)
-            cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
-            right_context: Number of frames in right context.
-        Returns:
-            x: ConformerConvolution output sequences. (B, T, D_hidden)
-            cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
-        """
-        x = self.pointwise_conv1(x.transpose(1, 2))
-        x = torch.nn.functional.glu(x, dim=1)
-
-        if self.lorder > 0:
-            if cache is None:
-                x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
-            else:
-                x = torch.cat([cache, x], dim=2)
-
-                if right_context > 0:
-                    cache = x[:, :, -(self.lorder + right_context) : -right_context]
-                else:
-                    cache = x[:, :, -self.lorder :]
-
-        x = self.depthwise_conv(x)
-        x = self.activation(self.norm(x))
-
-        x = self.pointwise_conv2(x).transpose(1, 2)
-
-        return x, cache
-
-class ConformerChunkEncoder(AbsEncoder):
-    """Encoder module definition.
-    Args:
-        input_size: Input size.
-        body_conf: Encoder body configuration.
-        input_conf: Encoder input configuration.
-        main_conf: Encoder main configuration.
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        output_size: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        attention_dropout_rate: float = 0.0,
-        embed_vgg_like: bool = False,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        positionwise_layer_type: str = "linear",
-        positionwise_conv_kernel_size: int = 3,
-        macaron_style: bool = False,
-        rel_pos_type: str = "legacy",
-        pos_enc_layer_type: str = "rel_pos",
-        selfattention_layer_type: str = "rel_selfattn",
-        activation_type: str = "swish",
-        use_cnn_module: bool = True,
-        zero_triu: bool = False,
-        norm_type: str = "layer_norm",
-        cnn_module_kernel: int = 31,
-        conv_mod_norm_eps: float = 0.00001,
-        conv_mod_norm_momentum: float = 0.1,
-        simplified_att_score: bool = False,
-        dynamic_chunk_training: bool = False,
-        short_chunk_threshold: float = 0.75,
-        short_chunk_size: int = 25,
-        left_chunk_size: int = 0,
-        time_reduction_factor: int = 1,
-        unified_model_training: bool = False,
-        default_chunk_size: int = 16,
-        jitter_range: int = 4,
-        subsampling_factor: int = 1,
-    ) -> None:
-        """Construct an Encoder object."""
-        super().__init__()
-
-
-        self.embed = StreamingConvInput(
-            input_size,
-            output_size,
-            subsampling_factor,
-            vgg_like=embed_vgg_like,
-            output_size=output_size,
-        )
-
-        self.pos_enc = StreamingRelPositionalEncoding(
-            output_size,
-            positional_dropout_rate,
-        )
-
-        activation = get_activation(
-            activation_type
-       )        
-
-        pos_wise_args = (
-            output_size,
-            linear_units,
-            positional_dropout_rate,
-            activation,
-        )
-
-        conv_mod_norm_args = {
-            "eps": conv_mod_norm_eps,
-            "momentum": conv_mod_norm_momentum,
-        }
-
-        conv_mod_args = (
-            output_size,
-            cnn_module_kernel,
-            activation,
-            conv_mod_norm_args,
-            dynamic_chunk_training or unified_model_training,
-        )
-
-        mult_att_args = (
-            attention_heads,
-            output_size,
-            attention_dropout_rate,
-            simplified_att_score,
-        )
-
-
-        fn_modules = []
-        for _ in range(num_blocks):
-            module = lambda: ChunkEncoderLayer(
-                output_size,
-                RelPositionMultiHeadedAttentionChunk(*mult_att_args),
-                PositionwiseFeedForward(*pos_wise_args),
-                PositionwiseFeedForward(*pos_wise_args),
-                CausalConvolution(*conv_mod_args),
-                dropout_rate=dropout_rate,
-            )
-            fn_modules.append(module)        
-
-        self.encoders = MultiBlocks(
-            [fn() for fn in fn_modules],
-            output_size,
-        )
-
-        self._output_size = output_size
-
-        self.dynamic_chunk_training = dynamic_chunk_training
-        self.short_chunk_threshold = short_chunk_threshold
-        self.short_chunk_size = short_chunk_size
-        self.left_chunk_size = left_chunk_size
-
-        self.unified_model_training = unified_model_training
-        self.default_chunk_size = default_chunk_size
-        self.jitter_range = jitter_range
-
-        self.time_reduction_factor = time_reduction_factor
-
-    def output_size(self) -> int:
-        return self._output_size
-
-    def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
-        """Return the corresponding number of sample for a given chunk size, in frames.
-        Where size is the number of features frames after applying subsampling.
-        Args:
-            size: Number of frames after subsampling.
-            hop_length: Frontend's hop length
-        Returns:
-            : Number of raw samples
-        """
-        return self.embed.get_size_before_subsampling(size) * hop_length
-
-    def get_encoder_input_size(self, size: int) -> int:
-        """Return the corresponding number of sample for a given chunk size, in frames.
-        Where size is the number of features frames after applying subsampling.
-        Args:
-            size: Number of frames after subsampling.
-        Returns:
-            : Number of raw samples
-        """
-        return self.embed.get_size_before_subsampling(size)
-
-
-    def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
-        """Initialize/Reset encoder streaming cache.
-        Args:
-            left_context: Number of frames in left context.
-            device: Device ID.
-        """
-        return self.encoders.reset_streaming_cache(left_context, device)
-
-    def forward(
-        self,
-        x: torch.Tensor,
-        x_len: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Encode input sequences.
-        Args:
-            x: Encoder input features. (B, T_in, F)
-            x_len: Encoder input features lengths. (B,)
-        Returns:
-           x: Encoder outputs. (B, T_out, D_enc)
-           x_len: Encoder outputs lenghts. (B,)
-        """
-        short_status, limit_size = check_short_utt(
-            self.embed.subsampling_factor, x.size(1)
-        )
-
-        if short_status:
-            raise TooShortUttError(
-                f"has {x.size(1)} frames and is too short for subsampling "
-                + f"(it needs more than {limit_size} frames), return empty results",
-                x.size(1),
-                limit_size,
-            )
-
-        mask = make_source_mask(x_len).to(x.device)
-
-        if self.unified_model_training:
-            if self.training:
-                chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
-            else:
-                chunk_size = self.default_chunk_size
-            x, mask = self.embed(x, mask, chunk_size)
-            pos_enc = self.pos_enc(x)
-            chunk_mask = make_chunk_mask(
-                x.size(1),
-                chunk_size,
-                left_chunk_size=self.left_chunk_size,
-                device=x.device,
-            )
-            x_utt = self.encoders(
-                x,
-                pos_enc,
-                mask,
-                chunk_mask=None,
-            )
-            x_chunk = self.encoders(
-                x,
-                pos_enc,
-                mask,
-                chunk_mask=chunk_mask,
-            )
-
-            olens = mask.eq(0).sum(1)
-            if self.time_reduction_factor > 1:
-                x_utt = x_utt[:,::self.time_reduction_factor,:]
-                x_chunk = x_chunk[:,::self.time_reduction_factor,:]
-                olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
-
-            return x_utt, x_chunk, olens
-
-        elif self.dynamic_chunk_training:
-            max_len = x.size(1)
-            if self.training:
-                chunk_size = torch.randint(1, max_len, (1,)).item()
-
-                if chunk_size > (max_len * self.short_chunk_threshold):
-                    chunk_size = max_len
-                else:
-                    chunk_size = (chunk_size % self.short_chunk_size) + 1
-            else:
-                chunk_size = self.default_chunk_size
-
-            x, mask = self.embed(x, mask, chunk_size)
-            pos_enc = self.pos_enc(x)
-
-            chunk_mask = make_chunk_mask(
-                x.size(1),
-                chunk_size,
-                left_chunk_size=self.left_chunk_size,
-                device=x.device,
-            )
-        else:
-            x, mask = self.embed(x, mask, None)
-            pos_enc = self.pos_enc(x)
-            chunk_mask = None
-        x = self.encoders(
-            x,
-            pos_enc,
-            mask,
-            chunk_mask=chunk_mask,
-        )
-
-        olens = mask.eq(0).sum(1)
-        if self.time_reduction_factor > 1:
-            x = x[:,::self.time_reduction_factor,:]
-            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
-
-        return x, olens, None
-
-    def full_utt_forward(
-        self,
-        x: torch.Tensor,
-        x_len: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Encode input sequences.
-        Args:
-            x: Encoder input features. (B, T_in, F)
-            x_len: Encoder input features lengths. (B,)
-        Returns:
-           x: Encoder outputs. (B, T_out, D_enc)
-           x_len: Encoder outputs lenghts. (B,)
-        """
-        short_status, limit_size = check_short_utt(
-            self.embed.subsampling_factor, x.size(1)
-        )
-
-        if short_status:
-            raise TooShortUttError(
-                f"has {x.size(1)} frames and is too short for subsampling "
-                + f"(it needs more than {limit_size} frames), return empty results",
-                x.size(1),
-                limit_size,
-            )
-
-        mask = make_source_mask(x_len).to(x.device)
-        x, mask = self.embed(x, mask, None)
-        pos_enc = self.pos_enc(x)
-        x_utt = self.encoders(
-            x,
-            pos_enc,
-            mask,
-            chunk_mask=None,
-        )
-
-        if self.time_reduction_factor > 1:
-            x_utt = x_utt[:,::self.time_reduction_factor,:]
-        return x_utt
-
-    def simu_chunk_forward(
-        self,
-        x: torch.Tensor,
-        x_len: torch.Tensor,
-        chunk_size: int = 16,
-        left_context: int = 32,
-        right_context: int = 0,
-    ) -> torch.Tensor:
-        short_status, limit_size = check_short_utt(
-            self.embed.subsampling_factor, x.size(1)
-        )
-
-        if short_status:
-            raise TooShortUttError(
-                f"has {x.size(1)} frames and is too short for subsampling "
-                + f"(it needs more than {limit_size} frames), return empty results",
-                x.size(1),
-                limit_size,
-            )
-
-        mask = make_source_mask(x_len)
-
-        x, mask = self.embed(x, mask, chunk_size)
-        pos_enc = self.pos_enc(x)
-        chunk_mask = make_chunk_mask(
-            x.size(1),
-            chunk_size,
-            left_chunk_size=self.left_chunk_size,
-            device=x.device,
-        )
-
-        x = self.encoders(
-            x,
-            pos_enc,
-            mask,
-            chunk_mask=chunk_mask,
-        )
-        olens = mask.eq(0).sum(1)
-        if self.time_reduction_factor > 1:
-            x = x[:,::self.time_reduction_factor,:]
-
-        return x
-
-    def chunk_forward(
-        self,
-        x: torch.Tensor,
-        x_len: torch.Tensor,
-        processed_frames: torch.tensor,
-        chunk_size: int = 16,
-        left_context: int = 32,
-        right_context: int = 0,
-    ) -> torch.Tensor:
-        """Encode input sequences as chunks.
-        Args:
-            x: Encoder input features. (1, T_in, F)
-            x_len: Encoder input features lengths. (1,)
-            processed_frames: Number of frames already seen.
-            left_context: Number of frames in left context.
-            right_context: Number of frames in right context.
-        Returns:
-           x: Encoder outputs. (B, T_out, D_enc)
-        """
-        mask = make_source_mask(x_len)
-        x, mask = self.embed(x, mask, None)
-
-        if left_context > 0:
-            processed_mask = (
-                torch.arange(left_context, device=x.device)
-                .view(1, left_context)
-                .flip(1)
-            )
-            processed_mask = processed_mask >= processed_frames
-            mask = torch.cat([processed_mask, mask], dim=1)
-        pos_enc = self.pos_enc(x, left_context=left_context)
-        x = self.encoders.chunk_forward(
-            x,
-            pos_enc,
-            mask,
-            chunk_size=chunk_size,
-            left_context=left_context,
-            right_context=right_context,
-        )
-
-        if right_context > 0:
-            x = x[:, 0:-right_context, :]
-
-        if self.time_reduction_factor > 1:
-            x = x[:,::self.time_reduction_factor,:]
-        return x
diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py
new file mode 100644
index 0000000..709e10e
--- /dev/null
+++ b/funasr/models/conformer/encoder.py
@@ -0,0 +1,613 @@
+# Copyright 2020 Tomoki Hayashi
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Conformer encoder definition."""
+
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import torch
+from torch import nn
+
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.attention import (
+    MultiHeadedAttention,  # noqa: H301
+    RelPositionMultiHeadedAttention,  # noqa: H301
+    LegacyRelPositionMultiHeadedAttention,  # noqa: H301
+)
+from funasr.models.transformer.embedding import (
+    PositionalEncoding,  # noqa: H301
+    ScaledPositionalEncoding,  # noqa: H301
+    RelPositionalEncoding,  # noqa: H301
+    LegacyRelPositionalEncoding,  # noqa: H301
+    StreamingRelPositionalEncoding,
+)
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.utils.nets_utils import get_activation
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.utils.nets_utils import (
+    TooShortUttError,
+    check_short_utt,
+    make_chunk_mask,
+    make_source_mask,
+)
+from funasr.models.transformer.positionwise_feed_forward import (
+    PositionwiseFeedForward,  # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat, MultiBlocks
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
+from funasr.models.transformer.utils.subsampling import StreamingConvInput
+from funasr.utils.register import register_class
+
+
+class ConvolutionModule(nn.Module):
+    """ConvolutionModule in Conformer model.
+
+    Args:
+        channels (int): The number of channels of conv layers.
+        kernel_size (int): Kernerl size of conv layers.
+
+    """
+
+    def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
+        """Construct an ConvolutionModule object."""
+        super(ConvolutionModule, self).__init__()
+        # kernerl_size should be a odd number for 'SAME' padding
+        assert (kernel_size - 1) % 2 == 0
+
+        self.pointwise_conv1 = nn.Conv1d(
+            channels,
+            2 * channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+        )
+        self.depthwise_conv = nn.Conv1d(
+            channels,
+            channels,
+            kernel_size,
+            stride=1,
+            padding=(kernel_size - 1) // 2,
+            groups=channels,
+            bias=bias,
+        )
+        self.norm = nn.BatchNorm1d(channels)
+        self.pointwise_conv2 = nn.Conv1d(
+            channels,
+            channels,
+            kernel_size=1,
+            stride=1,
+            padding=0,
+            bias=bias,
+        )
+        self.activation = activation
+
+    def forward(self, x):
+        """Compute convolution module.
+
+        Args:
+            x (torch.Tensor): Input tensor (#batch, time, channels).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, channels).
+
+        """
+        # exchange the temporal dimension and the feature dimension
+        x = x.transpose(1, 2)
+
+        # GLU mechanism
+        x = self.pointwise_conv1(x)  # (batch, 2*channel, dim)
+        x = nn.functional.glu(x, dim=1)  # (batch, channel, dim)
+
+        # 1D Depthwise Conv
+        x = self.depthwise_conv(x)
+        x = self.activation(self.norm(x))
+
+        x = self.pointwise_conv2(x)
+
+        return x.transpose(1, 2)
+
+
+class EncoderLayer(nn.Module):
+    """Encoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+            can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        conv_module (torch.nn.Module): Convolution module instance.
+            `ConvlutionModule` instance can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+        stochastic_depth_rate (float): Proability to skip this layer.
+            During training, the layer may skip residual computation and return input
+            as-is with given probability.
+    """
+
+    def __init__(
+            self,
+            size,
+            self_attn,
+            feed_forward,
+            feed_forward_macaron,
+            conv_module,
+            dropout_rate,
+            normalize_before=True,
+            concat_after=False,
+            stochastic_depth_rate=0.0,
+    ):
+        """Construct an EncoderLayer object."""
+        super(EncoderLayer, self).__init__()
+        self.self_attn = self_attn
+        self.feed_forward = feed_forward
+        self.feed_forward_macaron = feed_forward_macaron
+        self.conv_module = conv_module
+        self.norm_ff = LayerNorm(size)  # for the FNN module
+        self.norm_mha = LayerNorm(size)  # for the MHA module
+        if feed_forward_macaron is not None:
+            self.norm_ff_macaron = LayerNorm(size)
+            self.ff_scale = 0.5
+        else:
+            self.ff_scale = 1.0
+        if self.conv_module is not None:
+            self.norm_conv = LayerNorm(size)  # for the CNN module
+            self.norm_final = LayerNorm(size)  # for the final output of the block
+        self.dropout = nn.Dropout(dropout_rate)
+        self.size = size
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear = nn.Linear(size + size, size)
+        self.stochastic_depth_rate = stochastic_depth_rate
+
+    def forward(self, x_input, mask, cache=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
+                - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
+                - w/o pos emb: Tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+        if isinstance(x_input, tuple):
+            x, pos_emb = x_input[0], x_input[1]
+        else:
+            x, pos_emb = x_input, None
+
+        skip_layer = False
+        # with stochastic depth, residual connection `x + f(x)` becomes
+        # `x <- x + 1 / (1 - p) * f(x)` at training time.
+        stoch_layer_coeff = 1.0
+        if self.training and self.stochastic_depth_rate > 0:
+            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+        if skip_layer:
+            if cache is not None:
+                x = torch.cat([cache, x], dim=1)
+            if pos_emb is not None:
+                return (x, pos_emb), mask
+            return x, mask
+
+        # whether to use macaron style
+        if self.feed_forward_macaron is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm_ff_macaron(x)
+            x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
+                self.feed_forward_macaron(x)
+            )
+            if not self.normalize_before:
+                x = self.norm_ff_macaron(x)
+
+        # multi-headed self-attention module
+        residual = x
+        if self.normalize_before:
+            x = self.norm_mha(x)
+
+        if cache is None:
+            x_q = x
+        else:
+            assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
+            x_q = x[:, -1:, :]
+            residual = residual[:, -1:, :]
+            mask = None if mask is None else mask[:, -1:, :]
+
+        if pos_emb is not None:
+            x_att = self.self_attn(x_q, x, x, pos_emb, mask)
+        else:
+            x_att = self.self_attn(x_q, x, x, mask)
+
+        if self.concat_after:
+            x_concat = torch.cat((x, x_att), dim=-1)
+            x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+        else:
+            x = residual + stoch_layer_coeff * self.dropout(x_att)
+        if not self.normalize_before:
+            x = self.norm_mha(x)
+
+        # convolution module
+        if self.conv_module is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm_conv(x)
+            x = residual + stoch_layer_coeff * self.dropout(self.conv_module(x))
+            if not self.normalize_before:
+                x = self.norm_conv(x)
+
+        # feed forward module
+        residual = x
+        if self.normalize_before:
+            x = self.norm_ff(x)
+        x = residual + stoch_layer_coeff * self.ff_scale * self.dropout(
+            self.feed_forward(x)
+        )
+        if not self.normalize_before:
+            x = self.norm_ff(x)
+
+        if self.conv_module is not None:
+            x = self.norm_final(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+
+        if pos_emb is not None:
+            return (x, pos_emb), mask
+
+        return x, mask
+
+
+@register_class("encoder_classes", "ConformerEncoder")
+class ConformerEncoder(nn.Module):
+    """Conformer encoder module.
+
+    Args:
+        input_size (int): Input dimension.
+        output_size (int): Dimension of attention.
+        attention_heads (int): The number of heads of multi head attention.
+        linear_units (int): The number of units of position-wise feed forward.
+        num_blocks (int): The number of decoder blocks.
+        dropout_rate (float): Dropout rate.
+        attention_dropout_rate (float): Dropout rate in attention.
+        positional_dropout_rate (float): Dropout rate after adding positional encoding.
+        input_layer (Union[str, torch.nn.Module]): Input layer type.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            If True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            If False, no additional linear will be applied. i.e. x -> x + att(x)
+        positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
+        positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
+        rel_pos_type (str): Whether to use the latest relative positional encoding or
+            the legacy one. The legacy relative positional encoding will be deprecated
+            in the future. More Details can be found in
+            https://github.com/espnet/espnet/pull/2816.
+        encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
+        encoder_attn_layer_type (str): Encoder attention layer type.
+        activation_type (str): Encoder activation function type.
+        macaron_style (bool): Whether to use macaron style for positionwise layer.
+        use_cnn_module (bool): Whether to use convolution module.
+        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+        cnn_module_kernel (int): Kernerl size of convolution module.
+        padding_idx (int): Padding idx for input_layer=embed.
+
+    """
+
+    def __init__(
+            self,
+            input_size: int,
+            output_size: int = 256,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            attention_dropout_rate: float = 0.0,
+            input_layer: str = "conv2d",
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            positionwise_layer_type: str = "linear",
+            positionwise_conv_kernel_size: int = 3,
+            macaron_style: bool = False,
+            rel_pos_type: str = "legacy",
+            pos_enc_layer_type: str = "rel_pos",
+            selfattention_layer_type: str = "rel_selfattn",
+            activation_type: str = "swish",
+            use_cnn_module: bool = True,
+            zero_triu: bool = False,
+            cnn_module_kernel: int = 31,
+            padding_idx: int = -1,
+            interctc_layer_idx: List[int] = [],
+            interctc_use_conditioning: bool = False,
+            stochastic_depth_rate: Union[float, List[float]] = 0.0,
+    ):
+        super().__init__()
+        self._output_size = output_size
+
+        if rel_pos_type == "legacy":
+            if pos_enc_layer_type == "rel_pos":
+                pos_enc_layer_type = "legacy_rel_pos"
+            if selfattention_layer_type == "rel_selfattn":
+                selfattention_layer_type = "legacy_rel_selfattn"
+        elif rel_pos_type == "latest":
+            assert selfattention_layer_type != "legacy_rel_selfattn"
+            assert pos_enc_layer_type != "legacy_rel_pos"
+        else:
+            raise ValueError("unknown rel_pos_type: " + rel_pos_type)
+
+        activation = get_activation(activation_type)
+        if pos_enc_layer_type == "abs_pos":
+            pos_enc_class = PositionalEncoding
+        elif pos_enc_layer_type == "scaled_abs_pos":
+            pos_enc_class = ScaledPositionalEncoding
+        elif pos_enc_layer_type == "rel_pos":
+            assert selfattention_layer_type == "rel_selfattn"
+            pos_enc_class = RelPositionalEncoding
+        elif pos_enc_layer_type == "legacy_rel_pos":
+            assert selfattention_layer_type == "legacy_rel_selfattn"
+            pos_enc_class = LegacyRelPositionalEncoding
+            logging.warning(
+                "Using legacy_rel_pos and it will be deprecated in the future."
+            )
+        else:
+            raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
+
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2dpad":
+            self.embed = Conv2dSubsamplingPad(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif isinstance(input_layer, torch.nn.Module):
+            self.embed = torch.nn.Sequential(
+                input_layer,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer is None:
+            self.embed = torch.nn.Sequential(
+                pos_enc_class(output_size, positional_dropout_rate)
+            )
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+                activation,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+        elif selfattention_layer_type == "legacy_rel_selfattn":
+            assert pos_enc_layer_type == "legacy_rel_pos"
+            encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+            logging.warning(
+                "Using legacy_rel_selfattn and it will be deprecated in the future."
+            )
+        elif selfattention_layer_type == "rel_selfattn":
+            assert pos_enc_layer_type == "rel_pos"
+            encoder_selfattn_layer = RelPositionMultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+                zero_triu,
+            )
+        else:
+            raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
+
+        convolution_layer = ConvolutionModule
+        convolution_layer_args = (output_size, cnn_module_kernel, activation)
+
+        if isinstance(stochastic_depth_rate, float):
+            stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
+
+        if len(stochastic_depth_rate) != num_blocks:
+            raise ValueError(
+                f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
+                f"should be equal to num_blocks ({num_blocks})"
+            )
+
+        self.encoders = repeat(
+            num_blocks,
+            lambda lnum: EncoderLayer(
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                positionwise_layer(*positionwise_layer_args) if macaron_style else None,
+                convolution_layer(*convolution_layer_args) if use_cnn_module else None,
+                dropout_rate,
+                normalize_before,
+                concat_after,
+                stochastic_depth_rate[lnum],
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+            prev_states: torch.Tensor = None,
+            ctc: CTC = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Calculate forward propagation.
+
+        Args:
+            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
+            ilens (torch.Tensor): Input length (#batch).
+            prev_states (torch.Tensor): Not to be used now.
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, L, output_size).
+            torch.Tensor: Output length (#batch).
+            torch.Tensor: Not to be used now.
+
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+
+        if (
+                isinstance(self.embed, Conv2dSubsampling)
+                or isinstance(self.embed, Conv2dSubsampling2)
+                or isinstance(self.embed, Conv2dSubsampling6)
+                or isinstance(self.embed, Conv2dSubsampling8)
+                or isinstance(self.embed, Conv2dSubsamplingPad)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        intermediate_outs = []
+        if len(self.interctc_layer_idx) == 0:
+            xs_pad, masks = self.encoders(xs_pad, masks)
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                xs_pad, masks = encoder_layer(xs_pad, masks)
+
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+                    if isinstance(encoder_out, tuple):
+                        encoder_out = encoder_out[0]
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+
+                        if isinstance(xs_pad, tuple):
+                            x, pos_emb = xs_pad
+                            x = x + self.conditioning_layer(ctc_out)
+                            xs_pad = (x, pos_emb)
+                        else:
+                            xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if isinstance(xs_pad, tuple):
+            xs_pad = xs_pad[0]
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens, None
+
diff --git a/funasr/models/conformer/model.py b/funasr/models/conformer/model.py
index 48f04e4..5319a73 100644
--- a/funasr/models/conformer/model.py
+++ b/funasr/models/conformer/model.py
@@ -1,57 +1,11 @@
 import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
+
 import torch
-import torch.nn as nn
-import random
-import numpy as np
-import time
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
-	LabelSmoothingLoss,  # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
-
-from funasr.models.model_class_factory import *
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
-	from torch.cuda.amp import autocast
-else:
-	# Nothing to do if torch<1.6.0
-	@contextmanager
-	def autocast(enabled=True):
-		yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
 
 from funasr.models.transformer.model import Transformer
+from funasr.utils.register import register_class, registry_tables
 
+@register_class("model_classes", "Conformer")
 class Conformer(Transformer):
 	"""CTC-attention hybrid Encoder-Decoder model"""
 
diff --git a/funasr/models/ct_transformer/attention.py b/funasr/models/ct_transformer/attention.py
new file mode 100644
index 0000000..a35ddee
--- /dev/null
+++ b/funasr/models/ct_transformer/attention.py
@@ -0,0 +1,1091 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+from typing import Optional, Tuple
+
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+import funasr.models.lora.layers as lora
+
+class MultiHeadedAttention(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttention, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        self.linear_q = nn.Linear(n_feat, n_feat)
+        self.linear_k = nn.Linear(n_feat, n_feat)
+        self.linear_v = nn.Linear(n_feat, n_feat)
+        self.linear_out = nn.Linear(n_feat, n_feat)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, query, key, value):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        n_batch = query.size(0)
+        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
+        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
+        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q, k, v
+
+    def forward_attention(self, value, scores, mask):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, query, key, value, mask):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, mask)
+
+
+class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
+    """Multi-Head Attention layer with relative position encoding (old version).
+
+    Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+    Paper: https://arxiv.org/abs/1901.02860
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
+        """Construct an RelPositionMultiHeadedAttention object."""
+        super().__init__(n_head, n_feat, dropout_rate)
+        self.zero_triu = zero_triu
+        # linear transformation for positional encoding
+        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+        # these two learnable bias are used in matrix c and matrix d
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+        torch.nn.init.xavier_uniform_(self.pos_bias_u)
+        torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+    def rel_shift(self, x):
+        """Compute relative positional encoding.
+
+        Args:
+            x (torch.Tensor): Input tensor (batch, head, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor.
+
+        """
+        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+        x_padded = torch.cat([zero_pad, x], dim=-1)
+
+        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+        x = x_padded[:, :, 1:].view_as(x)
+
+        if self.zero_triu:
+            ones = torch.ones((x.size(2), x.size(3)))
+            x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+        return x
+
+    def forward(self, query, key, value, pos_emb, mask):
+        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
+
+        n_batch_pos = pos_emb.size(0)
+        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+        p = p.transpose(1, 2)  # (batch, head, time1, d_k)
+
+        # (batch, head, time1, d_k)
+        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+        # (batch, head, time1, d_k)
+        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+        # compute matrix b and matrix d
+        # (batch, head, time1, time1)
+        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        scores = (matrix_ac + matrix_bd) / math.sqrt(
+            self.d_k
+        )  # (batch, head, time1, time2)
+
+        return self.forward_attention(v, scores, mask)
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+    """Multi-Head Attention layer with relative position encoding (new implementation).
+
+    Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+    Paper: https://arxiv.org/abs/1901.02860
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
+        """Construct an RelPositionMultiHeadedAttention object."""
+        super().__init__(n_head, n_feat, dropout_rate)
+        self.zero_triu = zero_triu
+        # linear transformation for positional encoding
+        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+        # these two learnable bias are used in matrix c and matrix d
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+        torch.nn.init.xavier_uniform_(self.pos_bias_u)
+        torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+    def rel_shift(self, x):
+        """Compute relative positional encoding.
+
+        Args:
+            x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
+            time1 means the length of query vector.
+
+        Returns:
+            torch.Tensor: Output tensor.
+
+        """
+        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+        x_padded = torch.cat([zero_pad, x], dim=-1)
+
+        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+        x = x_padded[:, :, 1:].view_as(x)[
+            :, :, :, : x.size(-1) // 2 + 1
+            ]  # only keep the positions from 0 to time2
+
+        if self.zero_triu:
+            ones = torch.ones((x.size(2), x.size(3)), device=x.device)
+            x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+        return x
+
+    def forward(self, query, key, value, pos_emb, mask):
+        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            pos_emb (torch.Tensor): Positional embedding tensor
+                (#batch, 2*time1-1, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
+
+        n_batch_pos = pos_emb.size(0)
+        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+        p = p.transpose(1, 2)  # (batch, head, 2*time1-1, d_k)
+
+        # (batch, head, time1, d_k)
+        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+        # (batch, head, time1, d_k)
+        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+        # compute matrix b and matrix d
+        # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        scores = (matrix_ac + matrix_bd) / math.sqrt(
+            self.d_k
+        )  # (batch, head, time1, time2)
+
+        return self.forward_attention(v, scores, mask)
+
+
+class MultiHeadedAttentionSANM(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttentionSANM, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        # self.linear_q = nn.Linear(n_feat, n_feat)
+        # self.linear_k = nn.Linear(n_feat, n_feat)
+        # self.linear_v = nn.Linear(n_feat, n_feat)
+        if lora_list is not None:
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
+            lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
+            if lora_qkv_list == [False, False, False]:
+                self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+            else:
+                self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+        else:
+            self.linear_out = nn.Linear(n_feat, n_feat)
+            self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+        self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+        # padding
+        left_padding = (kernel_size - 1) // 2
+        if sanm_shfit > 0:
+            left_padding = left_padding + sanm_shfit
+        right_padding = kernel_size - 1 - left_padding
+        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+
+    def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
+        b, t, d = inputs.size()
+        if mask is not None:
+            mask = torch.reshape(mask, (b, -1, 1))
+            if mask_shfit_chunk is not None:
+                mask = mask * mask_shfit_chunk
+            inputs = inputs * mask
+
+        x = inputs.transpose(1, 2)
+        x = self.pad_fn(x)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        x += inputs
+        x = self.dropout(x)
+        if mask is not None:
+            x = x * mask
+        return x
+
+    def forward_qkv(self, x):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        b, t, d = x.size()
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q_h, k_h, v_h, v
+
+    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            if mask_att_chunk_encoder is not None:
+                mask = mask * mask_att_chunk_encoder
+
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+        return att_outs + fsmn_memory
+
+    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        if chunk_size is not None and look_back > 0 or look_back == -1:
+            if cache is not None:
+                k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
+                v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
+                k_h = torch.cat((cache["k"], k_h), dim=2)
+                v_h = torch.cat((cache["v"], v_h), dim=2)
+
+                cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
+                cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
+                if look_back != -1:
+                    cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
+                    cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
+            else:
+                cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
+                             "v": v_h[:, :, :-(chunk_size[2]), :]}
+                cache = cache_tmp
+        fsmn_memory = self.forward_fsmn(v, None)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, None)
+        return att_outs + fsmn_memory, cache
+
+
+class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
+        return att_outs + fsmn_memory
+
+class MultiHeadedAttentionSANMDecoder(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttentionSANMDecoder, self).__init__()
+
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+        self.fsmn_block = nn.Conv1d(n_feat, n_feat,
+                                    kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+        # padding
+        # padding
+        left_padding = (kernel_size - 1) // 2
+        if sanm_shfit > 0:
+            left_padding = left_padding + sanm_shfit
+        right_padding = kernel_size - 1 - left_padding
+        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+        self.kernel_size = kernel_size
+
+    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
+        '''
+        :param x: (#batch, time1, size).
+        :param mask: Mask tensor (#batch, 1, time)
+        :return:
+        '''
+        # print("in fsmn, inputs", inputs.size())
+        b, t, d = inputs.size()
+        # logging.info(
+        #     "mask: {}".format(mask.size()))
+        if mask is not None:
+            mask = torch.reshape(mask, (b ,-1, 1))
+            # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+            if mask_shfit_chunk is not None:
+                # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
+                mask = mask * mask_shfit_chunk
+            # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+            # print("in fsmn, mask", mask.size())
+            # print("in fsmn, inputs", inputs.size())
+            inputs = inputs * mask
+
+        x = inputs.transpose(1, 2)
+        b, d, t = x.size()
+        if cache is None:
+            # print("in fsmn, cache is None, x", x.size())
+
+            x = self.pad_fn(x)
+            if not self.training:
+                cache = x
+        else:
+            # print("in fsmn, cache is not None, x", x.size())
+            # x = torch.cat((x, cache), dim=2)[:, :, :-1]
+            # if t < self.kernel_size:
+            #     x = self.pad_fn(x)
+            x = torch.cat((cache[:, :, 1:], x), dim=2)
+            x = x[:, :, -(self.kernel_size+t-1):]
+            # print("in fsmn, cache is not None, x_cat", x.size())
+            cache = x
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        # print("in fsmn, fsmn_out", x.size())
+        if x.size(1) != inputs.size(1):
+            inputs = inputs[:, -1, :]
+
+        x = x + inputs
+        x = self.dropout(x)
+        if mask is not None:
+            x = x * mask
+        return x, cache
+
+class MultiHeadedAttentionCrossAtt(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttentionCrossAtt, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        if lora_list is not None:
+            if "q" in lora_list:
+                self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_q = nn.Linear(n_feat, n_feat)
+            lora_kv_list = ["k" in lora_list, "v" in lora_list]
+            if lora_kv_list == [False, False]:
+                self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            else:
+                self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
+                                      r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
+        else:
+            self.linear_q = nn.Linear(n_feat, n_feat)
+            self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            self.linear_out = nn.Linear(n_feat, n_feat)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, x, memory):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+
+        # print("in forward_qkv, x", x.size())
+        b = x.size(0)
+        q = self.linear_q(x)
+        q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time1, d_k)
+
+        k_v = self.linear_k_v(memory)
+        k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
+        k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
+
+
+        return q_h, k_h, v_h
+
+    def forward_attention(self, value, scores, mask):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            # logging.info(
+            #     "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, x, memory, memory_mask):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h = self.forward_qkv(x, memory)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        return self.forward_attention(v_h, scores, memory_mask)
+
+    def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h = self.forward_qkv(x, memory)
+        if chunk_size is not None and look_back > 0:
+            if cache is not None:
+                k_h = torch.cat((cache["k"], k_h), dim=2)
+                v_h = torch.cat((cache["v"], v_h), dim=2)
+                cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
+                cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
+            else:
+                cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
+                             "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
+                cache = cache_tmp
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        return self.forward_attention(v_h, scores, None), cache
+
+
+class MultiHeadSelfAttention(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, in_feat, n_feat, dropout_rate):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadSelfAttention, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        self.linear_out = nn.Linear(n_feat, n_feat)
+        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, x):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        b, t, d = x.size()
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q_h, k_h, v_h, v
+
+    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            if mask_att_chunk_encoder is not None:
+                mask = mask * mask_att_chunk_encoder
+
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, x, mask, mask_att_chunk_encoder=None):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+        return att_outs
+
+class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
+    """RelPositionMultiHeadedAttention definition.
+    Args:
+        num_heads: Number of attention heads.
+        embed_size: Embedding size.
+        dropout_rate: Dropout rate.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        embed_size: int,
+        dropout_rate: float = 0.0,
+        simplified_attention_score: bool = False,
+    ) -> None:
+        """Construct an MultiHeadedAttention object."""
+        super().__init__()
+
+        self.d_k = embed_size // num_heads
+        self.num_heads = num_heads
+
+        assert self.d_k * num_heads == embed_size, (
+            "embed_size (%d) must be divisible by num_heads (%d)",
+            (embed_size, num_heads),
+        )
+
+        self.linear_q = torch.nn.Linear(embed_size, embed_size)
+        self.linear_k = torch.nn.Linear(embed_size, embed_size)
+        self.linear_v = torch.nn.Linear(embed_size, embed_size)
+
+        self.linear_out = torch.nn.Linear(embed_size, embed_size)
+
+        if simplified_attention_score:
+            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
+
+            self.compute_att_score = self.compute_simplified_attention_score
+        else:
+            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
+
+            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            torch.nn.init.xavier_uniform_(self.pos_bias_u)
+            torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+            self.compute_att_score = self.compute_attention_score
+
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.attn = None
+
+    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+        """Compute relative positional encoding.
+        Args:
+            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
+            left_context: Number of frames in left context.
+        Returns:
+            x: Output sequence. (B, H, T_1, T_2)
+        """
+        batch_size, n_heads, time1, n = x.shape
+        time2 = time1 + left_context
+
+        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
+
+        return x.as_strided(
+            (batch_size, n_heads, time1, time2),
+            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
+            storage_offset=(n_stride * (time1 - 1)),
+        )
+
+    def compute_simplified_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Simplified attention score computation.
+        Reference: https://github.com/k2-fsa/icefall/pull/458
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+        """
+        pos_enc = self.linear_pos(pos_enc)
+
+        matrix_ac = torch.matmul(query, key.transpose(2, 3))
+
+        matrix_bd = self.rel_shift(
+            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
+            left_context=left_context,
+        )
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def compute_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Attention score computation.
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+        """
+        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
+
+        query = query.transpose(1, 2)
+        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
+        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def forward_qkv(
+        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Transform query, key and value.
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            v: Value tensor. (B, T_2, size)
+        Returns:
+            q: Transformed query tensor. (B, H, T_1, d_k)
+            k: Transformed key tensor. (B, H, T_2, d_k)
+            v: Transformed value tensor. (B, H, T_2, d_k)
+        """
+        n_batch = query.size(0)
+
+        q = (
+            self.linear_q(query)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        k = (
+            self.linear_k(key)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        v = (
+            self.linear_v(value)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+
+        return q, k, v
+
+    def forward_attention(
+        self,
+        value: torch.Tensor,
+        scores: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Compute attention context vector.
+        Args:
+            value: Transformed value. (B, H, T_2, d_k)
+            scores: Attention score. (B, H, T_1, T_2)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+        Returns:
+           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
+        """
+        batch_size = scores.size(0)
+        mask = mask.unsqueeze(1).unsqueeze(2)
+        if chunk_mask is not None:
+            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
+        scores = scores.masked_fill(mask, float("-inf"))
+        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+
+        attn_output = self.dropout(self.attn)
+        attn_output = torch.matmul(attn_output, value)
+
+        attn_output = self.linear_out(
+            attn_output.transpose(1, 2)
+            .contiguous()
+            .view(batch_size, -1, self.num_heads * self.d_k)
+        )
+
+        return attn_output
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Compute scaled dot product attention with rel. positional encoding.
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            value: Value tensor. (B, T_2, size)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+            left_context: Number of frames in left context.
+        Returns:
+            : Output tensor. (B, T_1, H * d_k)
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
+        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
+
+
+class CosineDistanceAttention(nn.Module):
+    """ Compute Cosine Distance between spk decoder output and speaker profile 
+    Args:
+        profile_path: speaker profile file path (.npy file)
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, spk_decoder_out, profile, profile_lens=None):
+        """
+        Args:
+            spk_decoder_out(torch.Tensor):(B, L, D)
+            spk_profiles(torch.Tensor):(B, N, D)
+        """
+        x = spk_decoder_out.unsqueeze(2)  # (B, L, 1, D)
+        if profile_lens is not None:
+            
+            mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
+            )
+            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
+            weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0)  # (B, L, N)
+        else:
+            x = x[:, -1:, :, :]
+            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
+            weights = self.softmax(weights_not_softmax)  # (B, 1, N)
+        spk_embedding = torch.matmul(weights, profile.to(weights.device))  # (B, L, D)
+
+        return spk_embedding, weights
diff --git a/funasr/models/ct_transformer/sanm_encoder.py b/funasr/models/ct_transformer/sanm_encoder.py
new file mode 100644
index 0000000..1bdf5d5
--- /dev/null
+++ b/funasr/models/ct_transformer/sanm_encoder.py
@@ -0,0 +1,383 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from funasr.models.scama.chunk_utilis import overlap_chunk
+import numpy as np
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.sanm.attention import MultiHeadedAttention
+from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask
+from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.positionwise_feed_forward import (
+    PositionwiseFeedForward,  # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
+
+from funasr.models.ctc.ctc import CTC
+
+from funasr.utils.register import register_class
+
+class EncoderLayerSANM(nn.Module):
+    def __init__(
+        self,
+        in_size,
+        size,
+        self_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+        stochastic_depth_rate=0.0,
+    ):
+        """Construct an EncoderLayer object."""
+        super(EncoderLayerSANM, self).__init__()
+        self.self_attn = self_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(in_size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.in_size = in_size
+        self.size = size
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear = nn.Linear(size + size, size)
+        self.stochastic_depth_rate = stochastic_depth_rate
+        self.dropout_rate = dropout_rate
+
+    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+        skip_layer = False
+        # with stochastic depth, residual connection `x + f(x)` becomes
+        # `x <- x + 1 / (1 - p) * f(x)` at training time.
+        stoch_layer_coeff = 1.0
+        if self.training and self.stochastic_depth_rate > 0:
+            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+        if skip_layer:
+            if cache is not None:
+                x = torch.cat([cache, x], dim=1)
+            return x, mask
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.concat_after:
+            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+            else:
+                x = stoch_layer_coeff * self.concat_linear(x_concat)
+        else:
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+            else:
+                x = stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.in_size == self.size:
+            attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+            x = residual + attn
+        else:
+            x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + self.feed_forward(x)
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        return x, cache
+
+
+@register_class("encoder_classes", "SANMVadEncoder")
+class SANMVadEncoder(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        pos_enc_class=SinusoidalPositionEncoder,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
+        kernel_size : int = 11,
+        sanm_shfit : int = 0,
+        selfattention_layer_type: str = "sanm",
+    ):
+        super().__init__()
+        self._output_size = output_size
+
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                SinusoidalPositionEncoder(),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        elif input_layer == "pe":
+            self.embed = SinusoidalPositionEncoder()
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+
+        elif selfattention_layer_type == "sanm":
+            self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
+            encoder_selfattn_layer_args0 = (
+                attention_heads,
+                input_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+        self.encoders = repeat(
+            num_blocks-1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+        self.dropout = nn.Dropout(dropout_rate)
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        vad_indexes: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
+        no_future_masks = masks & sub_masks
+        xs_pad *= self.output_size()**0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
+              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
+                    f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        # xs_pad = self.dropout(xs_pad)
+        mask_tup0 = [masks, no_future_masks]
+        encoder_outs = self.encoders0(xs_pad, mask_tup0)
+        xs_pad, _ = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+
+
+        for layer_idx, encoder_layer in enumerate(self.encoders):
+                if layer_idx + 1 == len(self.encoders):
+                    # This is last layer.
+                    coner_mask = torch.ones(masks.size(0),
+                                            masks.size(-1),
+                                            masks.size(-1),
+                                            device=xs_pad.device,
+                                            dtype=torch.bool)
+                    for word_index, length in enumerate(ilens):
+                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
+                                                                vad_indexes[word_index],
+                                                                device=xs_pad.device)
+                    layer_mask = masks & coner_mask
+                else:
+                    layer_mask = no_future_masks
+                mask_tup1 = [masks, layer_mask]
+                encoder_outs = encoder_layer(xs_pad, mask_tup1)
+                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens, None
diff --git a/funasr/models/ct_transformer/target_delay_transformer.py b/funasr/models/ct_transformer/target_delay_transformer.py
index f3bc772..59884a3 100644
--- a/funasr/models/ct_transformer/target_delay_transformer.py
+++ b/funasr/models/ct_transformer/target_delay_transformer.py
@@ -6,7 +6,7 @@
 import torch.nn as nn
 
 from funasr.models.transformer.embedding import SinusoidalPositionEncoder
-from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
+from funasr.models.sanm.encoder import SANMEncoder as Encoder
 
 
 class TargetDelayTransformer(torch.nn.Module):
diff --git a/funasr/models/fsmn_vad/vad_realtime_transformer.py b/funasr/models/ct_transformer/vad_realtime_transformer.py
similarity index 97%
rename from funasr/models/fsmn_vad/vad_realtime_transformer.py
rename to funasr/models/ct_transformer/vad_realtime_transformer.py
index e02f624..155057c 100644
--- a/funasr/models/fsmn_vad/vad_realtime_transformer.py
+++ b/funasr/models/ct_transformer/vad_realtime_transformer.py
@@ -6,7 +6,7 @@
 import torch.nn as nn
 
 from funasr.models.transformer.embedding import SinusoidalPositionEncoder
-from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
+from funasr.models.ct_transformer.sanm_encoder import SANMVadEncoder as Encoder
 
 
 class VadRealtimeTransformer(torch.nn.Module):
diff --git a/funasr/models/data2vec/data2vec.py b/funasr/models/data2vec/data2vec.py
index 19c5612..c77cedf 100644
--- a/funasr/models/data2vec/data2vec.py
+++ b/funasr/models/data2vec/data2vec.py
@@ -10,13 +10,14 @@
 from typing import Tuple
 
 import torch
+import torch.nn as nn
 
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.models.base_model import FunASRModel
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.specaug.abs_specaug import AbsSpecAug
+# from funasr.layers.abs_normalize import AbsNormalize
+# from funasr.models.base_model import FunASRModel
+# from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.frontends.abs_frontend import AbsFrontend
+# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+# from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.train_utils.device_funcs import force_gatherable
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@@ -28,16 +29,16 @@
         yield
 
 
-class Data2VecPretrainModel(FunASRModel):
+class Data2VecPretrainModel(nn.Module):
     """Data2Vec Pretrain model"""
 
     def __init__(
             self,
-            frontend: Optional[AbsFrontend],
-            specaug: Optional[AbsSpecAug],
-            normalize: Optional[AbsNormalize],
-            encoder: AbsEncoder,
-            preencoder: Optional[AbsPreEncoder] = None,
+            frontend = None,
+            specaug = None,
+            normalize = None,
+            encoder = None,
+            preencoder = None,
     ):
 
         super().__init__()
diff --git a/funasr/models/data2vec/data2vec_encoder.py b/funasr/models/data2vec/data2vec_encoder.py
index 1bcb639..4689e20 100644
--- a/funasr/models/data2vec/data2vec_encoder.py
+++ b/funasr/models/data2vec/data2vec_encoder.py
@@ -11,7 +11,6 @@
 import torch.nn as nn
 import torch.nn.functional as F
 
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.data2vec.data_utils import compute_mask_indices
 from funasr.models.data2vec.ema_module import EMAModule
 from funasr.models.data2vec.grad_multiply import GradMultiply
@@ -28,7 +27,7 @@
     return end - r * pct_remaining
 
 
-class Data2VecEncoder(AbsEncoder):
+class Data2VecEncoder(nn.Module):
     def __init__(
             self,
             # for ConvFeatureExtractionModel
diff --git a/funasr/models/e_branchformer/e_branchformer_encoder.py b/funasr/models/e_branchformer/encoder.py
similarity index 97%
rename from funasr/models/e_branchformer/e_branchformer_encoder.py
rename to funasr/models/e_branchformer/encoder.py
index a7b4fde..5604c9f 100644
--- a/funasr/models/e_branchformer/e_branchformer_encoder.py
+++ b/funasr/models/e_branchformer/encoder.py
@@ -13,9 +13,8 @@
 from typing import List, Optional, Tuple
 
 import torch
-
-from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
+import torch.nn as nn
+from funasr.models.ctc.ctc import CTC
 from funasr.models.branchformer.cgmlp import ConvolutionalGatingMLP
 from funasr.models.branchformer.fastformer import FastSelfAttention
 from funasr.models.transformer.utils.nets_utils import get_activation, make_pad_mask
@@ -34,8 +33,8 @@
 from funasr.models.transformer.positionwise_feed_forward import (
     PositionwiseFeedForward,
 )
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import (
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import (
     Conv2dSubsampling,
     Conv2dSubsampling2,
     Conv2dSubsampling6,
@@ -43,7 +42,7 @@
     TooShortUttError,
     check_short_utt,
 )
-
+from funasr.utils.register import register_class
 
 class EBranchformerEncoderLayer(torch.nn.Module):
     """E-Branchformer encoder layer module.
@@ -175,8 +174,8 @@
 
         return x, mask
 
-
-class EBranchformerEncoder(AbsEncoder):
+@register_class("encoder_classes", "EBranchformerEncoder")
+class EBranchformerEncoder(nn.Module):
     """E-Branchformer encoder module."""
 
     def __init__(
diff --git a/funasr/models/e_branchformer/model.py b/funasr/models/e_branchformer/model.py
index ca3e0f5..ccf1320 100644
--- a/funasr/models/e_branchformer/model.py
+++ b/funasr/models/e_branchformer/model.py
@@ -1,57 +1,9 @@
 import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
-import torch
-import torch.nn as nn
-import random
-import numpy as np
-import time
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
-	LabelSmoothingLoss,  # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
-
-from funasr.models.model_class_factory import *
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
-	from torch.cuda.amp import autocast
-else:
-	# Nothing to do if torch<1.6.0
-	@contextmanager
-	def autocast(enabled=True):
-		yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
 
 from funasr.models.transformer.model import Transformer
+from funasr.utils.register import register_class
 
+@register_class("model_classes", "EBranchformer")
 class EBranchformer(Transformer):
 	"""CTC-attention hybrid Encoder-Decoder model"""
 
diff --git a/funasr/models/eend/e2e_diar_eend_ola.py b/funasr/models/eend/e2e_diar_eend_ola.py
index 28aa223..ee80836 100644
--- a/funasr/models/eend/e2e_diar_eend_ola.py
+++ b/funasr/models/eend/e2e_diar_eend_ola.py
@@ -7,7 +7,7 @@
 import torch.nn as  nn
 import torch.nn.functional as F
 
-from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.frontends.wav_frontend import WavFrontendMel23
 from funasr.models.eend.encoder import EENDOLATransformerEncoder
 from funasr.models.eend.encoder_decoder_attractor import EncoderDecoderAttractor
 from funasr.models.eend.utils.losses import standard_loss, cal_power_loss, fast_batch_pit_n_speaker_loss
diff --git a/funasr/models/eend/encoder.py b/funasr/models/eend/encoder.py
index 3065884..430f0d0 100644
--- a/funasr/models/eend/encoder.py
+++ b/funasr/models/eend/encoder.py
@@ -7,7 +7,7 @@
 
 class MultiHeadSelfAttention(nn.Module):
     def __init__(self, n_units, h=8, dropout_rate=0.1):
-        super(MultiHeadSelfAttention, self).__init__()
+        super().__init__()
         self.linearQ = nn.Linear(n_units, n_units)
         self.linearK = nn.Linear(n_units, n_units)
         self.linearV = nn.Linear(n_units, n_units)
diff --git a/funasr/models/frontend/__init__.py b/funasr/models/frontend/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models/frontend/__init__.py
+++ /dev/null
diff --git a/funasr/models/frontend/abs_frontend.py b/funasr/models/frontend/abs_frontend.py
deleted file mode 100644
index 6049a01..0000000
--- a/funasr/models/frontend/abs_frontend.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Tuple
-
-import torch
-
-
-class AbsFrontend(torch.nn.Module, ABC):
-    @abstractmethod
-    def output_size(self) -> int:
-        raise NotImplementedError
-
-    @abstractmethod
-    def forward(
-        self, input: torch.Tensor, input_lengths: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index b930e0c..cc3c87e 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -27,7 +27,7 @@
                                                self.vad_opts.speech_to_sil_time_thres,
                                                self.vad_opts.frame_in_ms)
         
-        encoder_class = encoder_choices.get_class(encoder)
+        encoder_class = encoder_classes.get_class(encoder)
         encoder = encoder_class(**encoder_conf)
         self.encoder = encoder
         # init variables
diff --git a/funasr/models/language_model/rnn/decoders.py b/funasr/models/language_model/rnn/decoders.py
index a426b51..fd8bc29 100644
--- a/funasr/models/language_model/rnn/decoders.py
+++ b/funasr/models/language_model/rnn/decoders.py
@@ -15,7 +15,7 @@
 from funasr.metrics import end_detect
 from funasr.models.transformer.utils.nets_utils import mask_by_length
 from funasr.models.transformer.utils.nets_utils import pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
 from funasr.models.transformer.utils.nets_utils import to_device
 from funasr.models.language_model.rnn.attentions import att_to_numpy
 
diff --git a/funasr/models/language_model/rnn/encoders.py b/funasr/models/language_model/rnn/encoders.py
index 819585b..d047a8e 100644
--- a/funasr/models/language_model/rnn/encoders.py
+++ b/funasr/models/language_model/rnn/encoders.py
@@ -7,7 +7,7 @@
 from torch.nn.utils.rnn import pack_padded_sequence
 from torch.nn.utils.rnn import pad_packed_sequence
 
-from funasr.metrics import get_vgg2l_odim
+from funasr.metrics.common import get_vgg2l_odim
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.transformer.utils.nets_utils import to_device
 
diff --git a/funasr/models/transformer/transformer_encoder.py b/funasr/models/language_model/transformer_encoder.py
similarity index 66%
rename from funasr/models/transformer/transformer_encoder.py
rename to funasr/models/language_model/transformer_encoder.py
index 1126da0..21f3548 100644
--- a/funasr/models/transformer/transformer_encoder.py
+++ b/funasr/models/language_model/transformer_encoder.py
@@ -11,8 +11,6 @@
 from torch import nn
 import logging
 
-from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.transformer.attention import MultiHeadedAttention
 from funasr.models.transformer.embedding import PositionalEncoding
 from funasr.models.transformer.layer_norm import LayerNorm
@@ -22,18 +20,17 @@
 from funasr.models.transformer.positionwise_feed_forward import (
     PositionwiseFeedForward,  # noqa: H301
 )
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.utils.nets_utils import rename_state_dict
+from funasr.models.transformer.utils.repeat import repeat
 from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
 from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
 from funasr.models.transformer.utils.lightconv import LightweightConvolution
 from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
 
 
 class EncoderLayer(nn.Module):
@@ -143,216 +140,7 @@
         return x, mask
 
 
-class TransformerEncoder(AbsEncoder):
-    """Transformer encoder module.
-
-    Args:
-        input_size: input dim
-        output_size: dimension of attention
-        attention_heads: the number of heads of multi head attention
-        linear_units: the number of units of position-wise feed forward
-        num_blocks: the number of decoder blocks
-        dropout_rate: dropout rate
-        attention_dropout_rate: dropout rate in attention
-        positional_dropout_rate: dropout rate after adding positional encoding
-        input_layer: input layer type
-        pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
-        normalize_before: whether to use layer_norm before the first block
-        concat_after: whether to concat attention layer's input and output
-            if True, additional linear will be applied.
-            i.e. x -> x + linear(concat(x, att(x)))
-            if False, no additional linear will be applied.
-            i.e. x -> x + att(x)
-        positionwise_layer_type: linear of conv1d
-        positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
-        padding_idx: padding_idx for input_layer=embed
-    """
-
-    def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            input_layer: Optional[str] = "conv2d",
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            positionwise_layer_type: str = "linear",
-            positionwise_conv_kernel_size: int = 1,
-            padding_idx: int = -1,
-            interctc_layer_idx: List[int] = [],
-            interctc_use_conditioning: bool = False,
-    ):
-        super().__init__()
-        self._output_size = output_size
-
-        if input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(input_size, output_size),
-                torch.nn.LayerNorm(output_size),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2d":
-            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d2":
-            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d6":
-            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d8":
-            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
-        elif input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer is None:
-            if input_size == output_size:
-                self.embed = None
-            else:
-                self.embed = torch.nn.Linear(input_size, output_size)
-        else:
-            raise ValueError("unknown input_layer: " + input_layer)
-        self.normalize_before = normalize_before
-        if positionwise_layer_type == "linear":
-            positionwise_layer = PositionwiseFeedForward
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d":
-            positionwise_layer = MultiLayeredConv1d
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d-linear":
-            positionwise_layer = Conv1dLinear
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        else:
-            raise NotImplementedError("Support only linear or conv1d.")
-        self.encoders = repeat(
-            num_blocks,
-            lambda lnum: EncoderLayer(
-                output_size,
-                MultiHeadedAttention(
-                    attention_heads, output_size, attention_dropout_rate
-                ),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-
-        self.interctc_layer_idx = interctc_layer_idx
-        if len(interctc_layer_idx) > 0:
-            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
-        self.interctc_use_conditioning = interctc_use_conditioning
-        self.conditioning_layer = None
-
-    def output_size(self) -> int:
-        return self._output_size
-
-    def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            prev_states: torch.Tensor = None,
-            ctc: CTC = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Embed positions in tensor.
-
-        Args:
-            xs_pad: input tensor (B, L, D)
-            ilens: input length (B)
-            prev_states: Not to be used now.
-        Returns:
-            position embedded tensor and mask
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-
-        if self.embed is None:
-            xs_pad = xs_pad
-        elif (
-                isinstance(self.embed, Conv2dSubsampling)
-                or isinstance(self.embed, Conv2dSubsampling2)
-                or isinstance(self.embed, Conv2dSubsampling6)
-                or isinstance(self.embed, Conv2dSubsampling8)
-        ):
-            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
-            if short_status:
-                raise TooShortUttError(
-                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
-                    + f"(it needs more than {limit_size} frames), return empty results",
-                    xs_pad.size(1),
-                    limit_size,
-                )
-            xs_pad, masks = self.embed(xs_pad, masks)
-        else:
-            xs_pad = self.embed(xs_pad)
-
-        intermediate_outs = []
-        if len(self.interctc_layer_idx) == 0:
-            xs_pad, masks = self.encoders(xs_pad, masks)
-        else:
-            for layer_idx, encoder_layer in enumerate(self.encoders):
-                xs_pad, masks = encoder_layer(xs_pad, masks)
-
-                if layer_idx + 1 in self.interctc_layer_idx:
-                    encoder_out = xs_pad
-
-                    # intermediate outputs are also normalized
-                    if self.normalize_before:
-                        encoder_out = self.after_norm(encoder_out)
-
-                    intermediate_outs.append((layer_idx + 1, encoder_out))
-
-                    if self.interctc_use_conditioning:
-                        ctc_out = ctc.softmax(encoder_out)
-                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-
-        olens = masks.squeeze(1).sum(1)
-        if len(intermediate_outs) > 0:
-            return (xs_pad, intermediate_outs), olens, None
-        return xs_pad, olens, None
-
-
-def _pre_hook(
-    state_dict,
-    prefix,
-    local_metadata,
-    strict,
-    missing_keys,
-    unexpected_keys,
-    error_msgs,
-):
-    # https://github.com/espnet/espnet/commit/21d70286c354c66c0350e65dc098d2ee236faccc#diff-bffb1396f038b317b2b64dd96e6d3563
-    rename_state_dict(prefix + "input_layer.", prefix + "embed.", state_dict)
-    # https://github.com/espnet/espnet/commit/3d422f6de8d4f03673b89e1caef698745ec749ea#diff-bffb1396f038b317b2b64dd96e6d3563
-    rename_state_dict(prefix + "norm.", prefix + "after_norm.", state_dict)
-
-
-class TransformerEncoder_s0(torch.nn.Module):
+class TransformerEncoder_lm(nn.Module):
     """Transformer encoder module.
 
     Args:
@@ -418,8 +206,7 @@
         conditioning_layer_dim=None,
     ):
         """Construct an Encoder object."""
-        super(TransformerEncoder_s0, self).__init__()
-        self._register_load_state_dict_pre_hook(_pre_hook)
+        super().__init__()
 
         self.conv_subsampling_factor = 1
         if input_layer == "linear":
diff --git a/funasr/models/mfcca/e2e_asr_mfcca.py b/funasr/models/mfcca/e2e_asr_mfcca.py
index 5ec9e94..48534dd 100644
--- a/funasr/models/mfcca/e2e_asr_mfcca.py
+++ b/funasr/models/mfcca/e2e_asr_mfcca.py
@@ -9,15 +9,15 @@
 import torch
 
 from funasr.metrics import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.losses.label_smoothing_loss import (
     LabelSmoothingLoss,  # noqa: H301
 )
 from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.layers.abs_normalize import AbsNormalize
diff --git a/funasr/models/mfcca/mfcca_encoder.py b/funasr/models/mfcca/mfcca_encoder.py
index 92dd6e7..52da33d 100644
--- a/funasr/models/mfcca/mfcca_encoder.py
+++ b/funasr/models/mfcca/mfcca_encoder.py
@@ -26,13 +26,13 @@
 from funasr.models.transformer.positionwise_feed_forward import (
     PositionwiseFeedForward,  # noqa: H301
 )
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
 from funasr.models.encoder.abs_encoder import AbsEncoder
 import pdb
 import math
diff --git a/funasr/models/model_class_factory.py b/funasr/models/model_class_factory.py
deleted file mode 100644
index 23d2304..0000000
--- a/funasr/models/model_class_factory.py
+++ /dev/null
@@ -1,162 +0,0 @@
-
-from funasr.models.normalize.global_mvn import GlobalMVN
-from funasr.models.normalize.utterance_mvn import UtteranceMVN
-from funasr.models.ctc.ctc import CTC
-
-from funasr.models.transducer.rnn_decoder import RNNDecoder
-from funasr.models.sanm.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
-from funasr.models.transformer.transformer_decoder import (
-    DynamicConvolution2DTransformerDecoder,  # noqa: H301
-)
-from funasr.models.transformer.transformer_decoder import DynamicConvolutionTransformerDecoder
-from funasr.models.transformer.transformer_decoder import (
-    LightweightConvolution2DTransformerDecoder,  # noqa: H301
-)
-from funasr.models.transformer.transformer_decoder import (
-    LightweightConvolutionTransformerDecoder,  # noqa: H301
-)
-from funasr.models.transformer.transformer_decoder import ParaformerDecoderSAN
-from funasr.models.transformer.transformer_decoder import TransformerDecoder
-from funasr.models.paraformer.contextual_decoder import ContextualParaformerDecoder
-from funasr.models.transformer.transformer_decoder import SAAsrTransformerDecoder
-
-from funasr.models.transducer.rnnt_decoder import RNNTDecoder
-from funasr.models.transducer.joint_network import JointNetwork
-
-
-from funasr.models.conformer.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
-from funasr.models.data2vec.data2vec_encoder import Data2VecEncoder
-from funasr.models.transducer.rnn_encoder import RNNEncoder
-from funasr.models.sanm.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
-from funasr.models.transformer.transformer_encoder import TransformerEncoder
-from funasr.models.branchformer.branchformer_encoder import BranchformerEncoder
-from funasr.models.e_branchformer.e_branchformer_encoder import EBranchformerEncoder
-from funasr.models.mfcca.mfcca_encoder import MFCCAEncoder
-from funasr.models.sond.encoder.resnet34_encoder import ResNet34Diar
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.default import MultiChannelFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-
-
-from funasr.models.paraformer.cif_predictor import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
-from funasr.models.specaug.specaug import SpecAug
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.models.transformer.subsampling import Conv1dSubsampling
-from funasr.utils.class_choices import ClassChoices
-from funasr.models.fsmn_vad.fsmn_encoder import FSMN
-
-from funasr.models.sond.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
-from funasr.models.sond.encoder.conv_encoder import ConvEncoder
-from funasr.models.sond.encoder.fsmn_encoder import FsmnEncoder
-from funasr.models.sond.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
-
-from funasr.models.sond.encoder.conv_encoder import ConvEncoder
-from funasr.models.sond.encoder.fsmn_encoder import FsmnEncoder
-from funasr.models.eend.encoder_decoder_attractor import EncoderDecoderAttractor
-from funasr.models.eend.encoder import EENDOLATransformerEncoder
-
-frontend_choices = ClassChoices(
-    name="frontend",
-    classes=dict(
-        default=DefaultFrontend,
-        sliding_window=SlidingWindow,
-        s3prl=S3prlFrontend,
-        fused=FusedFrontends,
-        wav_frontend=WavFrontend,
-        multichannelfrontend=MultiChannelFrontend,
-    ),
-    default="default",
-)
-specaug_choices = ClassChoices(
-    name="specaug",
-    classes=dict(
-        specaug=SpecAug,
-        specaug_lfr=SpecAugLFR,
-    ),
-    default=None,
-)
-normalize_choices = ClassChoices(
-    "normalize",
-    classes=dict(
-        global_mvn=GlobalMVN,
-        utterance_mvn=UtteranceMVN,
-    ),
-    default=None,
-)
-
-encoder_choices = ClassChoices(
-    "encoder",
-    classes=dict(
-        conformer=ConformerEncoder,
-        transformer=TransformerEncoder,
-        rnn=RNNEncoder,
-        sanm=SANMEncoder,
-        sanm_chunk_opt=SANMEncoderChunkOpt,
-        data2vec_encoder=Data2VecEncoder,
-        mfcca_enc=MFCCAEncoder,
-        chunk_conformer=ConformerChunkEncoder,
-        fsmn=FSMN,
-        branchformer=BranchformerEncoder,
-        e_branchformer=EBranchformerEncoder,
-        resnet34=ResNet34Diar,
-        resnet34_sp_l2reg=ResNet34SpL2RegDiar,
-        ecapa_tdnn=ECAPA_TDNN,
-        eend_ola_transformer=EENDOLATransformerEncoder,
-        conv=ConvEncoder,
-        resnet34_diar=ResNet34Diar,
-    ),
-    default="rnn",
-)
-
-
-decoder_choices = ClassChoices(
-    "decoder",
-    classes=dict(
-        transformer=TransformerDecoder,
-        lightweight_conv=LightweightConvolutionTransformerDecoder,
-        lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
-        dynamic_conv=DynamicConvolutionTransformerDecoder,
-        dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
-        rnn=RNNDecoder,
-        fsmn_scama_opt=FsmnDecoderSCAMAOpt,
-        paraformer_decoder_sanm=ParaformerSANMDecoder,
-        paraformer_decoder_san=ParaformerDecoderSAN,
-        contextual_paraformer_decoder=ContextualParaformerDecoder,
-        sa_decoder=SAAsrTransformerDecoder,
-        rnnt=RNNTDecoder,
-    ),
-    default="transformer",
-)
-
-
-joint_network_choices = ClassChoices(
-    name="joint_network",
-    classes=dict(
-        joint_network=JointNetwork,
-    ),
-    default="joint_network",
-)
-
-predictor_choices = ClassChoices(
-    name="predictor",
-    classes=dict(
-        cif_predictor=CifPredictor,
-        ctc_predictor=None,
-        cif_predictor_v2=CifPredictorV2,
-        cif_predictor_v3=CifPredictorV3,
-        bat_predictor=BATPredictor,
-    ),
-    default="cif_predictor",
-)
-
-stride_conv_choices = ClassChoices(
-    name="stride_conv",
-    classes=dict(
-        stride_conv1d=Conv1dSubsampling
-    ),
-    default="stride_conv1d",
-)
\ No newline at end of file
diff --git a/funasr/models/paraformer/contextual_decoder.py b/funasr/models/neat_contextual_paraformer/decoder.py
similarity index 98%
rename from funasr/models/paraformer/contextual_decoder.py
rename to funasr/models/neat_contextual_paraformer/decoder.py
index 626cdef..ca689d3 100644
--- a/funasr/models/paraformer/contextual_decoder.py
+++ b/funasr/models/neat_contextual_paraformer/decoder.py
@@ -6,15 +6,15 @@
 import numpy as np
 
 from funasr.models.scama import utils as myutils
-from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
 
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
 from funasr.models.transformer.embedding import PositionalEncoding
 from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.transformer.repeat import repeat
-from funasr.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder
 
+from funasr.utils.register import register_class, registry_tables
 
 class ContextualDecoderLayer(nn.Module):
     def __init__(
@@ -98,7 +98,7 @@
             x =  self.dropout(self.src_attn(x, memory, memory_mask))
         return x, tgt_mask, memory, memory_mask, cache
 
-
+@register_class("decoder_classes", "ContextualParaformerDecoder")
 class ContextualParaformerDecoder(ParaformerSANMDecoder):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
diff --git a/funasr/models/neat_contextual_paraformer/model.py b/funasr/models/neat_contextual_paraformer/model.py
new file mode 100644
index 0000000..7891307
--- /dev/null
+++ b/funasr/models/neat_contextual_paraformer/model.py
@@ -0,0 +1,533 @@
+import os
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+import tempfile
+import codecs
+import requests
+import re
+import copy
+import torch
+import torch.nn as nn
+import random
+import numpy as np
+import time
+# from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+	LabelSmoothingLoss,  # noqa: H301
+)
+# from funasr.models.ctc import CTC
+# from funasr.models.decoder.abs_decoder import AbsDecoder
+# from funasr.models.e2e_asr_common import ErrorCalculator
+# from funasr.models.encoder.abs_encoder import AbsEncoder
+# from funasr.frontends.abs_frontend import AbsFrontend
+# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.paraformer.cif_predictor import mae_loss
+# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+# from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import force_gatherable
+# from funasr.models.base_model import FunASRModel
+# from funasr.models.paraformer.cif_predictor import CifPredictorV3
+from funasr.models.paraformer.search import Hypothesis
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+	from torch.cuda.amp import autocast
+else:
+	# Nothing to do if torch<1.6.0
+	@contextmanager
+	def autocast(enabled=True):
+		yield
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+
+
+from funasr.models.paraformer.model import Paraformer
+
+from funasr.utils.register import register_class, registry_tables
+
+@register_class("model_classes", "NeatContextualParaformer")
+class NeatContextualParaformer(Paraformer):
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+	https://arxiv.org/abs/2206.08317
+	"""
+	
+	def __init__(
+		self,
+		*args,
+		**kwargs,
+	):
+		super().__init__(*args, **kwargs)
+		
+		self.target_buffer_length = kwargs.get("target_buffer_length", -1)
+		inner_dim = kwargs.get("inner_dim", 256)
+		bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
+		use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
+		crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
+		crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
+		bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
+
+
+		if bias_encoder_type == 'lstm':
+			logging.warning("enable bias encoder sampling and contextual training")
+			self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
+			self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
+		elif bias_encoder_type == 'mean':
+			logging.warning("enable bias encoder sampling and contextual training")
+			self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
+		else:
+			logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
+		
+		if self.target_buffer_length > 0:
+			self.hotword_buffer = None
+			self.length_record = []
+			self.current_buffer_length = 0
+		self.use_decoder_embedding = use_decoder_embedding
+		self.crit_attn_weight = crit_attn_weight
+		if self.crit_attn_weight > 0:
+			self.attn_loss = torch.nn.L1Loss()
+		self.crit_attn_smooth = crit_attn_smooth
+
+
+	def forward(
+		self,
+		speech: torch.Tensor,
+		speech_lengths: torch.Tensor,
+		text: torch.Tensor,
+		text_lengths: torch.Tensor,
+		**kwargs,
+	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+		"""Frontend + Encoder + Decoder + Calc loss
+	
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				text: (Batch, Length)
+				text_lengths: (Batch,)
+		"""
+		if len(text_lengths.size()) > 1:
+			text_lengths = text_lengths[:, 0]
+		if len(speech_lengths.size()) > 1:
+			speech_lengths = speech_lengths[:, 0]
+		
+		batch_size = speech.shape[0]
+
+		hotword_pad = kwargs.get("hotword_pad")
+		hotword_lengths = kwargs.get("hotword_lengths")
+		dha_pad = kwargs.get("dha_pad")
+		
+		# 1. Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+		
+		loss_ctc, cer_ctc = None, None
+		
+		stats = dict()
+		
+		# 1. CTC branch
+		if self.ctc_weight != 0.0:
+			loss_ctc, cer_ctc = self._calc_ctc_loss(
+				encoder_out, encoder_out_lens, text, text_lengths
+			)
+			
+			# Collect CTC branch stats
+			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+			stats["cer_ctc"] = cer_ctc
+		
+
+		# 2b. Attention decoder branch
+		loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
+			encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
+		)
+		
+		# 3. CTC-Att loss definition
+		if self.ctc_weight == 0.0:
+			loss = loss_att + loss_pre * self.predictor_weight
+		else:
+			loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+		
+		if loss_ideal is not None:
+			loss = loss + loss_ideal * self.crit_attn_weight
+			stats["loss_ideal"] = loss_ideal.detach().cpu()
+		
+		# Collect Attn branch stats
+		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+		stats["acc"] = acc_att
+		stats["cer"] = cer_att
+		stats["wer"] = wer_att
+		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+		
+		stats["loss"] = torch.clone(loss.detach())
+		# force_gatherable: to-device and to-tensor if scalar for DataParallel
+		if self.length_normalized_loss:
+			batch_size = int((text_lengths + self.predictor_bias).sum())
+		
+		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+		return loss, stats, weight
+	
+	
+	def _calc_att_clas_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+		hotword_pad: torch.Tensor,
+		hotword_lengths: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+		                                                             ignore_id=self.ignore_id)
+		
+		# -1. bias encoder
+		if self.use_decoder_embedding:
+			hw_embed = self.decoder.embed(hotword_pad)
+		else:
+			hw_embed = self.bias_embed(hotword_pad)
+		hw_embed, (_, _) = self.bias_encoder(hw_embed)
+		_ind = np.arange(0, hotword_pad.shape[0]).tolist()
+		selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
+		contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
+		
+		# 0. sampler
+		decoder_out_1st = None
+		if self.sampling_ratio > 0.0:
+			if self.step_cur < 2:
+				logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+			sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+			                                               pre_acoustic_embeds, contextual_info)
+		else:
+			if self.step_cur < 2:
+				logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+			sematic_embeds = pre_acoustic_embeds
+		
+		# 1. Forward decoder
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+		)
+		decoder_out, _ = decoder_outs[0], decoder_outs[1]
+		'''
+		if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
+			ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
+			attn_non_blank = attn[:,:,:,:-1]
+			ideal_attn_non_blank = ideal_attn[:,:,:-1]
+			loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
+		else:
+			loss_ideal = None
+		'''
+		loss_ideal = None
+		
+		if decoder_out_1st is None:
+			decoder_out_1st = decoder_out
+		# 2. Compute attention loss
+		loss_att = self.criterion_att(decoder_out, ys_pad)
+		acc_att = th_accuracy(
+			decoder_out_1st.view(-1, self.vocab_size),
+			ys_pad,
+			ignore_label=self.ignore_id,
+		)
+		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		
+		# Compute cer/wer using attention-decoder
+		if self.training or self.error_calculator is None:
+			cer_att, wer_att = None, None
+		else:
+			ys_hat = decoder_out_1st.argmax(dim=-1)
+			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+		
+		return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
+	
+	
+	def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
+		tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+		ys_pad = ys_pad * tgt_mask[:, :, 0]
+		if self.share_embedding:
+			ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
+		else:
+			ys_pad_embed = self.decoder.embed(ys_pad)
+		with torch.no_grad():
+			decoder_outs = self.decoder(
+				encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
+			)
+			decoder_out, _ = decoder_outs[0], decoder_outs[1]
+			pred_tokens = decoder_out.argmax(-1)
+			nonpad_positions = ys_pad.ne(self.ignore_id)
+			seq_lens = (nonpad_positions).sum(1)
+			same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+			input_mask = torch.ones_like(nonpad_positions)
+			bsz, seq_len = ys_pad.size()
+			for li in range(bsz):
+				target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+				if target_num > 0:
+					input_mask[li].scatter_(dim=0,
+					                        index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device),
+					                        value=0)
+			input_mask = input_mask.eq(1)
+			input_mask = input_mask.masked_fill(~nonpad_positions, False)
+			input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+		
+		sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+			input_mask_expand_dim, 0)
+		return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+	
+	
+	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
+	                               clas_scale=1.0):
+		if hw_list is None:
+			hw_list = [torch.Tensor([1]).long().to(encoder_out.device)]  # empty hotword list
+			hw_list_pad = pad_list(hw_list, 0)
+			if self.use_decoder_embedding:
+				hw_embed = self.decoder.embed(hw_list_pad)
+			else:
+				hw_embed = self.bias_embed(hw_list_pad)
+			hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
+			hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+		else:
+			hw_lengths = [len(i) for i in hw_list]
+			hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
+			if self.use_decoder_embedding:
+				hw_embed = self.decoder.embed(hw_list_pad)
+			else:
+				hw_embed = self.bias_embed(hw_list_pad)
+			hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
+			                                                   enforce_sorted=False)
+			_, (h_n, _) = self.bias_encoder(hw_embed)
+			hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+		
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
+		)
+		decoder_out = decoder_outs[0]
+		decoder_out = torch.log_softmax(decoder_out, dim=-1)
+		return decoder_out, ys_pad_lens
+		
+	def generate(self,
+	             data_in,
+	             data_lengths=None,
+	             key: list = None,
+	             tokenizer=None,
+	             frontend=None,
+	             **kwargs,
+	             ):
+		
+		# init beamsearch
+		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		if self.beam_search is None and (is_use_lm or is_use_ctc):
+			logging.info("enable beam_search")
+			self.init_beam_search(**kwargs)
+			self.nbest = kwargs.get("nbest", 1)
+		
+		meta_data = {}
+		
+		# extract fbank feats
+		time1 = time.perf_counter()
+		audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+		time2 = time.perf_counter()
+		meta_data["load_data"] = f"{time2 - time1:0.3f}"
+		speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+		                                       frontend=frontend)
+		time3 = time.perf_counter()
+		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+		meta_data[
+			"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+		
+		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+		# hotword
+		self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
+		
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+		
+		# predictor
+		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+		                                                                predictor_outs[2], predictor_outs[3]
+		pre_token_length = pre_token_length.round().long()
+		if torch.max(pre_token_length) < 1:
+			return []
+
+
+		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
+		                                                         pre_acoustic_embeds,
+		                                                         pre_token_length,
+		                                                         hw_list=self.hotword_list,
+		                                                         clas_scale=kwargs.get("clas_scale", 1.0))
+		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+		
+		results = []
+		b, n, d = decoder_out.size()
+		for i in range(b):
+			x = encoder_out[i, :encoder_out_lens[i], :]
+			am_scores = decoder_out[i, :pre_token_length[i], :]
+			if self.beam_search is not None:
+				nbest_hyps = self.beam_search(
+					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+					minlenratio=kwargs.get("minlenratio", 0.0)
+				)
+				
+				nbest_hyps = nbest_hyps[: self.nbest]
+			else:
+				
+				yseq = am_scores.argmax(dim=-1)
+				score = am_scores.max(dim=-1)[0]
+				score = torch.sum(score, dim=-1)
+				# pad with mask tokens to ensure compatibility with sos/eos tokens
+				yseq = torch.tensor(
+					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+				)
+				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+			for nbest_idx, hyp in enumerate(nbest_hyps):
+				ibest_writer = None
+				if ibest_writer is None and kwargs.get("output_dir") is not None:
+					writer = DatadirWriter(kwargs.get("output_dir"))
+					ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+				# remove sos/eos and get results
+				last_pos = -1
+				if isinstance(hyp.yseq, list):
+					token_int = hyp.yseq[1:last_pos]
+				else:
+					token_int = hyp.yseq[1:last_pos].tolist()
+				
+				# remove blank symbol id, which is assumed to be 0
+				token_int = list(
+					filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+				
+				if tokenizer is not None:
+					# Change integer-ids to tokens
+					token = tokenizer.ids2tokens(token_int)
+					text = tokenizer.tokens2text(token)
+					
+					text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+					result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+					
+					if ibest_writer is not None:
+						ibest_writer["token"][key[i]] = " ".join(token)
+						ibest_writer["text"][key[i]] = text
+						ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+				else:
+					result_i = {"key": key[i], "token_int": token_int}
+				results.append(result_i)
+		
+		return results, meta_data
+
+
+	def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None, frontend=None):
+		def load_seg_dict(seg_dict_file):
+			seg_dict = {}
+			assert isinstance(seg_dict_file, str)
+			with open(seg_dict_file, "r", encoding="utf8") as f:
+				lines = f.readlines()
+				for line in lines:
+					s = line.strip().split()
+					key = s[0]
+					value = s[1:]
+					seg_dict[key] = " ".join(value)
+			return seg_dict
+		
+		def seg_tokenize(txt, seg_dict):
+			pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+			out_txt = ""
+			for word in txt:
+				word = word.lower()
+				if word in seg_dict:
+					out_txt += seg_dict[word] + " "
+				else:
+					if pattern.match(word):
+						for char in word:
+							if char in seg_dict:
+								out_txt += seg_dict[char] + " "
+							else:
+								out_txt += "<unk>" + " "
+					else:
+						out_txt += "<unk>" + " "
+			return out_txt.strip().split()
+		
+		seg_dict = None
+		if frontend.cmvn_file is not None:
+			model_dir = os.path.dirname(frontend.cmvn_file)
+			seg_dict_file = os.path.join(model_dir, 'seg_dict')
+			if os.path.exists(seg_dict_file):
+				seg_dict = load_seg_dict(seg_dict_file)
+			else:
+				seg_dict = None
+		# for None
+		if hotword_list_or_file is None:
+			hotword_list = None
+		# for local txt inputs
+		elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
+			logging.info("Attempting to parse hotwords from local txt...")
+			hotword_list = []
+			hotword_str_list = []
+			with codecs.open(hotword_list_or_file, 'r') as fin:
+				for line in fin.readlines():
+					hw = line.strip()
+					hw_list = hw.split()
+					if seg_dict is not None:
+						hw_list = seg_tokenize(hw_list, seg_dict)
+					hotword_str_list.append(hw)
+					hotword_list.append(tokenizer.tokens2ids(hw_list))
+				hotword_list.append([self.sos])
+				hotword_str_list.append('<s>')
+			logging.info("Initialized hotword list from file: {}, hotword list: {}."
+			             .format(hotword_list_or_file, hotword_str_list))
+		# for url, download and generate txt
+		elif hotword_list_or_file.startswith('http'):
+			logging.info("Attempting to parse hotwords from url...")
+			work_dir = tempfile.TemporaryDirectory().name
+			if not os.path.exists(work_dir):
+				os.makedirs(work_dir)
+			text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+			local_file = requests.get(hotword_list_or_file)
+			open(text_file_path, "wb").write(local_file.content)
+			hotword_list_or_file = text_file_path
+			hotword_list = []
+			hotword_str_list = []
+			with codecs.open(hotword_list_or_file, 'r') as fin:
+				for line in fin.readlines():
+					hw = line.strip()
+					hw_list = hw.split()
+					if seg_dict is not None:
+						hw_list = seg_tokenize(hw_list, seg_dict)
+					hotword_str_list.append(hw)
+					hotword_list.append(tokenizer.tokens2ids(hw_list))
+				hotword_list.append([self.sos])
+				hotword_str_list.append('<s>')
+			logging.info("Initialized hotword list from file: {}, hotword list: {}."
+			             .format(hotword_list_or_file, hotword_str_list))
+		# for text str input
+		elif not hotword_list_or_file.endswith('.txt'):
+			logging.info("Attempting to parse hotwords as str...")
+			hotword_list = []
+			hotword_str_list = []
+			for hw in hotword_list_or_file.strip().split():
+				hotword_str_list.append(hw)
+				hw_list = hw.strip().split()
+				if seg_dict is not None:
+					hw_list = seg_tokenize(hw_list, seg_dict)
+				hotword_list.append(tokenizer.tokens2ids(hw_list))
+			hotword_list.append([self.sos])
+			hotword_str_list.append('<s>')
+			logging.info("Hotword list: {}.".format(hotword_str_list))
+		else:
+			hotword_list = None
+		return hotword_list
+
diff --git a/funasr/models/normalize/global_mvn.py b/funasr/models/normalize/global_mvn.py
index a9b7935..eea84dc 100644
--- a/funasr/models/normalize/global_mvn.py
+++ b/funasr/models/normalize/global_mvn.py
@@ -6,8 +6,9 @@
 import torch
 
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.register import register_class, registry_tables
 
-
+@register_class("normalize_classes", "GlobalMVN")
 class GlobalMVN(torch.nn.Module):
     """Apply global mean and variance normalization
     TODO(kamo): Make this class portable somehow
diff --git a/funasr/models/normalize/utterance_mvn.py b/funasr/models/normalize/utterance_mvn.py
index 609c7aa..60703fb 100644
--- a/funasr/models/normalize/utterance_mvn.py
+++ b/funasr/models/normalize/utterance_mvn.py
@@ -3,8 +3,9 @@
 import torch
 
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.register import register_class, registry_tables
 
-
+@register_class("normalize_classes", "UtteranceMVN")
 class UtteranceMVN(torch.nn.Module):
     def __init__(
         self,
diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py
index 7cc088f..c1b7d7a 100644
--- a/funasr/models/paraformer/cif_predictor.py
+++ b/funasr/models/paraformer/cif_predictor.py
@@ -8,9 +8,12 @@
 from funasr.models.scama.utils import sequence_mask
 from typing import Optional, Tuple
 
+from funasr.utils.register import register_class, registry_tables
+
+@register_class("predictor_classes", "CifPredictor")
 class CifPredictor(nn.Module):
     def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, tail_threshold=0.45):
-        super(CifPredictor, self).__init__()
+        super().__init__()
 
         self.pad = nn.ConstantPad1d((l_order, r_order), 0)
         self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
@@ -133,7 +136,7 @@
         predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
         return predictor_alignments.detach(), predictor_alignments_length.detach()
 
-
+@register_class("predictor_classes", "CifPredictorV2")
 class CifPredictorV2(nn.Module):
     def __init__(self,
                  idim,
@@ -505,372 +508,3 @@
     fires = torch.stack(list_fires, 1)
     return fires
 
-
-class CifPredictorV3(nn.Module):
-    def __init__(self,
-                 idim,
-                 l_order,
-                 r_order,
-                 threshold=1.0,
-                 dropout=0.1,
-                 smooth_factor=1.0,
-                 noise_threshold=0,
-                 tail_threshold=0.0,
-                 tf2torch_tensor_name_prefix_torch="predictor",
-                 tf2torch_tensor_name_prefix_tf="seq2seq/cif",
-                 smooth_factor2=1.0,
-                 noise_threshold2=0,
-                 upsample_times=5,
-                 upsample_type="cnn",
-                 use_cif1_cnn=True,
-                 tail_mask=True,
-                 ):
-        super(CifPredictorV3, self).__init__()
-
-        self.pad = nn.ConstantPad1d((l_order, r_order), 0)
-        self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1)
-        self.cif_output = nn.Linear(idim, 1)
-        self.dropout = torch.nn.Dropout(p=dropout)
-        self.threshold = threshold
-        self.smooth_factor = smooth_factor
-        self.noise_threshold = noise_threshold
-        self.tail_threshold = tail_threshold
-        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
-        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-
-        self.upsample_times = upsample_times
-        self.upsample_type = upsample_type
-        self.use_cif1_cnn = use_cif1_cnn
-        if self.upsample_type == 'cnn':
-            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
-            self.cif_output2 = nn.Linear(idim, 1)
-        elif self.upsample_type == 'cnn_blstm':
-            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
-            self.blstm = nn.LSTM(idim, idim, 1, bias=True, batch_first=True, dropout=0.0, bidirectional=True)
-            self.cif_output2 = nn.Linear(idim*2, 1)
-        elif self.upsample_type == 'cnn_attn':
-            self.upsample_cnn = nn.ConvTranspose1d(idim, idim, self.upsample_times, self.upsample_times)
-            from funasr.models.encoder.transformer_encoder import EncoderLayer as TransformerEncoderLayer
-            from funasr.models.transformer.attention import MultiHeadedAttention
-            from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
-            positionwise_layer_args = (
-                idim,
-                idim*2,
-                0.1,
-            )
-            self.self_attn = TransformerEncoderLayer(
-                idim,
-                MultiHeadedAttention(
-                    4, idim, 0.1
-                ),
-                PositionwiseFeedForward(*positionwise_layer_args),
-                0.1,
-                True, #normalize_before,
-                False, #concat_after,
-            )
-            self.cif_output2 = nn.Linear(idim, 1)
-        self.smooth_factor2 = smooth_factor2
-        self.noise_threshold2 = noise_threshold2
-
-    def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
-                target_label_length=None):
-        h = hidden
-        context = h.transpose(1, 2)
-        queries = self.pad(context)
-        output = torch.relu(self.cif_conv1d(queries))
-
-        # alphas2 is an extra head for timestamp prediction
-        if not self.use_cif1_cnn:
-            _output = context
-        else:
-            _output = output
-        if self.upsample_type == 'cnn':
-            output2 = self.upsample_cnn(_output)
-            output2 = output2.transpose(1,2)
-        elif self.upsample_type == 'cnn_blstm':
-            output2 = self.upsample_cnn(_output)
-            output2 = output2.transpose(1,2)
-            output2, (_, _) = self.blstm(output2)
-        elif self.upsample_type == 'cnn_attn':
-            output2 = self.upsample_cnn(_output)
-            output2 = output2.transpose(1,2)
-            output2, _ = self.self_attn(output2, mask)
-        # import pdb; pdb.set_trace()
-        alphas2 = torch.sigmoid(self.cif_output2(output2))
-        alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
-        # repeat the mask in T demension to match the upsampled length
-        if mask is not None:
-            mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
-            mask2 = mask2.unsqueeze(-1)
-            alphas2 = alphas2 * mask2
-        alphas2 = alphas2.squeeze(-1)
-        token_num2 = alphas2.sum(-1)
-
-        output = output.transpose(1, 2)
-
-        output = self.cif_output(output)
-        alphas = torch.sigmoid(output)
-        alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
-        if mask is not None:
-            mask = mask.transpose(-1, -2).float()
-            alphas = alphas * mask
-        if mask_chunk_predictor is not None:
-            alphas = alphas * mask_chunk_predictor
-        alphas = alphas.squeeze(-1)
-        mask = mask.squeeze(-1)
-        if target_label_length is not None:
-            target_length = target_label_length
-        elif target_label is not None:
-            target_length = (target_label != ignore_id).float().sum(-1)
-        else:
-            target_length = None
-        token_num = alphas.sum(-1)
-
-        if target_length is not None:
-            alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
-        elif self.tail_threshold > 0.0:
-            hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, token_num, mask=mask)
-
-        acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-        if target_length is None and self.tail_threshold > 0.0:
-            token_num_int = torch.max(token_num).type(torch.int32).item()
-            acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
-        return acoustic_embeds, token_num, alphas, cif_peak, token_num2
-
-    def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
-        h = hidden
-        b = hidden.shape[0]
-        context = h.transpose(1, 2)
-        queries = self.pad(context)
-        output = torch.relu(self.cif_conv1d(queries))
-
-        # alphas2 is an extra head for timestamp prediction
-        if not self.use_cif1_cnn:
-            _output = context
-        else:
-            _output = output
-        if self.upsample_type == 'cnn':
-            output2 = self.upsample_cnn(_output)
-            output2 = output2.transpose(1,2)
-        elif self.upsample_type == 'cnn_blstm':
-            output2 = self.upsample_cnn(_output)
-            output2 = output2.transpose(1,2)
-            output2, (_, _) = self.blstm(output2)
-        elif self.upsample_type == 'cnn_attn':
-            output2 = self.upsample_cnn(_output)
-            output2 = output2.transpose(1,2)
-            output2, _ = self.self_attn(output2, mask)
-        alphas2 = torch.sigmoid(self.cif_output2(output2))
-        alphas2 = torch.nn.functional.relu(alphas2 * self.smooth_factor2 - self.noise_threshold2)
-        # repeat the mask in T demension to match the upsampled length
-        if mask is not None:
-            mask2 = mask.repeat(1, self.upsample_times, 1).transpose(-1, -2).reshape(alphas2.shape[0], -1)
-            mask2 = mask2.unsqueeze(-1)
-            alphas2 = alphas2 * mask2
-        alphas2 = alphas2.squeeze(-1)
-        _token_num = alphas2.sum(-1)
-        if token_num is not None:
-            alphas2 *= (token_num / _token_num)[:, None].repeat(1, alphas2.size(1))
-        # re-downsample
-        ds_alphas = alphas2.reshape(b, -1, self.upsample_times).sum(-1)
-        ds_cif_peak = cif_wo_hidden(ds_alphas, self.threshold - 1e-4)
-        # upsampled alphas and cif_peak
-        us_alphas = alphas2
-        us_cif_peak = cif_wo_hidden(us_alphas, self.threshold - 1e-4)
-        return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
-
-    def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
-        b, t, d = hidden.size()
-        tail_threshold = self.tail_threshold
-        if mask is not None:
-            zeros_t = torch.zeros((b, 1), dtype=torch.float32, device=alphas.device)
-            ones_t = torch.ones_like(zeros_t)
-            mask_1 = torch.cat([mask, zeros_t], dim=1)
-            mask_2 = torch.cat([ones_t, mask], dim=1)
-            mask = mask_2 - mask_1
-            tail_threshold = mask * tail_threshold
-            alphas = torch.cat([alphas, zeros_t], dim=1)
-            alphas = torch.add(alphas, tail_threshold)
-        else:
-            tail_threshold = torch.tensor([tail_threshold], dtype=alphas.dtype).to(alphas.device)
-            tail_threshold = torch.reshape(tail_threshold, (1, 1))
-            alphas = torch.cat([alphas, tail_threshold], dim=1)
-        zeros = torch.zeros((b, 1, d), dtype=hidden.dtype).to(hidden.device)
-        hidden = torch.cat([hidden, zeros], dim=1)
-        token_num = alphas.sum(dim=-1)
-        token_num_floor = torch.floor(token_num)
-
-        return hidden, alphas, token_num_floor
-
-    def gen_frame_alignments(self,
-                             alphas: torch.Tensor = None,
-                             encoder_sequence_length: torch.Tensor = None):
-        batch_size, maximum_length = alphas.size()
-        int_type = torch.int32
-
-        is_training = self.training
-        if is_training:
-            token_num = torch.round(torch.sum(alphas, dim=1)).type(int_type)
-        else:
-            token_num = torch.floor(torch.sum(alphas, dim=1)).type(int_type)
-
-        max_token_num = torch.max(token_num).item()
-
-        alphas_cumsum = torch.cumsum(alphas, dim=1)
-        alphas_cumsum = torch.floor(alphas_cumsum).type(int_type)
-        alphas_cumsum = alphas_cumsum[:, None, :].repeat(1, max_token_num, 1)
-
-        index = torch.ones([batch_size, max_token_num], dtype=int_type)
-        index = torch.cumsum(index, dim=1)
-        index = index[:, :, None].repeat(1, 1, maximum_length).to(alphas_cumsum.device)
-
-        index_div = torch.floor(torch.true_divide(alphas_cumsum, index)).type(int_type)
-        index_div_bool_zeros = index_div.eq(0)
-        index_div_bool_zeros_count = torch.sum(index_div_bool_zeros, dim=-1) + 1
-        index_div_bool_zeros_count = torch.clamp(index_div_bool_zeros_count, 0, encoder_sequence_length.max())
-        token_num_mask = (~make_pad_mask(token_num, maxlen=max_token_num)).to(token_num.device)
-        index_div_bool_zeros_count *= token_num_mask
-
-        index_div_bool_zeros_count_tile = index_div_bool_zeros_count[:, :, None].repeat(1, 1, maximum_length)
-        ones = torch.ones_like(index_div_bool_zeros_count_tile)
-        zeros = torch.zeros_like(index_div_bool_zeros_count_tile)
-        ones = torch.cumsum(ones, dim=2)
-        cond = index_div_bool_zeros_count_tile == ones
-        index_div_bool_zeros_count_tile = torch.where(cond, zeros, ones)
-
-        index_div_bool_zeros_count_tile_bool = index_div_bool_zeros_count_tile.type(torch.bool)
-        index_div_bool_zeros_count_tile = 1 - index_div_bool_zeros_count_tile_bool.type(int_type)
-        index_div_bool_zeros_count_tile_out = torch.sum(index_div_bool_zeros_count_tile, dim=1)
-        index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out.type(int_type)
-        predictor_mask = (~make_pad_mask(encoder_sequence_length, maxlen=encoder_sequence_length.max())).type(
-            int_type).to(encoder_sequence_length.device)
-        index_div_bool_zeros_count_tile_out = index_div_bool_zeros_count_tile_out * predictor_mask
-
-        predictor_alignments = index_div_bool_zeros_count_tile_out
-        predictor_alignments_length = predictor_alignments.sum(-1).type(encoder_sequence_length.dtype)
-        return predictor_alignments.detach(), predictor_alignments_length.detach()
-
-class BATPredictor(nn.Module):
-    def __init__(self, idim, l_order, r_order, threshold=1.0, dropout=0.1, smooth_factor=1.0, noise_threshold=0, return_accum=False):
-        super(BATPredictor, self).__init__()
-
-        self.pad = nn.ConstantPad1d((l_order, r_order), 0)
-        self.cif_conv1d = nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
-        self.cif_output = nn.Linear(idim, 1)
-        self.dropout = torch.nn.Dropout(p=dropout)
-        self.threshold = threshold
-        self.smooth_factor = smooth_factor
-        self.noise_threshold = noise_threshold
-        self.return_accum = return_accum
-
-    def cif(
-        self,
-        input: Tensor,
-        alpha: Tensor,
-        beta: float = 1.0,
-        return_accum: bool = False,
-    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
-        B, S, C = input.size()
-        assert tuple(alpha.size()) == (B, S), f"{alpha.size()} != {(B, S)}"
-
-        dtype = alpha.dtype
-        alpha = alpha.float()
-
-        alpha_sum = alpha.sum(1)
-        feat_lengths = (alpha_sum / beta).floor().long()
-        T = feat_lengths.max()
-
-        # aggregate and integrate
-        csum = alpha.cumsum(-1)
-        with torch.no_grad():
-            # indices used for scattering
-            right_idx = (csum / beta).floor().long().clip(max=T)
-            left_idx = right_idx.roll(1, dims=1)
-            left_idx[:, 0] = 0
-
-            # count # of fires from each source
-            fire_num = right_idx - left_idx
-            extra_weights = (fire_num - 1).clip(min=0)
-            # The extra entry in last dim is for
-            output = input.new_zeros((B, T + 1, C))
-            source_range = torch.arange(1, 1 + S).unsqueeze(0).type_as(input)
-            zero = alpha.new_zeros((1,))
-
-        # right scatter
-        fire_mask = fire_num > 0
-        right_weight = torch.where(
-            fire_mask,
-            csum - right_idx.type_as(alpha) * beta,
-            zero
-        ).type_as(input)
-        # assert right_weight.ge(0).all(), f"{right_weight} should be non-negative."
-        output.scatter_add_(
-            1,
-            right_idx.unsqueeze(-1).expand(-1, -1, C),
-            right_weight.unsqueeze(-1) * input
-        )
-
-        # left scatter
-        left_weight = (
-            alpha - right_weight - extra_weights.type_as(alpha) * beta
-        ).type_as(input)
-        output.scatter_add_(
-            1,
-            left_idx.unsqueeze(-1).expand(-1, -1, C),
-            left_weight.unsqueeze(-1) * input
-        )
-
-         # extra scatters
-        if extra_weights.ge(0).any():
-            extra_steps = extra_weights.max().item()
-            tgt_idx = left_idx
-            src_feats = input * beta
-            for _ in range(extra_steps):
-                tgt_idx = (tgt_idx + 1).clip(max=T)
-                # (B, S, 1)
-                src_mask = (extra_weights > 0)
-                output.scatter_add_(
-                    1,
-                    tgt_idx.unsqueeze(-1).expand(-1, -1, C),
-                    src_feats * src_mask.unsqueeze(2)
-                )
-                extra_weights -= 1
-
-        output = output[:, :T, :]
-
-        if return_accum:
-            return output, csum
-        else:
-            return output, alpha
-
-    def forward(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None, target_label_length=None):
-        h = hidden
-        context = h.transpose(1, 2)
-        queries = self.pad(context)
-        memory = self.cif_conv1d(queries)
-        output = memory + context
-        output = self.dropout(output)
-        output = output.transpose(1, 2)
-        output = torch.relu(output)
-        output = self.cif_output(output)
-        alphas = torch.sigmoid(output)
-        alphas = torch.nn.functional.relu(alphas*self.smooth_factor - self.noise_threshold)
-        if mask is not None:
-            alphas = alphas * mask.transpose(-1, -2).float()
-        if mask_chunk_predictor is not None:
-            alphas = alphas * mask_chunk_predictor
-        alphas = alphas.squeeze(-1)
-        if target_label_length is not None:
-            target_length = target_label_length
-        elif target_label is not None:
-            target_length = (target_label != ignore_id).float().sum(-1)
-            # logging.info("target_length: {}".format(target_length))
-        else:
-            target_length = None
-        token_num = alphas.sum(-1)
-        if target_length is not None:
-            # length_noise = torch.rand(alphas.size(0), device=alphas.device) - 0.5
-            # target_length = length_noise + target_length
-            alphas *= ((target_length + 1e-4) / token_num)[:, None].repeat(1, alphas.size(1))
-        acoustic_embeds, cif_peak = self.cif(hidden, alphas, self.threshold, self.return_accum)
-        return acoustic_embeds, token_num, alphas, cif_peak
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
new file mode 100644
index 0000000..3fe9d19
--- /dev/null
+++ b/funasr/models/paraformer/decoder.py
@@ -0,0 +1,625 @@
+from typing import List
+from typing import Tuple
+import logging
+import torch
+import torch.nn as nn
+import numpy as np
+
+from funasr.models.scama import utils as myutils
+from funasr.models.transformer.decoder import BaseTransformerDecoder
+
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.decoder import DecoderLayer
+from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayerSANM(nn.Module):
+    """Single decoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        src_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+    """
+
+    def __init__(
+        self,
+        size,
+        self_attn,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(DecoderLayerSANM, self).__init__()
+        self.size = size
+        self.self_attn = self_attn
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        if self_attn is not None:
+            self.norm2 = LayerNorm(size)
+        if src_attn is not None:
+            self.norm3 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        # tgt = self.dropout(tgt)
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, _ = self.self_attn(tgt, tgt_mask)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        # tgt = self.dropout(tgt)
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            if self.training:
+                cache = None
+            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
+            x = residual + x
+
+        return x, memory, fsmn_cache, opt_cache
+
+
+@register_class("decoder_classes", "ParaformerSANMDecoder")
+class ParaformerSANMDecoder(BaseTransformerDecoder):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2006.01713
+    """
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        att_layer_num: int = 6,
+        kernel_size: int = 21,
+        sanm_shfit: int = 0,
+        lora_list: List[str] = None,
+        lora_rank: int = 8,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.1,
+        chunk_multiply_factor: tuple = (1,),
+        tf2torch_tensor_name_prefix_torch: str = "decoder",
+        tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+    ):
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+
+        if input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(vocab_size, attention_dim),
+                # pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        elif input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(vocab_size, attention_dim),
+                torch.nn.LayerNorm(attention_dim),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        else:
+            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+        self.normalize_before = normalize_before
+        if self.normalize_before:
+            self.after_norm = LayerNorm(attention_dim)
+        if use_output_layer:
+            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+        else:
+            self.output_layer = None
+
+        self.att_layer_num = att_layer_num
+        self.num_blocks = num_blocks
+        if sanm_shfit is None:
+            sanm_shfit = (kernel_size - 1) // 2
+        self.decoders = repeat(
+            att_layer_num,
+            lambda lnum: DecoderLayerSANM(
+                attention_dim,
+                MultiHeadedAttentionSANMDecoder(
+                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+                ),
+                MultiHeadedAttentionCrossAtt(
+                    attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
+                ),
+                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if num_blocks - att_layer_num <= 0:
+            self.decoders2 = None
+        else:
+            self.decoders2 = repeat(
+                num_blocks - att_layer_num,
+                lambda lnum: DecoderLayerSANM(
+                    attention_dim,
+                    MultiHeadedAttentionSANMDecoder(
+                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+                    ),
+                    None,
+                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                    dropout_rate,
+                    normalize_before,
+                    concat_after,
+                ),
+            )
+
+        self.decoders3 = repeat(
+            1,
+            lambda lnum: DecoderLayerSANM(
+                attention_dim,
+                None,
+                None,
+                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+        self.chunk_multiply_factor = chunk_multiply_factor
+
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        chunk_mask: torch.Tensor = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        tgt = ys_in_pad
+        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+        
+        memory = hs_pad
+        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        if chunk_mask is not None:
+            memory_mask = memory_mask * chunk_mask
+            if tgt_mask.size(1) != memory_mask.size(1):
+                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
+
+        x = tgt
+        x, tgt_mask, memory, memory_mask, _ = self.decoders(
+            x, tgt_mask, memory, memory_mask
+        )
+        if self.decoders2 is not None:
+            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
+                x, tgt_mask, memory, memory_mask
+            )
+        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
+            x, tgt_mask, memory, memory_mask
+        )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+
+        olens = tgt_mask.sum(1)
+        return x, olens
+
+    def score(self, ys, state, x):
+        """Score."""
+        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+        logp, state = self.forward_one_step(
+            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
+        )
+        return logp.squeeze(0), state
+
+    def forward_chunk(
+        self,
+        memory: torch.Tensor,
+        tgt: torch.Tensor,
+        cache: dict = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        x = tgt
+        if cache["decode_fsmn"] is None:
+            cache_layer_num = len(self.decoders)
+            if self.decoders2 is not None:
+                cache_layer_num += len(self.decoders2)
+            fsmn_cache = [None] * cache_layer_num
+        else:
+            fsmn_cache = cache["decode_fsmn"]
+
+        if cache["opt"] is None:
+            cache_layer_num = len(self.decoders)
+            opt_cache = [None] * cache_layer_num
+        else:
+            opt_cache = cache["opt"]
+
+        for i in range(self.att_layer_num):
+            decoder = self.decoders[i]
+            x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
+                x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
+                chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
+            )
+
+        if self.num_blocks - self.att_layer_num > 1:
+            for i in range(self.num_blocks - self.att_layer_num):
+                j = i + self.att_layer_num
+                decoder = self.decoders2[i]
+                x, memory, fsmn_cache[j], _  = decoder.forward_chunk(
+                    x, memory, fsmn_cache=fsmn_cache[j]
+                )
+
+        for decoder in self.decoders3:
+            x, memory, _, _ = decoder.forward_chunk(
+                x, memory
+            )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+
+        cache["decode_fsmn"] = fsmn_cache
+        if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
+            cache["opt"] = opt_cache
+        return x
+
+    def forward_one_step(
+        self,
+        tgt: torch.Tensor,
+        tgt_mask: torch.Tensor,
+        memory: torch.Tensor,
+        cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Forward one step.
+
+        Args:
+            tgt: input token ids, int64 (batch, maxlen_out)
+            tgt_mask: input token mask,  (batch, maxlen_out)
+                      dtype=torch.uint8 in PyTorch 1.2-
+                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+            memory: encoded memory, float32  (batch, maxlen_in, feat)
+            cache: cached output list of (batch, max_time_out-1, size)
+        Returns:
+            y, cache: NN output value and cache per `self.decoders`.
+            y.shape` is (batch, maxlen_out, token)
+        """
+        x = self.embed(tgt)
+        if cache is None:
+            cache_layer_num = len(self.decoders)
+            if self.decoders2 is not None:
+                cache_layer_num += len(self.decoders2)
+            cache = [None] * cache_layer_num
+        new_cache = []
+        # for c, decoder in zip(cache, self.decoders):
+        for i in range(self.att_layer_num):
+            decoder = self.decoders[i]
+            c = cache[i]
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                x, tgt_mask, memory, None, cache=c
+            )
+            new_cache.append(c_ret)
+
+        if self.num_blocks - self.att_layer_num > 1:
+            for i in range(self.num_blocks - self.att_layer_num):
+                j = i + self.att_layer_num
+                decoder = self.decoders2[i]
+                c = cache[j]
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                    x, tgt_mask, memory, None, cache=c
+                )
+                new_cache.append(c_ret)
+
+        for decoder in self.decoders3:
+
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
+                x, tgt_mask, memory, None, cache=None
+            )
+
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.output_layer is not None:
+            y = torch.log_softmax(self.output_layer(y), dim=-1)
+
+        return y, new_cache
+
+
+@register_class("decoder_classes", "ParaformerDecoderSAN")
+class ParaformerDecoderSAN(BaseTransformerDecoder):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2006.01713
+    """
+    def __init__(
+            self,
+            vocab_size: int,
+            encoder_output_size: int,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            self_attention_dropout_rate: float = 0.0,
+            src_attention_dropout_rate: float = 0.0,
+            input_layer: str = "embed",
+            use_output_layer: bool = True,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            embeds_id: int = -1,
+    ):
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+        self.decoders = repeat(
+            num_blocks,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, self_attention_dropout_rate
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        self.embeds_id = embeds_id
+        self.attention_dim = attention_dim
+
+    def forward(
+            self,
+            hs_pad: torch.Tensor,
+            hlens: torch.Tensor,
+            ys_in_pad: torch.Tensor,
+            ys_in_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        tgt = ys_in_pad
+        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+
+        memory = hs_pad
+        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
+            memory.device
+        )
+        # Padding for Longformer
+        if memory_mask.shape[-1] != memory.shape[1]:
+            padlen = memory.shape[1] - memory_mask.shape[-1]
+            memory_mask = torch.nn.functional.pad(
+                memory_mask, (0, padlen), "constant", False
+            )
+
+        # x = self.embed(tgt)
+        x = tgt
+        embeds_outputs = None
+        for layer_id, decoder in enumerate(self.decoders):
+            x, tgt_mask, memory, memory_mask = decoder(
+                x, tgt_mask, memory, memory_mask
+            )
+            if layer_id == self.embeds_id:
+                embeds_outputs = x
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+
+        olens = tgt_mask.sum(1)
+        if embeds_outputs is not None:
+            return x, olens, embeds_outputs
+        else:
+            return x, olens
+
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 46bd7b0..fad8385 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -1,56 +1,34 @@
+import os
 import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
+from typing import Union, Dict, List, Tuple, Optional
+
 import torch
 import torch.nn as nn
-import random
-import numpy as np
+
 import time
-# from funasr.layers.abs_normalize import AbsNormalize
+
 from funasr.losses.label_smoothing_loss import (
 	LabelSmoothingLoss,  # noqa: H301
 )
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
-# from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+
+from funasr.models.paraformer.cif_predictor import mae_loss
+
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
 from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
+
 from funasr.models.paraformer.search import Hypothesis
 
-from funasr.models.model_class_factory import *
+from torch.cuda.amp import autocast
 
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
-	from torch.cuda.amp import autocast
-else:
-	# Nothing to do if torch<1.6.0
-	@contextmanager
-	def autocast(enabled=True):
-		yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
 from funasr.utils import postprocess_utils
 from funasr.utils.datadir_writer import DatadirWriter
-from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.utils.register import register_class, registry_tables
+from funasr.models.ctc.ctc import CTC
 
+@register_class("model_classes", "Paraformer")
 class Paraformer(nn.Module):
 	"""
 	Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -101,24 +79,19 @@
 	):
 
 		super().__init__()
-		
-		# import pdb;
-		# pdb.set_trace()
-		
-		if frontend is not None:
-			frontend_class = frontend_choices.get_class(frontend)
-			frontend = frontend_class(**frontend_conf)
+
 		if specaug is not None:
-			specaug_class = specaug_choices.get_class(specaug)
+			specaug_class = registry_tables.specaug_classes.get(specaug.lower())
 			specaug = specaug_class(**specaug_conf)
 		if normalize is not None:
-			normalize_class = normalize_choices.get_class(normalize)
+			normalize_class = registry_tables.normalize_classes.get(normalize.lower())
 			normalize = normalize_class(**normalize_conf)
-		encoder_class = encoder_choices.get_class(encoder)
+		encoder_class = registry_tables.encoder_classes.get(encoder.lower())
 		encoder = encoder_class(input_size=input_size, **encoder_conf)
 		encoder_output_size = encoder.output_size()
+
 		if decoder is not None:
-			decoder_class = decoder_choices.get_class(decoder)
+			decoder_class = registry_tables.decoder_classes.get(decoder.lower())
 			decoder = decoder_class(
 				vocab_size=vocab_size,
 				encoder_output_size=encoder_output_size,
@@ -133,7 +106,7 @@
 				odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
 			)
 		if predictor is not None:
-			predictor_class = predictor_choices.get_class(predictor)
+			predictor_class = registry_tables.predictor_classes.get(predictor.lower())
 			predictor = predictor_class(**predictor_conf)
 		
 		# note that eos is the same as sos (equivalent ID)
@@ -145,7 +118,7 @@
 		self.ctc_weight = ctc_weight
 		# self.token_list = token_list.copy()
 		#
-		self.frontend = frontend
+		# self.frontend = frontend
 		self.specaug = specaug
 		self.normalize = normalize
 		# self.preencoder = preencoder
@@ -275,7 +248,7 @@
 	def encode(
 		self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
 	) -> Tuple[torch.Tensor, torch.Tensor]:
-		"""Frontend + Encoder. Note that this method is used by asr_inference.py
+		"""Encoder. Note that this method is used by asr_inference.py
 		Args:
 				speech: (Batch, Length, ...)
 				speech_lengths: (Batch, )
@@ -469,10 +442,11 @@
 		self.beam_search = beam_search
 		
 	def generate(self,
-             data_in: list,
-             data_lengths: list=None,
+             data_in,
+             data_lengths=None,
              key: list=None,
              tokenizer=None,
+             frontend=None,
              **kwargs,
              ):
 		
@@ -485,16 +459,23 @@
 			self.nbest = kwargs.get("nbest", 1)
 		
 		meta_data = {}
-		# extract fbank feats
-		time1 = time.perf_counter()
-		audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
-		time2 = time.perf_counter()
-		meta_data["load_data"] = f"{time2 - time1:0.3f}"
-		speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"), frontend=self.frontend)
-		time3 = time.perf_counter()
-		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-		meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
-		
+		if isinstance(data_in, torch.Tensor): # fbank
+			speech, speech_lengths = data_in, data_lengths
+			if len(speech.shape) < 3:
+				speech = speech[None, :, :]
+			if speech_lengths is None:
+				speech_lengths = speech.shape[1]
+		else:
+			# extract fbank feats
+			time1 = time.perf_counter()
+			audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+			time2 = time.perf_counter()
+			meta_data["load_data"] = f"{time2 - time1:0.3f}"
+			speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
+			time3 = time.perf_counter()
+			meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+			meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+			
 		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
 
 		# Encoder
@@ -550,1211 +531,22 @@
 				# remove blank symbol id, which is assumed to be 0
 				token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
 				
-				# Change integer-ids to tokens
-				token = tokenizer.ids2tokens(token_int)
-				text = tokenizer.tokens2text(token)
-				
-				text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-				result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
-				results.append(result_i)
-				
-				if ibest_writer is not None:
-					ibest_writer["token"][key[i]] = " ".join(token)
-					ibest_writer["text"][key[i]] = text
-					ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
-		
-		return results, meta_data
-
-
-
-class BiCifParaformer(Paraformer):
-	"""
-	Author: Speech Lab of DAMO Academy, Alibaba Group
-	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
-	https://arxiv.org/abs/2206.08317
-	"""
-	
-	def __init__(
-		self,
-		*args,
-		**kwargs,
-	):
-		super().__init__(*args, **kwargs)
-		assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
-
-
-	def _calc_pre2_loss(
-		self,
-		encoder_out: torch.Tensor,
-		encoder_out_lens: torch.Tensor,
-		ys_pad: torch.Tensor,
-		ys_pad_lens: torch.Tensor,
-	):
-		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-			encoder_out.device)
-		if self.predictor_bias == 1:
-			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-			ys_pad_lens = ys_pad_lens + self.predictor_bias
-		_, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
-		
-		# loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-		loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
-		
-		return loss_pre2
-	
-	
-	def _calc_att_loss(
-		self,
-		encoder_out: torch.Tensor,
-		encoder_out_lens: torch.Tensor,
-		ys_pad: torch.Tensor,
-		ys_pad_lens: torch.Tensor,
-	):
-		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-			encoder_out.device)
-		if self.predictor_bias == 1:
-			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-			ys_pad_lens = ys_pad_lens + self.predictor_bias
-		pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
-		                                                                             encoder_out_mask,
-		                                                                             ignore_id=self.ignore_id)
-		
-		# 0. sampler
-		decoder_out_1st = None
-		if self.sampling_ratio > 0.0:
-			sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
-			                                               pre_acoustic_embeds)
-		else:
-			sematic_embeds = pre_acoustic_embeds
-		
-		# 1. Forward decoder
-		decoder_outs = self.decoder(
-			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
-		)
-		decoder_out, _ = decoder_outs[0], decoder_outs[1]
-		
-		if decoder_out_1st is None:
-			decoder_out_1st = decoder_out
-		# 2. Compute attention loss
-		loss_att = self.criterion_att(decoder_out, ys_pad)
-		acc_att = th_accuracy(
-			decoder_out_1st.view(-1, self.vocab_size),
-			ys_pad,
-			ignore_label=self.ignore_id,
-		)
-		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-		
-		# Compute cer/wer using attention-decoder
-		if self.training or self.error_calculator is None:
-			cer_att, wer_att = None, None
-		else:
-			ys_hat = decoder_out_1st.argmax(dim=-1)
-			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-		
-		return loss_att, acc_att, cer_att, wer_att, loss_pre
-
-
-	def calc_predictor(self, encoder_out, encoder_out_lens):
-		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-			encoder_out.device)
-		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
-		                                                                                                  None,
-		                                                                                                  encoder_out_mask,
-		                                                                                                  ignore_id=self.ignore_id)
-		return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
-
-	def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
-		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-			encoder_out.device)
-		ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
-		                                                                                    encoder_out_mask,
-		                                                                                    token_num)
-		return ds_alphas, ds_cif_peak, us_alphas, us_peaks
-	
-	
-	def forward(
-		self,
-		speech: torch.Tensor,
-		speech_lengths: torch.Tensor,
-		text: torch.Tensor,
-		text_lengths: torch.Tensor,
-		**kwargs,
-	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-		"""Frontend + Encoder + Decoder + Calc loss
-		Args:
-				speech: (Batch, Length, ...)
-				speech_lengths: (Batch, )
-				text: (Batch, Length)
-				text_lengths: (Batch,)
-		"""
-		if len(text_lengths.size()) > 1:
-			text_lengths = text_lengths[:, 0]
-		if len(speech_lengths.size()) > 1:
-			speech_lengths = speech_lengths[:, 0]
-		
-		batch_size = speech.shape[0]
-		
-		# Encoder
-		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
-
-		loss_ctc, cer_ctc = None, None
-		loss_pre = None
-		stats = dict()
-		
-		# decoder: CTC branch
-		if self.ctc_weight != 0.0:
-			loss_ctc, cer_ctc = self._calc_ctc_loss(
-				encoder_out, encoder_out_lens, text, text_lengths
-			)
-			
-			# Collect CTC branch stats
-			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
-			stats["cer_ctc"] = cer_ctc
-
-
-		# decoder: Attention decoder branch
-		loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
-			encoder_out, encoder_out_lens, text, text_lengths
-		)
-		
-		loss_pre2 = self._calc_pre2_loss(
-			encoder_out, encoder_out_lens, text, text_lengths
-		)
-		
-		# 3. CTC-Att loss definition
-		if self.ctc_weight == 0.0:
-			loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
-		else:
-			loss = self.ctc_weight * loss_ctc + (
-				1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
-		
-		# Collect Attn branch stats
-		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
-		stats["acc"] = acc_att
-		stats["cer"] = cer_att
-		stats["wer"] = wer_att
-		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-		stats["loss_pre2"] = loss_pre2.detach().cpu()
-		
-		stats["loss"] = torch.clone(loss.detach())
-		
-		# force_gatherable: to-device and to-tensor if scalar for DataParallel
-		if self.length_normalized_loss:
-			batch_size = int((text_lengths + self.predictor_bias).sum())
-		
-		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-		return loss, stats, weight
-	
-	def generate(self,
-	             data_in: list,
-	             data_lengths: list = None,
-	             key: list = None,
-	             tokenizer=None,
-	             **kwargs,
-	             ):
-		
-		# init beamsearch
-		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
-		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
-		if self.beam_search is None and (is_use_lm or is_use_ctc):
-			logging.info("enable beam_search")
-			self.init_beam_search(**kwargs)
-			self.nbest = kwargs.get("nbest", 1)
-		
-		meta_data = {}
-		# extract fbank feats
-		time1 = time.perf_counter()
-		audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
-		time2 = time.perf_counter()
-		meta_data["load_data"] = f"{time2 - time1:0.3f}"
-		speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"),
-		                                       frontend=self.frontend)
-		time3 = time.perf_counter()
-		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-		meta_data[
-			"batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
-		
-		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
-		
-		# Encoder
-		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-		if isinstance(encoder_out, tuple):
-			encoder_out = encoder_out[0]
-		
-		# predictor
-		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
-		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
-		                                                                predictor_outs[2], predictor_outs[3]
-		pre_token_length = pre_token_length.round().long()
-		if torch.max(pre_token_length) < 1:
-			return []
-		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
-		                                               pre_token_length)
-		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-		
-		# BiCifParaformer, test no bias cif2
-
-		_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
-			                                                                    pre_token_length)
-		
-		results = []
-		b, n, d = decoder_out.size()
-		for i in range(b):
-			x = encoder_out[i, :encoder_out_lens[i], :]
-			am_scores = decoder_out[i, :pre_token_length[i], :]
-			if self.beam_search is not None:
-				nbest_hyps = self.beam_search(
-					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
-					minlenratio=kwargs.get("minlenratio", 0.0)
-				)
-				
-				nbest_hyps = nbest_hyps[: self.nbest]
-			else:
-				
-				yseq = am_scores.argmax(dim=-1)
-				score = am_scores.max(dim=-1)[0]
-				score = torch.sum(score, dim=-1)
-				# pad with mask tokens to ensure compatibility with sos/eos tokens
-				yseq = torch.tensor(
-					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
-				)
-				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-			for nbest_idx, hyp in enumerate(nbest_hyps):
-				ibest_writer = None
-				if ibest_writer is None and kwargs.get("output_dir") is not None:
-					writer = DatadirWriter(kwargs.get("output_dir"))
-					ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
-				# remove sos/eos and get results
-				last_pos = -1
-				if isinstance(hyp.yseq, list):
-					token_int = hyp.yseq[1:last_pos]
-				else:
-					token_int = hyp.yseq[1:last_pos].tolist()
-				
-				# remove blank symbol id, which is assumed to be 0
-				token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-				
-				# Change integer-ids to tokens
-				token = tokenizer.ids2tokens(token_int)
-				text = tokenizer.tokens2text(token)
-				
-				_, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
-				                                           us_peaks[i][:encoder_out_lens[i] * 3],
-				                                           copy.copy(token),
-				                                           vad_offset=kwargs.get("begin_time", 0))
-				
-				text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token, timestamp)
-				
-				result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
-				            "time_stamp_postprocessed": time_stamp_postprocessed,
-				            "word_lists": word_lists
-				            }
-				results.append(result_i)
-				
-				if ibest_writer is not None:
-					ibest_writer["token"][key[i]] = " ".join(token)
-					ibest_writer["text"][key[i]] = text
-					ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+				if tokenizer is not None:
+					# Change integer-ids to tokens
+					token = tokenizer.ids2tokens(token_int)
+					text = tokenizer.tokens2text(token)
 					
-		
-		return results, meta_data
+					text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+					result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
 
-
-class NeatContextualParaformer(Paraformer):
-	"""
-	Author: Speech Lab of DAMO Academy, Alibaba Group
-	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
-	https://arxiv.org/abs/2206.08317
-	"""
-	
-	def __init__(
-		self,
-		*args,
-		**kwargs,
-	):
-		super().__init__(*args, **kwargs)
-		
-		self.target_buffer_length = kwargs.get("target_buffer_length", -1)
-		inner_dim = kwargs.get("inner_dim", 256)
-		bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
-		use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
-		crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
-		crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
-		bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
-
-
-		if bias_encoder_type == 'lstm':
-			logging.warning("enable bias encoder sampling and contextual training")
-			self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
-			self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
-		elif bias_encoder_type == 'mean':
-			logging.warning("enable bias encoder sampling and contextual training")
-			self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
-		else:
-			logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
-		
-		if self.target_buffer_length > 0:
-			self.hotword_buffer = None
-			self.length_record = []
-			self.current_buffer_length = 0
-		self.use_decoder_embedding = use_decoder_embedding
-		self.crit_attn_weight = crit_attn_weight
-		if self.crit_attn_weight > 0:
-			self.attn_loss = torch.nn.L1Loss()
-		self.crit_attn_smooth = crit_attn_smooth
-
-
-	def forward(
-		self,
-		speech: torch.Tensor,
-		speech_lengths: torch.Tensor,
-		text: torch.Tensor,
-		text_lengths: torch.Tensor,
-		hotword_pad: torch.Tensor,
-		hotword_lengths: torch.Tensor,
-		dha_pad: torch.Tensor,
-		**kwargs,
-	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-		"""Frontend + Encoder + Decoder + Calc loss
-	
-		Args:
-				speech: (Batch, Length, ...)
-				speech_lengths: (Batch, )
-				text: (Batch, Length)
-				text_lengths: (Batch,)
-		"""
-		if len(text_lengths.size()) > 1:
-			text_lengths = text_lengths[:, 0]
-		if len(speech_lengths.size()) > 1:
-			speech_lengths = speech_lengths[:, 0]
-		
-		batch_size = speech.shape[0]
-
-		# 1. Encoder
-		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
-		
-		loss_ctc, cer_ctc = None, None
-		
-		stats = dict()
-		
-		# 1. CTC branch
-		if self.ctc_weight != 0.0:
-			loss_ctc, cer_ctc = self._calc_ctc_loss(
-				encoder_out, encoder_out_lens, text, text_lengths
-			)
-			
-			# Collect CTC branch stats
-			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
-			stats["cer_ctc"] = cer_ctc
-		
-
-		# 2b. Attention decoder branch
-		loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
-			encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
-		)
-		
-		# 3. CTC-Att loss definition
-		if self.ctc_weight == 0.0:
-			loss = loss_att + loss_pre * self.predictor_weight
-		else:
-			loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
-		
-		if loss_ideal is not None:
-			loss = loss + loss_ideal * self.crit_attn_weight
-			stats["loss_ideal"] = loss_ideal.detach().cpu()
-		
-		# Collect Attn branch stats
-		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
-		stats["acc"] = acc_att
-		stats["cer"] = cer_att
-		stats["wer"] = wer_att
-		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-		
-		stats["loss"] = torch.clone(loss.detach())
-		# force_gatherable: to-device and to-tensor if scalar for DataParallel
-		if self.length_normalized_loss:
-			batch_size = int((text_lengths + self.predictor_bias).sum())
-		
-		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-		return loss, stats, weight
-	
-	
-	def _calc_att_clas_loss(
-		self,
-		encoder_out: torch.Tensor,
-		encoder_out_lens: torch.Tensor,
-		ys_pad: torch.Tensor,
-		ys_pad_lens: torch.Tensor,
-		hotword_pad: torch.Tensor,
-		hotword_lengths: torch.Tensor,
-	):
-		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-			encoder_out.device)
-		if self.predictor_bias == 1:
-			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-			ys_pad_lens = ys_pad_lens + self.predictor_bias
-		pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
-		                                                             ignore_id=self.ignore_id)
-		
-		# -1. bias encoder
-		if self.use_decoder_embedding:
-			hw_embed = self.decoder.embed(hotword_pad)
-		else:
-			hw_embed = self.bias_embed(hotword_pad)
-		hw_embed, (_, _) = self.bias_encoder(hw_embed)
-		_ind = np.arange(0, hotword_pad.shape[0]).tolist()
-		selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
-		contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
-		
-		# 0. sampler
-		decoder_out_1st = None
-		if self.sampling_ratio > 0.0:
-			if self.step_cur < 2:
-				logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
-			sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
-			                                               pre_acoustic_embeds, contextual_info)
-		else:
-			if self.step_cur < 2:
-				logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
-			sematic_embeds = pre_acoustic_embeds
-		
-		# 1. Forward decoder
-		decoder_outs = self.decoder(
-			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
-		)
-		decoder_out, _ = decoder_outs[0], decoder_outs[1]
-		'''
-		if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
-			ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
-			attn_non_blank = attn[:,:,:,:-1]
-			ideal_attn_non_blank = ideal_attn[:,:,:-1]
-			loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
-		else:
-			loss_ideal = None
-		'''
-		loss_ideal = None
-		
-		if decoder_out_1st is None:
-			decoder_out_1st = decoder_out
-		# 2. Compute attention loss
-		loss_att = self.criterion_att(decoder_out, ys_pad)
-		acc_att = th_accuracy(
-			decoder_out_1st.view(-1, self.vocab_size),
-			ys_pad,
-			ignore_label=self.ignore_id,
-		)
-		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-		
-		# Compute cer/wer using attention-decoder
-		if self.training or self.error_calculator is None:
-			cer_att, wer_att = None, None
-		else:
-			ys_hat = decoder_out_1st.argmax(dim=-1)
-			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-		
-		return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
-	
-	
-	def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
-		tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
-		ys_pad = ys_pad * tgt_mask[:, :, 0]
-		if self.share_embedding:
-			ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
-		else:
-			ys_pad_embed = self.decoder.embed(ys_pad)
-		with torch.no_grad():
-			decoder_outs = self.decoder(
-				encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
-			)
-			decoder_out, _ = decoder_outs[0], decoder_outs[1]
-			pred_tokens = decoder_out.argmax(-1)
-			nonpad_positions = ys_pad.ne(self.ignore_id)
-			seq_lens = (nonpad_positions).sum(1)
-			same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
-			input_mask = torch.ones_like(nonpad_positions)
-			bsz, seq_len = ys_pad.size()
-			for li in range(bsz):
-				target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
-				if target_num > 0:
-					input_mask[li].scatter_(dim=0,
-					                        index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device),
-					                        value=0)
-			input_mask = input_mask.eq(1)
-			input_mask = input_mask.masked_fill(~nonpad_positions, False)
-			input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-		
-		sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
-			input_mask_expand_dim, 0)
-		return sematic_embeds * tgt_mask, decoder_out * tgt_mask
-	
-	
-	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
-	                               clas_scale=1.0):
-		if hw_list is None:
-			hw_list = [torch.Tensor([1]).long().to(encoder_out.device)]  # empty hotword list
-			hw_list_pad = pad_list(hw_list, 0)
-			if self.use_decoder_embedding:
-				hw_embed = self.decoder.embed(hw_list_pad)
-			else:
-				hw_embed = self.bias_embed(hw_list_pad)
-			hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
-			hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
-		else:
-			hw_lengths = [len(i) for i in hw_list]
-			hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
-			if self.use_decoder_embedding:
-				hw_embed = self.decoder.embed(hw_list_pad)
-			else:
-				hw_embed = self.bias_embed(hw_list_pad)
-			hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
-			                                                   enforce_sorted=False)
-			_, (h_n, _) = self.bias_encoder(hw_embed)
-			hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
-		
-		decoder_outs = self.decoder(
-			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
-		)
-		decoder_out = decoder_outs[0]
-		decoder_out = torch.log_softmax(decoder_out, dim=-1)
-		return decoder_out, ys_pad_lens
-		
-	def generate(self,
-	             data_in: list,
-	             data_lengths: list = None,
-	             key: list = None,
-	             tokenizer=None,
-	             **kwargs,
-	             ):
-		
-		# init beamsearch
-		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
-		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
-		if self.beam_search is None and (is_use_lm or is_use_ctc):
-			logging.info("enable beam_search")
-			self.init_beam_search(**kwargs)
-			self.nbest = kwargs.get("nbest", 1)
-		
-		meta_data = {}
-		
-		# extract fbank feats
-		time1 = time.perf_counter()
-		audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
-		time2 = time.perf_counter()
-		meta_data["load_data"] = f"{time2 - time1:0.3f}"
-		speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"),
-		                                       frontend=self.frontend)
-		time3 = time.perf_counter()
-		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-		meta_data[
-			"batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
-		
-		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
-
-		# hotword
-		self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer)
-		
-		# Encoder
-		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-		if isinstance(encoder_out, tuple):
-			encoder_out = encoder_out[0]
-		
-		# predictor
-		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
-		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
-		                                                                predictor_outs[2], predictor_outs[3]
-		pre_token_length = pre_token_length.round().long()
-		if torch.max(pre_token_length) < 1:
-			return []
-
-
-		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
-		                                                         pre_acoustic_embeds,
-		                                                         pre_token_length,
-		                                                         hw_list=self.hotword_list,
-		                                                         clas_scale=kwargs.get("clas_scale", 1.0))
-		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-		
-		results = []
-		b, n, d = decoder_out.size()
-		for i in range(b):
-			x = encoder_out[i, :encoder_out_lens[i], :]
-			am_scores = decoder_out[i, :pre_token_length[i], :]
-			if self.beam_search is not None:
-				nbest_hyps = self.beam_search(
-					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
-					minlenratio=kwargs.get("minlenratio", 0.0)
-				)
-				
-				nbest_hyps = nbest_hyps[: self.nbest]
-			else:
-				
-				yseq = am_scores.argmax(dim=-1)
-				score = am_scores.max(dim=-1)[0]
-				score = torch.sum(score, dim=-1)
-				# pad with mask tokens to ensure compatibility with sos/eos tokens
-				yseq = torch.tensor(
-					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
-				)
-				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-			for nbest_idx, hyp in enumerate(nbest_hyps):
-				ibest_writer = None
-				if ibest_writer is None and kwargs.get("output_dir") is not None:
-					writer = DatadirWriter(kwargs.get("output_dir"))
-					ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
-				# remove sos/eos and get results
-				last_pos = -1
-				if isinstance(hyp.yseq, list):
-					token_int = hyp.yseq[1:last_pos]
+					
+					if ibest_writer is not None:
+						ibest_writer["token"][key[i]] = " ".join(token)
+						ibest_writer["text"][key[i]] = text
+						ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
 				else:
-					token_int = hyp.yseq[1:last_pos].tolist()
-				
-				# remove blank symbol id, which is assumed to be 0
-				token_int = list(
-					filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-				
-				# Change integer-ids to tokens
-				token = tokenizer.ids2tokens(token_int)
-				text = tokenizer.tokens2text(token)
-				
-				text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-				result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+					result_i = {"key": key[i], "token_int": token_int}
 				results.append(result_i)
 				
-				if ibest_writer is not None:
-					ibest_writer["token"][key[i]] = " ".join(token)
-					ibest_writer["text"][key[i]] = text
-					ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
-		
 		return results, meta_data
-
-
-	def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None):
-		def load_seg_dict(seg_dict_file):
-			seg_dict = {}
-			assert isinstance(seg_dict_file, str)
-			with open(seg_dict_file, "r", encoding="utf8") as f:
-				lines = f.readlines()
-				for line in lines:
-					s = line.strip().split()
-					key = s[0]
-					value = s[1:]
-					seg_dict[key] = " ".join(value)
-			return seg_dict
-		
-		def seg_tokenize(txt, seg_dict):
-			pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
-			out_txt = ""
-			for word in txt:
-				word = word.lower()
-				if word in seg_dict:
-					out_txt += seg_dict[word] + " "
-				else:
-					if pattern.match(word):
-						for char in word:
-							if char in seg_dict:
-								out_txt += seg_dict[char] + " "
-							else:
-								out_txt += "<unk>" + " "
-					else:
-						out_txt += "<unk>" + " "
-			return out_txt.strip().split()
-		
-		seg_dict = None
-		if self.frontend.cmvn_file is not None:
-			model_dir = os.path.dirname(self.frontend.cmvn_file)
-			seg_dict_file = os.path.join(model_dir, 'seg_dict')
-			if os.path.exists(seg_dict_file):
-				seg_dict = load_seg_dict(seg_dict_file)
-			else:
-				seg_dict = None
-		# for None
-		if hotword_list_or_file is None:
-			hotword_list = None
-		# for local txt inputs
-		elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
-			logging.info("Attempting to parse hotwords from local txt...")
-			hotword_list = []
-			hotword_str_list = []
-			with codecs.open(hotword_list_or_file, 'r') as fin:
-				for line in fin.readlines():
-					hw = line.strip()
-					hw_list = hw.split()
-					if seg_dict is not None:
-						hw_list = seg_tokenize(hw_list, seg_dict)
-					hotword_str_list.append(hw)
-					hotword_list.append(tokenizer.tokens2ids(hw_list))
-				hotword_list.append([self.sos])
-				hotword_str_list.append('<s>')
-			logging.info("Initialized hotword list from file: {}, hotword list: {}."
-			             .format(hotword_list_or_file, hotword_str_list))
-		# for url, download and generate txt
-		elif hotword_list_or_file.startswith('http'):
-			logging.info("Attempting to parse hotwords from url...")
-			work_dir = tempfile.TemporaryDirectory().name
-			if not os.path.exists(work_dir):
-				os.makedirs(work_dir)
-			text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
-			local_file = requests.get(hotword_list_or_file)
-			open(text_file_path, "wb").write(local_file.content)
-			hotword_list_or_file = text_file_path
-			hotword_list = []
-			hotword_str_list = []
-			with codecs.open(hotword_list_or_file, 'r') as fin:
-				for line in fin.readlines():
-					hw = line.strip()
-					hw_list = hw.split()
-					if seg_dict is not None:
-						hw_list = seg_tokenize(hw_list, seg_dict)
-					hotword_str_list.append(hw)
-					hotword_list.append(tokenizer.tokens2ids(hw_list))
-				hotword_list.append([self.sos])
-				hotword_str_list.append('<s>')
-			logging.info("Initialized hotword list from file: {}, hotword list: {}."
-			             .format(hotword_list_or_file, hotword_str_list))
-		# for text str input
-		elif not hotword_list_or_file.endswith('.txt'):
-			logging.info("Attempting to parse hotwords as str...")
-			hotword_list = []
-			hotword_str_list = []
-			for hw in hotword_list_or_file.strip().split():
-				hotword_str_list.append(hw)
-				hw_list = hw.strip().split()
-				if seg_dict is not None:
-					hw_list = seg_tokenize(hw_list, seg_dict)
-				hotword_list.append(tokenizer.tokens2ids(hw_list))
-			hotword_list.append([self.sos])
-			hotword_str_list.append('<s>')
-			logging.info("Hotword list: {}.".format(hotword_str_list))
-		else:
-			hotword_list = None
-		return hotword_list
-
-
-class ParaformerOnline(Paraformer):
-	"""
-	Author: Speech Lab of DAMO Academy, Alibaba Group
-	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
-	https://arxiv.org/abs/2206.08317
-	"""
-	
-	def __init__(
-		self,
-		*args,
-		**kwargs,
-	):
-		
-		super().__init__(*args, **kwargs)
-		
-		# import pdb;
-		# pdb.set_trace()
-		self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
-
-
-		self.scama_mask = None
-		if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
-			from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
-			self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
-			self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
-
-
-	
-	def forward(
-		self,
-		speech: torch.Tensor,
-		speech_lengths: torch.Tensor,
-		text: torch.Tensor,
-		text_lengths: torch.Tensor,
-		**kwargs,
-	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-		"""Encoder + Decoder + Calc loss
-		Args:
-				speech: (Batch, Length, ...)
-				speech_lengths: (Batch, )
-				text: (Batch, Length)
-				text_lengths: (Batch,)
-		"""
-		# import pdb;
-		# pdb.set_trace()
-		decoding_ind = kwargs.get("decoding_ind")
-		if len(text_lengths.size()) > 1:
-			text_lengths = text_lengths[:, 0]
-		if len(speech_lengths.size()) > 1:
-			speech_lengths = speech_lengths[:, 0]
-		
-		batch_size = speech.shape[0]
-		
-		# Encoder
-		if hasattr(self.encoder, "overlap_chunk_cls"):
-			ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
-			encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
-		else:
-			encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-		
-		loss_ctc, cer_ctc = None, None
-		loss_pre = None
-		stats = dict()
-		
-		# decoder: CTC branch
-
-		if self.ctc_weight > 0.0:
-			if hasattr(self.encoder, "overlap_chunk_cls"):
-				encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
-				                                                                                    encoder_out_lens,
-				                                                                                    chunk_outs=None)
-			else:
-				encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
-				
-			loss_ctc, cer_ctc = self._calc_ctc_loss(
-				encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
-			)
-			# Collect CTC branch stats
-			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
-			stats["cer_ctc"] = cer_ctc
-		
-		# decoder: Attention decoder branch
-		loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
-			encoder_out, encoder_out_lens, text, text_lengths
-		)
-		
-		# 3. CTC-Att loss definition
-		if self.ctc_weight == 0.0:
-			loss = loss_att + loss_pre * self.predictor_weight
-		else:
-			loss = self.ctc_weight * loss_ctc + (
-					1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
-		
-		# Collect Attn branch stats
-		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
-		stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
-		stats["acc"] = acc_att
-		stats["cer"] = cer_att
-		stats["wer"] = wer_att
-		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-		
-		stats["loss"] = torch.clone(loss.detach())
-		
-		# force_gatherable: to-device and to-tensor if scalar for DataParallel
-		if self.length_normalized_loss:
-			batch_size = (text_lengths + self.predictor_bias).sum()
-		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-		return loss, stats, weight
-	
-	def encode_chunk(
-		self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
-	) -> Tuple[torch.Tensor, torch.Tensor]:
-		"""Frontend + Encoder. Note that this method is used by asr_inference.py
-		Args:
-				speech: (Batch, Length, ...)
-				speech_lengths: (Batch, )
-				ind: int
-		"""
-		with autocast(False):
-			
-			# Data augmentation
-			if self.specaug is not None and self.training:
-				speech, speech_lengths = self.specaug(speech, speech_lengths)
-			
-			# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
-			if self.normalize is not None:
-				speech, speech_lengths = self.normalize(speech, speech_lengths)
-		
-		# Forward encoder
-		encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
-		if isinstance(encoder_out, tuple):
-			encoder_out = encoder_out[0]
-		
-		return encoder_out, torch.tensor([encoder_out.size(1)])
-	
-	def _calc_att_predictor_loss(
-		self,
-		encoder_out: torch.Tensor,
-		encoder_out_lens: torch.Tensor,
-		ys_pad: torch.Tensor,
-		ys_pad_lens: torch.Tensor,
-	):
-		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-			encoder_out.device)
-		if self.predictor_bias == 1:
-			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-			ys_pad_lens = ys_pad_lens + self.predictor_bias
-		mask_chunk_predictor = None
-		if self.encoder.overlap_chunk_cls is not None:
-			mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
-			                                                                               device=encoder_out.device,
-			                                                                               batch_size=encoder_out.size(
-				                                                                               0))
-			mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
-			                                                                       batch_size=encoder_out.size(0))
-			encoder_out = encoder_out * mask_shfit_chunk
-		pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
-		                                                                      ys_pad,
-		                                                                      encoder_out_mask,
-		                                                                      ignore_id=self.ignore_id,
-		                                                                      mask_chunk_predictor=mask_chunk_predictor,
-		                                                                      target_label_length=ys_pad_lens,
-		                                                                      )
-		predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
-		                                                                                     encoder_out_lens)
-		
-		scama_mask = None
-		if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
-			encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
-			attention_chunk_center_bias = 0
-			attention_chunk_size = encoder_chunk_size
-			decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
-			mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
-				get_mask_shift_att_chunk_decoder(None,
-			                                     device=encoder_out.device,
-			                                     batch_size=encoder_out.size(0)
-			                                     )
-			scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
-				predictor_alignments=predictor_alignments,
-				encoder_sequence_length=encoder_out_lens,
-				chunk_size=1,
-				encoder_chunk_size=encoder_chunk_size,
-				attention_chunk_center_bias=attention_chunk_center_bias,
-				attention_chunk_size=attention_chunk_size,
-				attention_chunk_type=self.decoder_attention_chunk_type,
-				step=None,
-				predictor_mask_chunk_hopping=mask_chunk_predictor,
-				decoder_att_look_back_factor=decoder_att_look_back_factor,
-				mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
-				target_length=ys_pad_lens,
-				is_training=self.training,
-			)
-		elif self.encoder.overlap_chunk_cls is not None:
-			encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
-			                                                                            encoder_out_lens,
-			                                                                            chunk_outs=None)
-		# 0. sampler
-		decoder_out_1st = None
-		pre_loss_att = None
-		if self.sampling_ratio > 0.0:
-			if self.step_cur < 2:
-				logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
-			if self.use_1st_decoder_loss:
-				sematic_embeds, decoder_out_1st, pre_loss_att = \
-					self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
-					                       ys_pad_lens, pre_acoustic_embeds, scama_mask)
-			else:
-				sematic_embeds, decoder_out_1st = \
-					self.sampler(encoder_out, encoder_out_lens, ys_pad,
-					             ys_pad_lens, pre_acoustic_embeds, scama_mask)
-		else:
-			if self.step_cur < 2:
-				logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
-			sematic_embeds = pre_acoustic_embeds
-		
-		# 1. Forward decoder
-		decoder_outs = self.decoder(
-			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
-		)
-		decoder_out, _ = decoder_outs[0], decoder_outs[1]
-		
-		if decoder_out_1st is None:
-			decoder_out_1st = decoder_out
-		# 2. Compute attention loss
-		loss_att = self.criterion_att(decoder_out, ys_pad)
-		acc_att = th_accuracy(
-			decoder_out_1st.view(-1, self.vocab_size),
-			ys_pad,
-			ignore_label=self.ignore_id,
-		)
-		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-		
-		# Compute cer/wer using attention-decoder
-		if self.training or self.error_calculator is None:
-			cer_att, wer_att = None, None
-		else:
-			ys_hat = decoder_out_1st.argmax(dim=-1)
-			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-		
-		return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
-	
-	def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
-		
-		tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
-		ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
-		if self.share_embedding:
-			ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
-		else:
-			ys_pad_embed = self.decoder.embed(ys_pad_masked)
-		with torch.no_grad():
-			decoder_outs = self.decoder(
-				encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
-			)
-			decoder_out, _ = decoder_outs[0], decoder_outs[1]
-			pred_tokens = decoder_out.argmax(-1)
-			nonpad_positions = ys_pad.ne(self.ignore_id)
-			seq_lens = (nonpad_positions).sum(1)
-			same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
-			input_mask = torch.ones_like(nonpad_positions)
-			bsz, seq_len = ys_pad.size()
-			for li in range(bsz):
-				target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
-				if target_num > 0:
-					input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
-			input_mask = input_mask.eq(1)
-			input_mask = input_mask.masked_fill(~nonpad_positions, False)
-			input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
-		
-		sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
-			input_mask_expand_dim, 0)
-		return sematic_embeds * tgt_mask, decoder_out * tgt_mask
-	
-
-	def calc_predictor(self, encoder_out, encoder_out_lens):
-		
-		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
-			encoder_out.device)
-		mask_chunk_predictor = None
-		if self.encoder.overlap_chunk_cls is not None:
-			mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
-			                                                                               device=encoder_out.device,
-			                                                                               batch_size=encoder_out.size(
-				                                                                               0))
-			mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
-			                                                                       batch_size=encoder_out.size(0))
-			encoder_out = encoder_out * mask_shfit_chunk
-		pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
-		                                                                                   None,
-		                                                                                   encoder_out_mask,
-		                                                                                   ignore_id=self.ignore_id,
-		                                                                                   mask_chunk_predictor=mask_chunk_predictor,
-		                                                                                   target_label_length=None,
-		                                                                                   )
-		predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
-		                                                                                     encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
-		
-		scama_mask = None
-		if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
-			encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
-			attention_chunk_center_bias = 0
-			attention_chunk_size = encoder_chunk_size
-			decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
-			mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
-				get_mask_shift_att_chunk_decoder(None,
-			                                     device=encoder_out.device,
-			                                     batch_size=encoder_out.size(0)
-			                                     )
-			scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
-				predictor_alignments=predictor_alignments,
-				encoder_sequence_length=encoder_out_lens,
-				chunk_size=1,
-				encoder_chunk_size=encoder_chunk_size,
-				attention_chunk_center_bias=attention_chunk_center_bias,
-				attention_chunk_size=attention_chunk_size,
-				attention_chunk_type=self.decoder_attention_chunk_type,
-				step=None,
-				predictor_mask_chunk_hopping=mask_chunk_predictor,
-				decoder_att_look_back_factor=decoder_att_look_back_factor,
-				mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
-				target_length=None,
-				is_training=self.training,
-			)
-		self.scama_mask = scama_mask
-		
-		return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
-	
-	def calc_predictor_chunk(self, encoder_out, cache=None):
-		
-		pre_acoustic_embeds, pre_token_length = \
-			self.predictor.forward_chunk(encoder_out, cache["encoder"])
-		return pre_acoustic_embeds, pre_token_length
-	
-	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
-		decoder_outs = self.decoder(
-			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
-		)
-		decoder_out = decoder_outs[0]
-		decoder_out = torch.log_softmax(decoder_out, dim=-1)
-		return decoder_out, ys_pad_lens
-	
-	def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
-		decoder_outs = self.decoder.forward_chunk(
-			encoder_out, sematic_embeds, cache["decoder"]
-		)
-		decoder_out = decoder_outs
-		decoder_out = torch.log_softmax(decoder_out, dim=-1)
-		return decoder_out
-
-	def generate(self,
-	             speech: torch.Tensor,
-	             speech_lengths: torch.Tensor,
-	             tokenizer=None,
-	             **kwargs,
-	             ):
-		
-		is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
-		print(is_use_ctc)
-		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
-		
-		if self.beam_search is None and (is_use_lm or is_use_ctc):
-			logging.info("enable beam_search")
-			self.init_beam_search(speech, speech_lengths, **kwargs)
-			self.nbest = kwargs.get("nbest", 1)
-		
-		# Forward Encoder
-		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-		if isinstance(encoder_out, tuple):
-			encoder_out = encoder_out[0]
-		
-		# predictor
-		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
-		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
-		                                                                predictor_outs[2], predictor_outs[3]
-		pre_token_length = pre_token_length.round().long()
-		if torch.max(pre_token_length) < 1:
-			return []
-		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
-		                                               pre_token_length)
-		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-		
-		results = []
-		b, n, d = decoder_out.size()
-		for i in range(b):
-			x = encoder_out[i, :encoder_out_lens[i], :]
-			am_scores = decoder_out[i, :pre_token_length[i], :]
-			if self.beam_search is not None:
-				nbest_hyps = self.beam_search(
-					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
-					minlenratio=kwargs.get("minlenratio", 0.0)
-				)
-				
-				nbest_hyps = nbest_hyps[: self.nbest]
-			else:
-				
-				yseq = am_scores.argmax(dim=-1)
-				score = am_scores.max(dim=-1)[0]
-				score = torch.sum(score, dim=-1)
-				# pad with mask tokens to ensure compatibility with sos/eos tokens
-				yseq = torch.tensor(
-					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
-				)
-				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-			for hyp in nbest_hyps:
-				assert isinstance(hyp, (Hypothesis)), type(hyp)
-				
-				# remove sos/eos and get results
-				last_pos = -1
-				if isinstance(hyp.yseq, list):
-					token_int = hyp.yseq[1:last_pos]
-				else:
-					token_int = hyp.yseq[1:last_pos].tolist()
-				
-				# remove blank symbol id, which is assumed to be 0
-				token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-				
-				# Change integer-ids to tokens
-				token = tokenizer.ids2tokens(token_int)
-				text = tokenizer.tokens2text(token)
-				
-				timestamp = []
-				
-				results.append((text, token, timestamp))
-		
-		return results
 
diff --git a/funasr/models/paraformer/search.py b/funasr/models/paraformer/search.py
index 8789025..250baad 100644
--- a/funasr/models/paraformer/search.py
+++ b/funasr/models/paraformer/search.py
@@ -9,7 +9,7 @@
 
 import torch
 
-from funasr.metrics import end_detect
+from funasr.metrics.common import end_detect
 from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
 from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
 
diff --git a/funasr/models/paraformer/template.yaml b/funasr/models/paraformer/template.yaml
new file mode 100644
index 0000000..1909600
--- /dev/null
+++ b/funasr/models/paraformer/template.yaml
@@ -0,0 +1,126 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# network architecture
+#model: funasr.models.paraformer.model:Paraformer
+model: Paraformer
+model_conf:
+    ctc_weight: 0.0
+    lsm_weight: 0.1
+    length_normalized_loss: true
+    predictor_weight: 1.0
+    predictor_bias: 1
+    sampling_ratio: 0.75
+
+# encoder
+encoder: SANMEncoder
+encoder_conf:
+    output_size: 512
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 50
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.1
+    input_layer: pe
+    pos_enc_class: SinusoidalPositionEncoder
+    normalize_before: true
+    kernel_size: 11
+    sanm_shfit: 0
+    selfattention_layer_type: sanm
+
+# decoder
+decoder: ParaformerSANMDecoder
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 16
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.1
+    src_attention_dropout_rate: 0.1
+    att_layer_num: 16
+    kernel_size: 11
+    sanm_shfit: 0
+
+predictor: CifPredictorV2
+predictor_conf:
+    idim: 512
+    threshold: 1.0
+    l_order: 1
+    r_order: 1
+    tail_threshold: 0.45
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 7
+    lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 6
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 150
+  val_scheduler_criterion:
+      - valid
+      - acc
+  best_model_criterion:
+  -   - valid
+      - acc
+      - max
+  keep_nbest_models: 10
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: DynamicBatchLocalShuffleSampler
+    batch_type: example # example or length
+    batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 500
+    shuffle: True
+    num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+  split_with_space: true
+
+
+input_size: 560
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+normalize: null
diff --git a/funasr/models/paraformer_online/model.py b/funasr/models/paraformer_online/model.py
new file mode 100644
index 0000000..5cbed26
--- /dev/null
+++ b/funasr/models/paraformer_online/model.py
@@ -0,0 +1,1284 @@
+import os
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+import tempfile
+import codecs
+import requests
+import re
+import copy
+import torch
+import torch.nn as nn
+import random
+import numpy as np
+import time
+# from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+	LabelSmoothingLoss,  # noqa: H301
+)
+
+from funasr.models.paraformer.cif_predictor import mae_loss
+
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import force_gatherable
+
+from funasr.models.paraformer.search import Hypothesis
+
+# from funasr.models.model_class_factory import *
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+	from torch.cuda.amp import autocast
+else:
+	# Nothing to do if torch<1.6.0
+	@contextmanager
+	def autocast(enabled=True):
+		yield
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.utils.register import registry_tables
+from funasr.models.ctc.ctc import CTC
+
+class Paraformer(nn.Module):
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+	https://arxiv.org/abs/2206.08317
+	"""
+	
+	def __init__(
+		self,
+		# token_list: Union[Tuple[str, ...], List[str]],
+		frontend: Optional[str] = None,
+		frontend_conf: Optional[Dict] = None,
+		specaug: Optional[str] = None,
+		specaug_conf: Optional[Dict] = None,
+		normalize: str = None,
+		normalize_conf: Optional[Dict] = None,
+		encoder: str = None,
+		encoder_conf: Optional[Dict] = None,
+		decoder: str = None,
+		decoder_conf: Optional[Dict] = None,
+		ctc: str = None,
+		ctc_conf: Optional[Dict] = None,
+		predictor: str = None,
+		predictor_conf: Optional[Dict] = None,
+		ctc_weight: float = 0.5,
+		input_size: int = 80,
+		vocab_size: int = -1,
+		ignore_id: int = -1,
+		blank_id: int = 0,
+		sos: int = 1,
+		eos: int = 2,
+		lsm_weight: float = 0.0,
+		length_normalized_loss: bool = False,
+		# report_cer: bool = True,
+		# report_wer: bool = True,
+		# sym_space: str = "<space>",
+		# sym_blank: str = "<blank>",
+		# extract_feats_in_collect_stats: bool = True,
+		# predictor=None,
+		predictor_weight: float = 0.0,
+		predictor_bias: int = 0,
+		sampling_ratio: float = 0.2,
+		share_embedding: bool = False,
+		# preencoder: Optional[AbsPreEncoder] = None,
+		# postencoder: Optional[AbsPostEncoder] = None,
+		use_1st_decoder_loss: bool = False,
+		**kwargs,
+	):
+
+		super().__init__()
+		
+		# import pdb;
+		# pdb.set_trace()
+		
+		if frontend is not None:
+			frontend_class = registry_tables.frontend_classes.get_class(frontend.lower())
+			frontend = frontend_class(**frontend_conf)
+		if specaug is not None:
+			specaug_class = registry_tables.specaug_classes.get_class(specaug.lower())
+			specaug = specaug_class(**specaug_conf)
+		if normalize is not None:
+			normalize_class = registry_tables.normalize_classes.get_class(normalize.lower())
+			normalize = normalize_class(**normalize_conf)
+		encoder_class = registry_tables.encoder_classes.get_class(encoder.lower())
+		encoder = encoder_class(input_size=input_size, **encoder_conf)
+		encoder_output_size = encoder.output_size()
+		if decoder is not None:
+			decoder_class = registry_tables.decoder_classes.get_class(decoder.lower())
+			decoder = decoder_class(
+				vocab_size=vocab_size,
+				encoder_output_size=encoder_output_size,
+				**decoder_conf,
+			)
+		if ctc_weight > 0.0:
+			
+			if ctc_conf is None:
+				ctc_conf = {}
+			
+			ctc = CTC(
+				odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+			)
+		if predictor is not None:
+			predictor_class = registry_tables.predictor_classes.get_class(predictor.lower())
+			predictor = predictor_class(**predictor_conf)
+		
+		# note that eos is the same as sos (equivalent ID)
+		self.blank_id = blank_id
+		self.sos = sos if sos is not None else vocab_size - 1
+		self.eos = eos if eos is not None else vocab_size - 1
+		self.vocab_size = vocab_size
+		self.ignore_id = ignore_id
+		self.ctc_weight = ctc_weight
+		# self.token_list = token_list.copy()
+		#
+		self.frontend = frontend
+		self.specaug = specaug
+		self.normalize = normalize
+		# self.preencoder = preencoder
+		# self.postencoder = postencoder
+		self.encoder = encoder
+		#
+		# if not hasattr(self.encoder, "interctc_use_conditioning"):
+		# 	self.encoder.interctc_use_conditioning = False
+		# if self.encoder.interctc_use_conditioning:
+		# 	self.encoder.conditioning_layer = torch.nn.Linear(
+		# 		vocab_size, self.encoder.output_size()
+		# 	)
+		#
+		# self.error_calculator = None
+		#
+		if ctc_weight == 1.0:
+			self.decoder = None
+		else:
+			self.decoder = decoder
+		
+		self.criterion_att = LabelSmoothingLoss(
+			size=vocab_size,
+			padding_idx=ignore_id,
+			smoothing=lsm_weight,
+			normalize_length=length_normalized_loss,
+		)
+		#
+		# if report_cer or report_wer:
+		# 	self.error_calculator = ErrorCalculator(
+		# 		token_list, sym_space, sym_blank, report_cer, report_wer
+		# 	)
+		#
+		if ctc_weight == 0.0:
+			self.ctc = None
+		else:
+			self.ctc = ctc
+		#
+		# self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+		self.predictor = predictor
+		self.predictor_weight = predictor_weight
+		self.predictor_bias = predictor_bias
+		self.sampling_ratio = sampling_ratio
+		self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+		# self.step_cur = 0
+		#
+		self.share_embedding = share_embedding
+		if self.share_embedding:
+			self.decoder.embed = None
+		
+		self.use_1st_decoder_loss = use_1st_decoder_loss
+		self.length_normalized_loss = length_normalized_loss
+		self.beam_search = None
+	
+	def forward(
+		self,
+		speech: torch.Tensor,
+		speech_lengths: torch.Tensor,
+		text: torch.Tensor,
+		text_lengths: torch.Tensor,
+		**kwargs,
+	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+		"""Encoder + Decoder + Calc loss
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				text: (Batch, Length)
+				text_lengths: (Batch,)
+		"""
+		# import pdb;
+		# pdb.set_trace()
+		if len(text_lengths.size()) > 1:
+			text_lengths = text_lengths[:, 0]
+		if len(speech_lengths.size()) > 1:
+			speech_lengths = speech_lengths[:, 0]
+		
+		batch_size = speech.shape[0]
+		
+		
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+		
+		loss_ctc, cer_ctc = None, None
+		loss_pre = None
+		stats = dict()
+		
+		# decoder: CTC branch
+		if self.ctc_weight != 0.0:
+			loss_ctc, cer_ctc = self._calc_ctc_loss(
+				encoder_out, encoder_out_lens, text, text_lengths
+			)
+			
+			# Collect CTC branch stats
+			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+			stats["cer_ctc"] = cer_ctc
+		
+
+		# decoder: Attention decoder branch
+		loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
+			encoder_out, encoder_out_lens, text, text_lengths
+		)
+		
+		# 3. CTC-Att loss definition
+		if self.ctc_weight == 0.0:
+			loss = loss_att + loss_pre * self.predictor_weight
+		else:
+			loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+		
+		
+		# Collect Attn branch stats
+		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+		stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+		stats["acc"] = acc_att
+		stats["cer"] = cer_att
+		stats["wer"] = wer_att
+		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+		
+		stats["loss"] = torch.clone(loss.detach())
+		
+		# force_gatherable: to-device and to-tensor if scalar for DataParallel
+		if self.length_normalized_loss:
+			batch_size = (text_lengths + self.predictor_bias).sum()
+		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+		return loss, stats, weight
+	
+
+	def encode(
+		self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+	) -> Tuple[torch.Tensor, torch.Tensor]:
+		"""Frontend + Encoder. Note that this method is used by asr_inference.py
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				ind: int
+		"""
+		with autocast(False):
+
+			# Data augmentation
+			if self.specaug is not None and self.training:
+				speech, speech_lengths = self.specaug(speech, speech_lengths)
+			
+			# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+			if self.normalize is not None:
+				speech, speech_lengths = self.normalize(speech, speech_lengths)
+		
+
+		# Forward encoder
+		encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+
+		return encoder_out, encoder_out_lens
+	
+	def calc_predictor(self, encoder_out, encoder_out_lens):
+		
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
+		                                                                               encoder_out_mask,
+		                                                                               ignore_id=self.ignore_id)
+		return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+	
+	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+		
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+		)
+		decoder_out = decoder_outs[0]
+		decoder_out = torch.log_softmax(decoder_out, dim=-1)
+		return decoder_out, ys_pad_lens
+
+	def _calc_att_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+		                                                                          ignore_id=self.ignore_id)
+		
+		# 0. sampler
+		decoder_out_1st = None
+		pre_loss_att = None
+		if self.sampling_ratio > 0.0:
+
+			sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+			                                               pre_acoustic_embeds)
+		else:
+			sematic_embeds = pre_acoustic_embeds
+		
+		# 1. Forward decoder
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+		)
+		decoder_out, _ = decoder_outs[0], decoder_outs[1]
+		
+		if decoder_out_1st is None:
+			decoder_out_1st = decoder_out
+		# 2. Compute attention loss
+		loss_att = self.criterion_att(decoder_out, ys_pad)
+		acc_att = th_accuracy(
+			decoder_out_1st.view(-1, self.vocab_size),
+			ys_pad,
+			ignore_label=self.ignore_id,
+		)
+		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		
+		# Compute cer/wer using attention-decoder
+		if self.training or self.error_calculator is None:
+			cer_att, wer_att = None, None
+		else:
+			ys_hat = decoder_out_1st.argmax(dim=-1)
+			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+		
+		return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+	
+	def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
+		
+		tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+		ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+		if self.share_embedding:
+			ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+		else:
+			ys_pad_embed = self.decoder.embed(ys_pad_masked)
+		with torch.no_grad():
+			decoder_outs = self.decoder(
+				encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
+			)
+			decoder_out, _ = decoder_outs[0], decoder_outs[1]
+			pred_tokens = decoder_out.argmax(-1)
+			nonpad_positions = ys_pad.ne(self.ignore_id)
+			seq_lens = (nonpad_positions).sum(1)
+			same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+			input_mask = torch.ones_like(nonpad_positions)
+			bsz, seq_len = ys_pad.size()
+			for li in range(bsz):
+				target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+				if target_num > 0:
+					input_mask[li].scatter_(dim=0,
+					                        index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
+					                        value=0)
+			input_mask = input_mask.eq(1)
+			input_mask = input_mask.masked_fill(~nonpad_positions, False)
+			input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+		
+		sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+			input_mask_expand_dim, 0)
+		return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+		
+	def _calc_ctc_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		# Calc CTC loss
+		loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+		
+		# Calc CER using CTC
+		cer_ctc = None
+		if not self.training and self.error_calculator is not None:
+			ys_hat = self.ctc.argmax(encoder_out).data
+			cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+		return loss_ctc, cer_ctc
+
+	
+	def init_beam_search(self,
+	                     **kwargs,
+	                     ):
+		from funasr.models.paraformer.search import BeamSearchPara
+		from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+		from funasr.models.transformer.scorers.length_bonus import LengthBonus
+	
+		# 1. Build ASR model
+		scorers = {}
+		
+		if self.ctc != None:
+			ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+			scorers.update(
+				ctc=ctc
+			)
+		token_list = kwargs.get("token_list")
+		scorers.update(
+			length_bonus=LengthBonus(len(token_list)),
+		)
+
+		
+		# 3. Build ngram model
+		# ngram is not supported now
+		ngram = None
+		scorers["ngram"] = ngram
+		
+		weights = dict(
+			decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+			ctc=kwargs.get("decoding_ctc_weight", 0.0),
+			lm=kwargs.get("lm_weight", 0.0),
+			ngram=kwargs.get("ngram_weight", 0.0),
+			length_bonus=kwargs.get("penalty", 0.0),
+		)
+		beam_search = BeamSearchPara(
+			beam_size=kwargs.get("beam_size", 2),
+			weights=weights,
+			scorers=scorers,
+			sos=self.sos,
+			eos=self.eos,
+			vocab_size=len(token_list),
+			token_list=token_list,
+			pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+		)
+		# beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+		# for scorer in scorers.values():
+		# 	if isinstance(scorer, torch.nn.Module):
+		# 		scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+		self.beam_search = beam_search
+		
+	def generate(self,
+             data_in: list,
+             data_lengths: list=None,
+             key: list=None,
+             tokenizer=None,
+             **kwargs,
+             ):
+		
+		# init beamsearch
+		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		if self.beam_search is None and (is_use_lm or is_use_ctc):
+			logging.info("enable beam_search")
+			self.init_beam_search(**kwargs)
+			self.nbest = kwargs.get("nbest", 1)
+		
+		meta_data = {}
+		# extract fbank feats
+		time1 = time.perf_counter()
+		audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+		time2 = time.perf_counter()
+		meta_data["load_data"] = f"{time2 - time1:0.3f}"
+		speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
+		time3 = time.perf_counter()
+		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+		meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+		
+		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+		
+		# predictor
+		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+		                                                                predictor_outs[2], predictor_outs[3]
+		pre_token_length = pre_token_length.round().long()
+		if torch.max(pre_token_length) < 1:
+			return []
+		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+		                                                         pre_token_length)
+		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+
+		results = []
+		b, n, d = decoder_out.size()
+		for i in range(b):
+			x = encoder_out[i, :encoder_out_lens[i], :]
+			am_scores = decoder_out[i, :pre_token_length[i], :]
+			if self.beam_search is not None:
+				nbest_hyps = self.beam_search(
+					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+				)
+				
+				nbest_hyps = nbest_hyps[: self.nbest]
+			else:
+
+				yseq = am_scores.argmax(dim=-1)
+				score = am_scores.max(dim=-1)[0]
+				score = torch.sum(score, dim=-1)
+				# pad with mask tokens to ensure compatibility with sos/eos tokens
+				yseq = torch.tensor(
+					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+				)
+				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+			for nbest_idx, hyp in enumerate(nbest_hyps):
+				ibest_writer = None
+				if ibest_writer is None and kwargs.get("output_dir") is not None:
+					writer = DatadirWriter(kwargs.get("output_dir"))
+					ibest_writer = writer[f"{nbest_idx+1}best_recog"]
+				# remove sos/eos and get results
+				last_pos = -1
+				if isinstance(hyp.yseq, list):
+					token_int = hyp.yseq[1:last_pos]
+				else:
+					token_int = hyp.yseq[1:last_pos].tolist()
+					
+				# remove blank symbol id, which is assumed to be 0
+				token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+				
+				# Change integer-ids to tokens
+				token = tokenizer.ids2tokens(token_int)
+				text = tokenizer.tokens2text(token)
+				
+				text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+				result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+				results.append(result_i)
+				
+				if ibest_writer is not None:
+					ibest_writer["token"][key[i]] = " ".join(token)
+					ibest_writer["text"][key[i]] = text
+					ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+		
+		return results, meta_data
+
+
+
+class BiCifParaformer(Paraformer):
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+	https://arxiv.org/abs/2206.08317
+	"""
+	
+	def __init__(
+		self,
+		*args,
+		**kwargs,
+	):
+		super().__init__(*args, **kwargs)
+
+
+	def _calc_pre2_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		_, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
+		
+		# loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
+		
+		return loss_pre2
+	
+	
+	def _calc_att_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
+		                                                                             encoder_out_mask,
+		                                                                             ignore_id=self.ignore_id)
+		
+		# 0. sampler
+		decoder_out_1st = None
+		if self.sampling_ratio > 0.0:
+			sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+			                                               pre_acoustic_embeds)
+		else:
+			sematic_embeds = pre_acoustic_embeds
+		
+		# 1. Forward decoder
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+		)
+		decoder_out, _ = decoder_outs[0], decoder_outs[1]
+		
+		if decoder_out_1st is None:
+			decoder_out_1st = decoder_out
+		# 2. Compute attention loss
+		loss_att = self.criterion_att(decoder_out, ys_pad)
+		acc_att = th_accuracy(
+			decoder_out_1st.view(-1, self.vocab_size),
+			ys_pad,
+			ignore_label=self.ignore_id,
+		)
+		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		
+		# Compute cer/wer using attention-decoder
+		if self.training or self.error_calculator is None:
+			cer_att, wer_att = None, None
+		else:
+			ys_hat = decoder_out_1st.argmax(dim=-1)
+			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+		
+		return loss_att, acc_att, cer_att, wer_att, loss_pre
+
+
+	def calc_predictor(self, encoder_out, encoder_out_lens):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+		                                                                                                  None,
+		                                                                                                  encoder_out_mask,
+		                                                                                                  ignore_id=self.ignore_id)
+		return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+
+	def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
+		                                                                                    encoder_out_mask,
+		                                                                                    token_num)
+		return ds_alphas, ds_cif_peak, us_alphas, us_peaks
+	
+	
+	def forward(
+		self,
+		speech: torch.Tensor,
+		speech_lengths: torch.Tensor,
+		text: torch.Tensor,
+		text_lengths: torch.Tensor,
+		**kwargs,
+	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+		"""Frontend + Encoder + Decoder + Calc loss
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				text: (Batch, Length)
+				text_lengths: (Batch,)
+		"""
+		if len(text_lengths.size()) > 1:
+			text_lengths = text_lengths[:, 0]
+		if len(speech_lengths.size()) > 1:
+			speech_lengths = speech_lengths[:, 0]
+		
+		batch_size = speech.shape[0]
+		
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+
+		loss_ctc, cer_ctc = None, None
+		loss_pre = None
+		stats = dict()
+		
+		# decoder: CTC branch
+		if self.ctc_weight != 0.0:
+			loss_ctc, cer_ctc = self._calc_ctc_loss(
+				encoder_out, encoder_out_lens, text, text_lengths
+			)
+			
+			# Collect CTC branch stats
+			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+			stats["cer_ctc"] = cer_ctc
+
+
+		# decoder: Attention decoder branch
+		loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+			encoder_out, encoder_out_lens, text, text_lengths
+		)
+		
+		loss_pre2 = self._calc_pre2_loss(
+			encoder_out, encoder_out_lens, text, text_lengths
+		)
+		
+		# 3. CTC-Att loss definition
+		if self.ctc_weight == 0.0:
+			loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+		else:
+			loss = self.ctc_weight * loss_ctc + (
+				1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+		
+		# Collect Attn branch stats
+		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+		stats["acc"] = acc_att
+		stats["cer"] = cer_att
+		stats["wer"] = wer_att
+		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+		stats["loss_pre2"] = loss_pre2.detach().cpu()
+		
+		stats["loss"] = torch.clone(loss.detach())
+		
+		# force_gatherable: to-device and to-tensor if scalar for DataParallel
+		if self.length_normalized_loss:
+			batch_size = int((text_lengths + self.predictor_bias).sum())
+		
+		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+		return loss, stats, weight
+	
+	def generate(self,
+	             data_in: list,
+	             data_lengths: list = None,
+	             key: list = None,
+	             tokenizer=None,
+	             **kwargs,
+	             ):
+		
+		# init beamsearch
+		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		if self.beam_search is None and (is_use_lm or is_use_ctc):
+			logging.info("enable beam_search")
+			self.init_beam_search(**kwargs)
+			self.nbest = kwargs.get("nbest", 1)
+		
+		meta_data = {}
+		# extract fbank feats
+		time1 = time.perf_counter()
+		audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+		time2 = time.perf_counter()
+		meta_data["load_data"] = f"{time2 - time1:0.3f}"
+		speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+		                                       frontend=self.frontend)
+		time3 = time.perf_counter()
+		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+		meta_data[
+			"batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+		
+		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+		
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+		
+		# predictor
+		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+		                                                                predictor_outs[2], predictor_outs[3]
+		pre_token_length = pre_token_length.round().long()
+		if torch.max(pre_token_length) < 1:
+			return []
+		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+		                                               pre_token_length)
+		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+		
+		# BiCifParaformer, test no bias cif2
+
+		_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
+			                                                                    pre_token_length)
+		
+		results = []
+		b, n, d = decoder_out.size()
+		for i in range(b):
+			x = encoder_out[i, :encoder_out_lens[i], :]
+			am_scores = decoder_out[i, :pre_token_length[i], :]
+			if self.beam_search is not None:
+				nbest_hyps = self.beam_search(
+					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+					minlenratio=kwargs.get("minlenratio", 0.0)
+				)
+				
+				nbest_hyps = nbest_hyps[: self.nbest]
+			else:
+				
+				yseq = am_scores.argmax(dim=-1)
+				score = am_scores.max(dim=-1)[0]
+				score = torch.sum(score, dim=-1)
+				# pad with mask tokens to ensure compatibility with sos/eos tokens
+				yseq = torch.tensor(
+					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+				)
+				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+			for nbest_idx, hyp in enumerate(nbest_hyps):
+				ibest_writer = None
+				if ibest_writer is None and kwargs.get("output_dir") is not None:
+					writer = DatadirWriter(kwargs.get("output_dir"))
+					ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+				# remove sos/eos and get results
+				last_pos = -1
+				if isinstance(hyp.yseq, list):
+					token_int = hyp.yseq[1:last_pos]
+				else:
+					token_int = hyp.yseq[1:last_pos].tolist()
+				
+				# remove blank symbol id, which is assumed to be 0
+				token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+				
+				# Change integer-ids to tokens
+				token = tokenizer.ids2tokens(token_int)
+				text = tokenizer.tokens2text(token)
+				
+				_, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
+				                                           us_peaks[i][:encoder_out_lens[i] * 3],
+				                                           copy.copy(token),
+				                                           vad_offset=kwargs.get("begin_time", 0))
+				
+				text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token, timestamp)
+				
+				result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
+				            "time_stamp_postprocessed": time_stamp_postprocessed,
+				            "word_lists": word_lists
+				            }
+				results.append(result_i)
+				
+				if ibest_writer is not None:
+					ibest_writer["token"][key[i]] = " ".join(token)
+					ibest_writer["text"][key[i]] = text
+					ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+					
+		
+		return results, meta_data
+
+
+class ParaformerOnline(Paraformer):
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+	https://arxiv.org/abs/2206.08317
+	"""
+	
+	def __init__(
+		self,
+		*args,
+		**kwargs,
+	):
+		
+		super().__init__(*args, **kwargs)
+		
+		# import pdb;
+		# pdb.set_trace()
+		self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
+
+
+		self.scama_mask = None
+		if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
+			from funasr.models.scama.chunk_utilis import build_scama_mask_for_cross_attention_decoder
+			self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
+			self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
+
+
+	
+	def forward(
+		self,
+		speech: torch.Tensor,
+		speech_lengths: torch.Tensor,
+		text: torch.Tensor,
+		text_lengths: torch.Tensor,
+		**kwargs,
+	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+		"""Encoder + Decoder + Calc loss
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				text: (Batch, Length)
+				text_lengths: (Batch,)
+		"""
+		# import pdb;
+		# pdb.set_trace()
+		decoding_ind = kwargs.get("decoding_ind")
+		if len(text_lengths.size()) > 1:
+			text_lengths = text_lengths[:, 0]
+		if len(speech_lengths.size()) > 1:
+			speech_lengths = speech_lengths[:, 0]
+		
+		batch_size = speech.shape[0]
+		
+		# Encoder
+		if hasattr(self.encoder, "overlap_chunk_cls"):
+			ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+			encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+		else:
+			encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		
+		loss_ctc, cer_ctc = None, None
+		loss_pre = None
+		stats = dict()
+		
+		# decoder: CTC branch
+
+		if self.ctc_weight > 0.0:
+			if hasattr(self.encoder, "overlap_chunk_cls"):
+				encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+				                                                                                    encoder_out_lens,
+				                                                                                    chunk_outs=None)
+			else:
+				encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
+				
+			loss_ctc, cer_ctc = self._calc_ctc_loss(
+				encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
+			)
+			# Collect CTC branch stats
+			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+			stats["cer_ctc"] = cer_ctc
+		
+		# decoder: Attention decoder branch
+		loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
+			encoder_out, encoder_out_lens, text, text_lengths
+		)
+		
+		# 3. CTC-Att loss definition
+		if self.ctc_weight == 0.0:
+			loss = loss_att + loss_pre * self.predictor_weight
+		else:
+			loss = self.ctc_weight * loss_ctc + (
+					1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+		
+		# Collect Attn branch stats
+		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+		stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+		stats["acc"] = acc_att
+		stats["cer"] = cer_att
+		stats["wer"] = wer_att
+		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+		
+		stats["loss"] = torch.clone(loss.detach())
+		
+		# force_gatherable: to-device and to-tensor if scalar for DataParallel
+		if self.length_normalized_loss:
+			batch_size = (text_lengths + self.predictor_bias).sum()
+		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+		return loss, stats, weight
+	
+	def encode_chunk(
+		self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
+	) -> Tuple[torch.Tensor, torch.Tensor]:
+		"""Frontend + Encoder. Note that this method is used by asr_inference.py
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				ind: int
+		"""
+		with autocast(False):
+			
+			# Data augmentation
+			if self.specaug is not None and self.training:
+				speech, speech_lengths = self.specaug(speech, speech_lengths)
+			
+			# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+			if self.normalize is not None:
+				speech, speech_lengths = self.normalize(speech, speech_lengths)
+		
+		# Forward encoder
+		encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+		
+		return encoder_out, torch.tensor([encoder_out.size(1)])
+	
+	def _calc_att_predictor_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		mask_chunk_predictor = None
+		if self.encoder.overlap_chunk_cls is not None:
+			mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+			                                                                               device=encoder_out.device,
+			                                                                               batch_size=encoder_out.size(
+				                                                                               0))
+			mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+			                                                                       batch_size=encoder_out.size(0))
+			encoder_out = encoder_out * mask_shfit_chunk
+		pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
+		                                                                      ys_pad,
+		                                                                      encoder_out_mask,
+		                                                                      ignore_id=self.ignore_id,
+		                                                                      mask_chunk_predictor=mask_chunk_predictor,
+		                                                                      target_label_length=ys_pad_lens,
+		                                                                      )
+		predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+		                                                                                     encoder_out_lens)
+		
+		scama_mask = None
+		if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+			encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+			attention_chunk_center_bias = 0
+			attention_chunk_size = encoder_chunk_size
+			decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+			mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
+				get_mask_shift_att_chunk_decoder(None,
+			                                     device=encoder_out.device,
+			                                     batch_size=encoder_out.size(0)
+			                                     )
+			scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+				predictor_alignments=predictor_alignments,
+				encoder_sequence_length=encoder_out_lens,
+				chunk_size=1,
+				encoder_chunk_size=encoder_chunk_size,
+				attention_chunk_center_bias=attention_chunk_center_bias,
+				attention_chunk_size=attention_chunk_size,
+				attention_chunk_type=self.decoder_attention_chunk_type,
+				step=None,
+				predictor_mask_chunk_hopping=mask_chunk_predictor,
+				decoder_att_look_back_factor=decoder_att_look_back_factor,
+				mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+				target_length=ys_pad_lens,
+				is_training=self.training,
+			)
+		elif self.encoder.overlap_chunk_cls is not None:
+			encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+			                                                                            encoder_out_lens,
+			                                                                            chunk_outs=None)
+		# 0. sampler
+		decoder_out_1st = None
+		pre_loss_att = None
+		if self.sampling_ratio > 0.0:
+			if self.step_cur < 2:
+				logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+			if self.use_1st_decoder_loss:
+				sematic_embeds, decoder_out_1st, pre_loss_att = \
+					self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
+					                       ys_pad_lens, pre_acoustic_embeds, scama_mask)
+			else:
+				sematic_embeds, decoder_out_1st = \
+					self.sampler(encoder_out, encoder_out_lens, ys_pad,
+					             ys_pad_lens, pre_acoustic_embeds, scama_mask)
+		else:
+			if self.step_cur < 2:
+				logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+			sematic_embeds = pre_acoustic_embeds
+		
+		# 1. Forward decoder
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
+		)
+		decoder_out, _ = decoder_outs[0], decoder_outs[1]
+		
+		if decoder_out_1st is None:
+			decoder_out_1st = decoder_out
+		# 2. Compute attention loss
+		loss_att = self.criterion_att(decoder_out, ys_pad)
+		acc_att = th_accuracy(
+			decoder_out_1st.view(-1, self.vocab_size),
+			ys_pad,
+			ignore_label=self.ignore_id,
+		)
+		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		
+		# Compute cer/wer using attention-decoder
+		if self.training or self.error_calculator is None:
+			cer_att, wer_att = None, None
+		else:
+			ys_hat = decoder_out_1st.argmax(dim=-1)
+			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+		
+		return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+	
+	def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
+		
+		tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+		ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+		if self.share_embedding:
+			ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+		else:
+			ys_pad_embed = self.decoder.embed(ys_pad_masked)
+		with torch.no_grad():
+			decoder_outs = self.decoder(
+				encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
+			)
+			decoder_out, _ = decoder_outs[0], decoder_outs[1]
+			pred_tokens = decoder_out.argmax(-1)
+			nonpad_positions = ys_pad.ne(self.ignore_id)
+			seq_lens = (nonpad_positions).sum(1)
+			same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+			input_mask = torch.ones_like(nonpad_positions)
+			bsz, seq_len = ys_pad.size()
+			for li in range(bsz):
+				target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+				if target_num > 0:
+					input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+			input_mask = input_mask.eq(1)
+			input_mask = input_mask.masked_fill(~nonpad_positions, False)
+			input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+		
+		sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+			input_mask_expand_dim, 0)
+		return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+	
+
+	def calc_predictor(self, encoder_out, encoder_out_lens):
+		
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		mask_chunk_predictor = None
+		if self.encoder.overlap_chunk_cls is not None:
+			mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+			                                                                               device=encoder_out.device,
+			                                                                               batch_size=encoder_out.size(
+				                                                                               0))
+			mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+			                                                                       batch_size=encoder_out.size(0))
+			encoder_out = encoder_out * mask_shfit_chunk
+		pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
+		                                                                                   None,
+		                                                                                   encoder_out_mask,
+		                                                                                   ignore_id=self.ignore_id,
+		                                                                                   mask_chunk_predictor=mask_chunk_predictor,
+		                                                                                   target_label_length=None,
+		                                                                                   )
+		predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+		                                                                                     encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
+		
+		scama_mask = None
+		if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+			encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+			attention_chunk_center_bias = 0
+			attention_chunk_size = encoder_chunk_size
+			decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+			mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
+				get_mask_shift_att_chunk_decoder(None,
+			                                     device=encoder_out.device,
+			                                     batch_size=encoder_out.size(0)
+			                                     )
+			scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+				predictor_alignments=predictor_alignments,
+				encoder_sequence_length=encoder_out_lens,
+				chunk_size=1,
+				encoder_chunk_size=encoder_chunk_size,
+				attention_chunk_center_bias=attention_chunk_center_bias,
+				attention_chunk_size=attention_chunk_size,
+				attention_chunk_type=self.decoder_attention_chunk_type,
+				step=None,
+				predictor_mask_chunk_hopping=mask_chunk_predictor,
+				decoder_att_look_back_factor=decoder_att_look_back_factor,
+				mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+				target_length=None,
+				is_training=self.training,
+			)
+		self.scama_mask = scama_mask
+		
+		return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
+	
+	def calc_predictor_chunk(self, encoder_out, cache=None):
+		
+		pre_acoustic_embeds, pre_token_length = \
+			self.predictor.forward_chunk(encoder_out, cache["encoder"])
+		return pre_acoustic_embeds, pre_token_length
+	
+	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
+		)
+		decoder_out = decoder_outs[0]
+		decoder_out = torch.log_softmax(decoder_out, dim=-1)
+		return decoder_out, ys_pad_lens
+	
+	def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+		decoder_outs = self.decoder.forward_chunk(
+			encoder_out, sematic_embeds, cache["decoder"]
+		)
+		decoder_out = decoder_outs
+		decoder_out = torch.log_softmax(decoder_out, dim=-1)
+		return decoder_out
+
+	def generate(self,
+	             speech: torch.Tensor,
+	             speech_lengths: torch.Tensor,
+	             tokenizer=None,
+	             **kwargs,
+	             ):
+		
+		is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
+		print(is_use_ctc)
+		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		
+		if self.beam_search is None and (is_use_lm or is_use_ctc):
+			logging.info("enable beam_search")
+			self.init_beam_search(speech, speech_lengths, **kwargs)
+			self.nbest = kwargs.get("nbest", 1)
+		
+		# Forward Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+		
+		# predictor
+		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+		                                                                predictor_outs[2], predictor_outs[3]
+		pre_token_length = pre_token_length.round().long()
+		if torch.max(pre_token_length) < 1:
+			return []
+		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+		                                               pre_token_length)
+		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+		
+		results = []
+		b, n, d = decoder_out.size()
+		for i in range(b):
+			x = encoder_out[i, :encoder_out_lens[i], :]
+			am_scores = decoder_out[i, :pre_token_length[i], :]
+			if self.beam_search is not None:
+				nbest_hyps = self.beam_search(
+					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+					minlenratio=kwargs.get("minlenratio", 0.0)
+				)
+				
+				nbest_hyps = nbest_hyps[: self.nbest]
+			else:
+				
+				yseq = am_scores.argmax(dim=-1)
+				score = am_scores.max(dim=-1)[0]
+				score = torch.sum(score, dim=-1)
+				# pad with mask tokens to ensure compatibility with sos/eos tokens
+				yseq = torch.tensor(
+					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+				)
+				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+			for hyp in nbest_hyps:
+				assert isinstance(hyp, (Hypothesis)), type(hyp)
+				
+				# remove sos/eos and get results
+				last_pos = -1
+				if isinstance(hyp.yseq, list):
+					token_int = hyp.yseq[1:last_pos]
+				else:
+					token_int = hyp.yseq[1:last_pos].tolist()
+				
+				# remove blank symbol id, which is assumed to be 0
+				token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+				
+				# Change integer-ids to tokens
+				token = tokenizer.ids2tokens(token_int)
+				text = tokenizer.tokens2text(token)
+				
+				timestamp = []
+				
+				results.append((text, token, timestamp))
+		
+		return results
+
diff --git a/funasr/models/paraformer_online/sanm_decoder.py b/funasr/models/paraformer_online/sanm_decoder.py
new file mode 100644
index 0000000..b1e94d7
--- /dev/null
+++ b/funasr/models/paraformer_online/sanm_decoder.py
@@ -0,0 +1,507 @@
+from typing import List
+from typing import Tuple
+import logging
+import torch
+import torch.nn as nn
+import numpy as np
+
+from funasr.models.scama import utils as myutils
+from funasr.models.transformer.decoder import BaseTransformerDecoder
+
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayerSANM(nn.Module):
+    """Single decoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        src_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+    """
+
+    def __init__(
+        self,
+        size,
+        self_attn,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(DecoderLayerSANM, self).__init__()
+        self.size = size
+        self.self_attn = self_attn
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        if self_attn is not None:
+            self.norm2 = LayerNorm(size)
+        if src_attn is not None:
+            self.norm3 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        # tgt = self.dropout(tgt)
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, _ = self.self_attn(tgt, tgt_mask)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        # tgt = self.dropout(tgt)
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            if self.training:
+                cache = None
+            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
+            x = residual + x
+
+        return x, memory, fsmn_cache, opt_cache
+
+
+@register_class("decoder_classes", "ParaformerSANMDecoder")
+class ParaformerSANMDecoder(BaseTransformerDecoder):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2006.01713
+    """
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        att_layer_num: int = 6,
+        kernel_size: int = 21,
+        sanm_shfit: int = 0,
+        lora_list: List[str] = None,
+        lora_rank: int = 8,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.1,
+        chunk_multiply_factor: tuple = (1,),
+        tf2torch_tensor_name_prefix_torch: str = "decoder",
+        tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+    ):
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+
+        if input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(vocab_size, attention_dim),
+                # pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        elif input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(vocab_size, attention_dim),
+                torch.nn.LayerNorm(attention_dim),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        else:
+            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+        self.normalize_before = normalize_before
+        if self.normalize_before:
+            self.after_norm = LayerNorm(attention_dim)
+        if use_output_layer:
+            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+        else:
+            self.output_layer = None
+
+        self.att_layer_num = att_layer_num
+        self.num_blocks = num_blocks
+        if sanm_shfit is None:
+            sanm_shfit = (kernel_size - 1) // 2
+        self.decoders = repeat(
+            att_layer_num,
+            lambda lnum: DecoderLayerSANM(
+                attention_dim,
+                MultiHeadedAttentionSANMDecoder(
+                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+                ),
+                MultiHeadedAttentionCrossAtt(
+                    attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
+                ),
+                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if num_blocks - att_layer_num <= 0:
+            self.decoders2 = None
+        else:
+            self.decoders2 = repeat(
+                num_blocks - att_layer_num,
+                lambda lnum: DecoderLayerSANM(
+                    attention_dim,
+                    MultiHeadedAttentionSANMDecoder(
+                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+                    ),
+                    None,
+                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                    dropout_rate,
+                    normalize_before,
+                    concat_after,
+                ),
+            )
+
+        self.decoders3 = repeat(
+            1,
+            lambda lnum: DecoderLayerSANM(
+                attention_dim,
+                None,
+                None,
+                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+        self.chunk_multiply_factor = chunk_multiply_factor
+
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        chunk_mask: torch.Tensor = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        tgt = ys_in_pad
+        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+        
+        memory = hs_pad
+        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        if chunk_mask is not None:
+            memory_mask = memory_mask * chunk_mask
+            if tgt_mask.size(1) != memory_mask.size(1):
+                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
+
+        x = tgt
+        x, tgt_mask, memory, memory_mask, _ = self.decoders(
+            x, tgt_mask, memory, memory_mask
+        )
+        if self.decoders2 is not None:
+            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
+                x, tgt_mask, memory, memory_mask
+            )
+        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
+            x, tgt_mask, memory, memory_mask
+        )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+
+        olens = tgt_mask.sum(1)
+        return x, olens
+
+    def score(self, ys, state, x):
+        """Score."""
+        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+        logp, state = self.forward_one_step(
+            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
+        )
+        return logp.squeeze(0), state
+
+    def forward_chunk(
+        self,
+        memory: torch.Tensor,
+        tgt: torch.Tensor,
+        cache: dict = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        x = tgt
+        if cache["decode_fsmn"] is None:
+            cache_layer_num = len(self.decoders)
+            if self.decoders2 is not None:
+                cache_layer_num += len(self.decoders2)
+            fsmn_cache = [None] * cache_layer_num
+        else:
+            fsmn_cache = cache["decode_fsmn"]
+
+        if cache["opt"] is None:
+            cache_layer_num = len(self.decoders)
+            opt_cache = [None] * cache_layer_num
+        else:
+            opt_cache = cache["opt"]
+
+        for i in range(self.att_layer_num):
+            decoder = self.decoders[i]
+            x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
+                x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
+                chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
+            )
+
+        if self.num_blocks - self.att_layer_num > 1:
+            for i in range(self.num_blocks - self.att_layer_num):
+                j = i + self.att_layer_num
+                decoder = self.decoders2[i]
+                x, memory, fsmn_cache[j], _  = decoder.forward_chunk(
+                    x, memory, fsmn_cache=fsmn_cache[j]
+                )
+
+        for decoder in self.decoders3:
+            x, memory, _, _ = decoder.forward_chunk(
+                x, memory
+            )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+
+        cache["decode_fsmn"] = fsmn_cache
+        if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
+            cache["opt"] = opt_cache
+        return x
+
+    def forward_one_step(
+        self,
+        tgt: torch.Tensor,
+        tgt_mask: torch.Tensor,
+        memory: torch.Tensor,
+        cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Forward one step.
+
+        Args:
+            tgt: input token ids, int64 (batch, maxlen_out)
+            tgt_mask: input token mask,  (batch, maxlen_out)
+                      dtype=torch.uint8 in PyTorch 1.2-
+                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+            memory: encoded memory, float32  (batch, maxlen_in, feat)
+            cache: cached output list of (batch, max_time_out-1, size)
+        Returns:
+            y, cache: NN output value and cache per `self.decoders`.
+            y.shape` is (batch, maxlen_out, token)
+        """
+        x = self.embed(tgt)
+        if cache is None:
+            cache_layer_num = len(self.decoders)
+            if self.decoders2 is not None:
+                cache_layer_num += len(self.decoders2)
+            cache = [None] * cache_layer_num
+        new_cache = []
+        # for c, decoder in zip(cache, self.decoders):
+        for i in range(self.att_layer_num):
+            decoder = self.decoders[i]
+            c = cache[i]
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                x, tgt_mask, memory, None, cache=c
+            )
+            new_cache.append(c_ret)
+
+        if self.num_blocks - self.att_layer_num > 1:
+            for i in range(self.num_blocks - self.att_layer_num):
+                j = i + self.att_layer_num
+                decoder = self.decoders2[i]
+                c = cache[j]
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                    x, tgt_mask, memory, None, cache=c
+                )
+                new_cache.append(c_ret)
+
+        for decoder in self.decoders3:
+
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
+                x, tgt_mask, memory, None, cache=None
+            )
+
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.output_layer is not None:
+            y = torch.log_softmax(self.output_layer(y), dim=-1)
+
+        return y, new_cache
+
diff --git a/funasr/models/sa_asr/attention.py b/funasr/models/sa_asr/attention.py
new file mode 100644
index 0000000..2cce9ec
--- /dev/null
+++ b/funasr/models/sa_asr/attention.py
@@ -0,0 +1,51 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+from typing import Optional, Tuple
+
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+import funasr.models.lora.layers as lora
+
+
+
+class CosineDistanceAttention(nn.Module):
+    """ Compute Cosine Distance between spk decoder output and speaker profile 
+    Args:
+        profile_path: speaker profile file path (.npy file)
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, spk_decoder_out, profile, profile_lens=None):
+        """
+        Args:
+            spk_decoder_out(torch.Tensor):(B, L, D)
+            spk_profiles(torch.Tensor):(B, N, D)
+        """
+        x = spk_decoder_out.unsqueeze(2)  # (B, L, 1, D)
+        if profile_lens is not None:
+            
+            mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
+            )
+            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
+            weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0)  # (B, L, N)
+        else:
+            x = x[:, -1:, :, :]
+            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
+            weights = self.softmax(weights_not_softmax)  # (B, 1, N)
+        spk_embedding = torch.matmul(weights, profile.to(weights.device))  # (B, L, D)
+
+        return spk_embedding, weights
diff --git a/funasr/models/sa_asr/e2e_sa_asr.py b/funasr/models/sa_asr/e2e_sa_asr.py
index e0cb69a..3bfecfc 100644
--- a/funasr/models/sa_asr/e2e_sa_asr.py
+++ b/funasr/models/sa_asr/e2e_sa_asr.py
@@ -20,13 +20,13 @@
 from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
 from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.metrics import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.models.base_model import FunASRModel
 
diff --git a/funasr/models/transformer/transformer_decoder.py b/funasr/models/sa_asr/transformer_decoder.py
similarity index 65%
rename from funasr/models/transformer/transformer_decoder.py
rename to funasr/models/sa_asr/transformer_decoder.py
index b2bea68..3319212 100644
--- a/funasr/models/transformer/transformer_decoder.py
+++ b/funasr/models/sa_asr/transformer_decoder.py
@@ -10,9 +10,9 @@
 import torch
 from torch import nn
 
-from funasr.models.decoder.abs_decoder import AbsDecoder
+
 from funasr.models.transformer.attention import MultiHeadedAttention
-from funasr.models.transformer.attention import CosineDistanceAttention
+from funasr.models.sa_asr.attention import CosineDistanceAttention
 from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
 from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
 from funasr.models.transformer.embedding import PositionalEncoding
@@ -24,9 +24,10 @@
 from funasr.models.transformer.positionwise_feed_forward import (
     PositionwiseFeedForward,  # noqa: H301
 )
-from funasr.models.transformer.repeat import repeat
+from funasr.models.transformer.utils.repeat import repeat
 from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
 
+from funasr.utils.register import register_class, registry_tables
 
 class DecoderLayer(nn.Module):
     """Single decoder layer module.
@@ -150,7 +151,7 @@
         return x, tgt_mask, memory, memory_mask
 
 
-class BaseTransformerDecoder(AbsDecoder, BatchScorerInterface):
+class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
     """Base class of Transfomer decoder module.
 
     Args:
@@ -352,7 +353,7 @@
         state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
         return logp, state_list
 
-
+@register_class("decoder_classes", "TransformerDecoder")
 class TransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -401,6 +402,7 @@
         )
 
 
+@register_class("decoder_classes", "ParaformerDecoderSAN")
 class ParaformerDecoderSAN(BaseTransformerDecoder):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -514,7 +516,7 @@
         else:
             return x, olens
 
-
+@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
 class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -575,7 +577,7 @@
             ),
         )
 
-
+@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
 class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -637,6 +639,7 @@
         )
 
 
+@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
 class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -697,7 +700,7 @@
             ),
         )
 
-
+@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
 class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -757,426 +760,3 @@
                 concat_after,
             ),
         )
-
-class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
-    
-    def __init__(
-        self,
-        vocab_size: int,
-        encoder_output_size: int,
-        spker_embedding_dim: int = 256,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        input_layer: str = "embed",
-        use_asr_output_layer: bool = True,
-        use_spk_output_layer: bool = True,
-        pos_enc_class=PositionalEncoding,
-        normalize_before: bool = True,
-    ):
-        super().__init__()
-        attention_dim = encoder_output_size
-
-        if input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(vocab_size, attention_dim),
-                pos_enc_class(attention_dim, positional_dropout_rate),
-            )
-        elif input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(vocab_size, attention_dim),
-                torch.nn.LayerNorm(attention_dim),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(attention_dim, positional_dropout_rate),
-            )
-        else:
-            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-
-        self.normalize_before = normalize_before
-        if self.normalize_before:
-            self.after_norm = LayerNorm(attention_dim)
-        if use_asr_output_layer:
-            self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
-        else:
-            self.asr_output_layer = None
-
-        if use_spk_output_layer:
-            self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
-        else:
-            self.spk_output_layer = None
-
-        self.cos_distance_att = CosineDistanceAttention()
-
-        self.decoder1 = None
-        self.decoder2 = None
-        self.decoder3 = None
-        self.decoder4 = None
-
-    def forward(
-        self,
-        asr_hs_pad: torch.Tensor,
-        spk_hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
-        ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
-        profile: torch.Tensor,
-        profile_lens: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        
-        tgt = ys_in_pad
-        # tgt_mask: (B, 1, L)
-        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
-        # m: (1, L, L)
-        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
-        # tgt_mask: (B, L, L)
-        tgt_mask = tgt_mask & m
-
-        asr_memory = asr_hs_pad
-        spk_memory = spk_hs_pad
-        memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
-        # Spk decoder
-        x = self.embed(tgt)
-
-        x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
-            x, tgt_mask, asr_memory, spk_memory, memory_mask
-        )
-        x, tgt_mask, spk_memory, memory_mask = self.decoder2(
-            x, tgt_mask, spk_memory, memory_mask
-        )
-        if self.normalize_before:
-            x = self.after_norm(x)
-        if self.spk_output_layer is not None:
-            x = self.spk_output_layer(x)
-        dn, weights = self.cos_distance_att(x, profile, profile_lens)
-        # Asr decoder
-        x, tgt_mask, asr_memory, memory_mask = self.decoder3(
-            z, tgt_mask, asr_memory, memory_mask, dn
-        )
-        x, tgt_mask, asr_memory, memory_mask = self.decoder4(
-            x, tgt_mask, asr_memory, memory_mask
-        )
-
-        if self.normalize_before:
-            x = self.after_norm(x)
-        if self.asr_output_layer is not None:
-            x = self.asr_output_layer(x)
-
-        olens = tgt_mask.sum(1)
-        return x, weights, olens
-
-
-    def forward_one_step(
-        self,
-        tgt: torch.Tensor,
-        tgt_mask: torch.Tensor,
-        asr_memory: torch.Tensor,
-        spk_memory: torch.Tensor,
-        profile: torch.Tensor,
-        cache: List[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
-        
-        x = self.embed(tgt)
-
-        if cache is None:
-            cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
-        new_cache = []
-        x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
-                x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
-        )
-        new_cache.append(x)
-        for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
-            x, tgt_mask, spk_memory, _ = decoder(
-                x, tgt_mask, spk_memory, None, cache=c
-            )
-            new_cache.append(x)
-        if self.normalize_before:
-            x = self.after_norm(x)
-        else:
-            x = x
-        if self.spk_output_layer is not None:
-            x = self.spk_output_layer(x)
-        dn, weights = self.cos_distance_att(x, profile, None)
-
-        x, tgt_mask, asr_memory, _ = self.decoder3(
-            z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
-        )
-        new_cache.append(x)
-
-        for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
-            x, tgt_mask, asr_memory, _ = decoder(
-                x, tgt_mask, asr_memory, None, cache=c
-            )
-            new_cache.append(x)
-
-        if self.normalize_before:
-            y = self.after_norm(x[:, -1])
-        else:
-            y = x[:, -1]
-        if self.asr_output_layer is not None:
-            y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
-
-        return y, weights, new_cache
-
-    def score(self, ys, state, asr_enc, spk_enc, profile):
-        """Score."""
-        ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
-        logp, weights, state = self.forward_one_step(
-            ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
-        )
-        return logp.squeeze(0), weights.squeeze(), state
-
-class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
-    def __init__(
-        self,
-        vocab_size: int,
-        encoder_output_size: int,
-        spker_embedding_dim: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        asr_num_blocks: int = 6,
-        spk_num_blocks: int = 3,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        self_attention_dropout_rate: float = 0.0,
-        src_attention_dropout_rate: float = 0.0,
-        input_layer: str = "embed",
-        use_asr_output_layer: bool = True,
-        use_spk_output_layer: bool = True,
-        pos_enc_class=PositionalEncoding,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-    ):
-        super().__init__(
-            vocab_size=vocab_size,
-            encoder_output_size=encoder_output_size,
-            spker_embedding_dim=spker_embedding_dim,
-            dropout_rate=dropout_rate,
-            positional_dropout_rate=positional_dropout_rate,
-            input_layer=input_layer,
-            use_asr_output_layer=use_asr_output_layer,
-            use_spk_output_layer=use_spk_output_layer,
-            pos_enc_class=pos_enc_class,
-            normalize_before=normalize_before,
-        )
-
-        attention_dim = encoder_output_size
-
-        self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
-            attention_dim,
-            MultiHeadedAttention(
-                attention_heads, attention_dim, self_attention_dropout_rate
-            ),
-            MultiHeadedAttention(
-                attention_heads, attention_dim, src_attention_dropout_rate
-            ),
-            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-            dropout_rate,
-            normalize_before,
-            concat_after,
-        )
-        self.decoder2 = repeat(
-            spk_num_blocks - 1,
-            lambda lnum: DecoderLayer(
-                attention_dim,
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, self_attention_dropout_rate
-                ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
-                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        
-        
-        self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
-            attention_dim,
-            spker_embedding_dim,
-            MultiHeadedAttention(
-                attention_heads, attention_dim, src_attention_dropout_rate
-            ),
-            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-            dropout_rate,
-            normalize_before,
-            concat_after,
-        )
-        self.decoder4 = repeat(
-            asr_num_blocks - 1,
-            lambda lnum: DecoderLayer(
-                attention_dim,
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, self_attention_dropout_rate
-                ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
-                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-
-class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
-
-    def __init__(
-        self,
-        size,
-        self_attn,
-        src_attn,
-        feed_forward,
-        dropout_rate,
-        normalize_before=True,
-        concat_after=False,
-    ):
-        """Construct an DecoderLayer object."""
-        super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
-        self.size = size
-        self.self_attn = self_attn
-        self.src_attn = src_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(size)
-        self.norm2 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        if self.concat_after:
-            self.concat_linear1 = nn.Linear(size + size, size)
-            self.concat_linear2 = nn.Linear(size + size, size)
-
-    def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
-        
-        residual = tgt
-        if self.normalize_before:
-            tgt = self.norm1(tgt)
-
-        if cache is None:
-            tgt_q = tgt
-            tgt_q_mask = tgt_mask
-        else:
-            # compute only the last frame query keeping dim: max_time_out -> 1
-            assert cache.shape == (
-                tgt.shape[0],
-                tgt.shape[1] - 1,
-                self.size,
-            ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
-            tgt_q = tgt[:, -1:, :]
-            residual = residual[:, -1:, :]
-            tgt_q_mask = None
-            if tgt_mask is not None:
-                tgt_q_mask = tgt_mask[:, -1:, :]
-
-        if self.concat_after:
-            tgt_concat = torch.cat(
-                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
-            )
-            x = residual + self.concat_linear1(tgt_concat)
-        else:
-            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
-        if not self.normalize_before:
-            x = self.norm1(x)
-        z = x
-        
-        residual = x
-        if self.normalize_before:
-            x = self.norm1(x)
-
-        skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
-
-        if self.concat_after:
-            x_concat = torch.cat(
-                (x, skip), dim=-1
-            )
-            x = residual + self.concat_linear2(x_concat)
-        else:
-            x = residual + self.dropout(skip)
-        if not self.normalize_before:
-            x = self.norm1(x)
-        
-        residual = x
-        if self.normalize_before:
-            x = self.norm2(x)
-        x = residual + self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm2(x)
-
-        if cache is not None:
-            x = torch.cat([cache, x], dim=1)
-            
-        return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
-
-class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
-    
-    def __init__(
-        self,
-        size,
-        d_size,
-        src_attn,
-        feed_forward,
-        dropout_rate,
-        normalize_before=True,
-        concat_after=False,
-    ):
-        """Construct an DecoderLayer object."""
-        super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
-        self.size = size
-        self.src_attn = src_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(size)
-        self.norm2 = LayerNorm(size)
-        self.norm3 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        self.spk_linear = nn.Linear(d_size, size, bias=False)
-        if self.concat_after:
-            self.concat_linear1 = nn.Linear(size + size, size)
-            self.concat_linear2 = nn.Linear(size + size, size)
-
-    def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
-        
-        residual = tgt
-        if self.normalize_before:
-            tgt = self.norm1(tgt)
-
-        if cache is None:
-            tgt_q = tgt
-            tgt_q_mask = tgt_mask
-        else:
-            
-            tgt_q = tgt[:, -1:, :]
-            residual = residual[:, -1:, :]
-            tgt_q_mask = None
-            if tgt_mask is not None:
-                tgt_q_mask = tgt_mask[:, -1:, :]
-
-        x = tgt_q
-        if self.normalize_before:
-            x = self.norm2(x)
-        if self.concat_after:
-            x_concat = torch.cat(
-                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
-            )
-            x = residual + self.concat_linear2(x_concat)
-        else:
-            x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
-        if not self.normalize_before:
-            x = self.norm2(x)
-        residual = x
-
-        if dn!=None:
-            x = x + self.spk_linear(dn)
-        if self.normalize_before:
-            x = self.norm3(x)
-        
-        x = residual + self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm3(x)
-
-        if cache is not None:
-            x = torch.cat([cache, x], dim=1)
-
-        return x, tgt_mask, memory, memory_mask
\ No newline at end of file
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
new file mode 100644
index 0000000..f48617c
--- /dev/null
+++ b/funasr/models/sanm/attention.py
@@ -0,0 +1,641 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+from typing import Optional, Tuple
+
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+import funasr.models.lora.layers as lora
+
+class MultiHeadedAttention(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttention, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        self.linear_q = nn.Linear(n_feat, n_feat)
+        self.linear_k = nn.Linear(n_feat, n_feat)
+        self.linear_v = nn.Linear(n_feat, n_feat)
+        self.linear_out = nn.Linear(n_feat, n_feat)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, query, key, value):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        n_batch = query.size(0)
+        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
+        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
+        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q, k, v
+
+    def forward_attention(self, value, scores, mask):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, query, key, value, mask):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, mask)
+
+
+
+
+
+
+class MultiHeadedAttentionSANM(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
+        """Construct an MultiHeadedAttention object."""
+        super().__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        # self.linear_q = nn.Linear(n_feat, n_feat)
+        # self.linear_k = nn.Linear(n_feat, n_feat)
+        # self.linear_v = nn.Linear(n_feat, n_feat)
+        if lora_list is not None:
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
+            lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
+            if lora_qkv_list == [False, False, False]:
+                self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+            else:
+                self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+        else:
+            self.linear_out = nn.Linear(n_feat, n_feat)
+            self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+        self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+        # padding
+        left_padding = (kernel_size - 1) // 2
+        if sanm_shfit > 0:
+            left_padding = left_padding + sanm_shfit
+        right_padding = kernel_size - 1 - left_padding
+        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+
+    def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
+        b, t, d = inputs.size()
+        if mask is not None:
+            mask = torch.reshape(mask, (b, -1, 1))
+            if mask_shfit_chunk is not None:
+                mask = mask * mask_shfit_chunk
+            inputs = inputs * mask
+
+        x = inputs.transpose(1, 2)
+        x = self.pad_fn(x)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        x += inputs
+        x = self.dropout(x)
+        if mask is not None:
+            x = x * mask
+        return x
+
+    def forward_qkv(self, x):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        b, t, d = x.size()
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q_h, k_h, v_h, v
+
+    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            if mask_att_chunk_encoder is not None:
+                mask = mask * mask_att_chunk_encoder
+
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+        return att_outs + fsmn_memory
+
+    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        if chunk_size is not None and look_back > 0 or look_back == -1:
+            if cache is not None:
+                k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
+                v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
+                k_h = torch.cat((cache["k"], k_h), dim=2)
+                v_h = torch.cat((cache["v"], v_h), dim=2)
+
+                cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
+                cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
+                if look_back != -1:
+                    cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
+                    cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
+            else:
+                cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
+                             "v": v_h[:, :, :-(chunk_size[2]), :]}
+                cache = cache_tmp
+        fsmn_memory = self.forward_fsmn(v, None)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, None)
+        return att_outs + fsmn_memory, cache
+
+
+
+class MultiHeadedAttentionSANMDecoder(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttentionSANMDecoder, self).__init__()
+
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+        self.fsmn_block = nn.Conv1d(n_feat, n_feat,
+                                    kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+        # padding
+        # padding
+        left_padding = (kernel_size - 1) // 2
+        if sanm_shfit > 0:
+            left_padding = left_padding + sanm_shfit
+        right_padding = kernel_size - 1 - left_padding
+        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+        self.kernel_size = kernel_size
+
+    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
+        '''
+        :param x: (#batch, time1, size).
+        :param mask: Mask tensor (#batch, 1, time)
+        :return:
+        '''
+        # print("in fsmn, inputs", inputs.size())
+        b, t, d = inputs.size()
+        # logging.info(
+        #     "mask: {}".format(mask.size()))
+        if mask is not None:
+            mask = torch.reshape(mask, (b ,-1, 1))
+            # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+            if mask_shfit_chunk is not None:
+                # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
+                mask = mask * mask_shfit_chunk
+            # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+            # print("in fsmn, mask", mask.size())
+            # print("in fsmn, inputs", inputs.size())
+            inputs = inputs * mask
+
+        x = inputs.transpose(1, 2)
+        b, d, t = x.size()
+        if cache is None:
+            # print("in fsmn, cache is None, x", x.size())
+
+            x = self.pad_fn(x)
+            if not self.training:
+                cache = x
+        else:
+            # print("in fsmn, cache is not None, x", x.size())
+            # x = torch.cat((x, cache), dim=2)[:, :, :-1]
+            # if t < self.kernel_size:
+            #     x = self.pad_fn(x)
+            x = torch.cat((cache[:, :, 1:], x), dim=2)
+            x = x[:, :, -(self.kernel_size+t-1):]
+            # print("in fsmn, cache is not None, x_cat", x.size())
+            cache = x
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        # print("in fsmn, fsmn_out", x.size())
+        if x.size(1) != inputs.size(1):
+            inputs = inputs[:, -1, :]
+
+        x = x + inputs
+        x = self.dropout(x)
+        if mask is not None:
+            x = x * mask
+        return x, cache
+
+class MultiHeadedAttentionCrossAtt(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttentionCrossAtt, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        if lora_list is not None:
+            if "q" in lora_list:
+                self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_q = nn.Linear(n_feat, n_feat)
+            lora_kv_list = ["k" in lora_list, "v" in lora_list]
+            if lora_kv_list == [False, False]:
+                self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            else:
+                self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
+                                      r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
+        else:
+            self.linear_q = nn.Linear(n_feat, n_feat)
+            self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            self.linear_out = nn.Linear(n_feat, n_feat)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, x, memory):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+
+        # print("in forward_qkv, x", x.size())
+        b = x.size(0)
+        q = self.linear_q(x)
+        q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time1, d_k)
+
+        k_v = self.linear_k_v(memory)
+        k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
+        k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
+
+
+        return q_h, k_h, v_h
+
+    def forward_attention(self, value, scores, mask):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            # logging.info(
+            #     "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, x, memory, memory_mask):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h = self.forward_qkv(x, memory)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        return self.forward_attention(v_h, scores, memory_mask)
+
+    def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h = self.forward_qkv(x, memory)
+        if chunk_size is not None and look_back > 0:
+            if cache is not None:
+                k_h = torch.cat((cache["k"], k_h), dim=2)
+                v_h = torch.cat((cache["v"], v_h), dim=2)
+                cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
+                cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
+            else:
+                cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
+                             "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
+                cache = cache_tmp
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        return self.forward_attention(v_h, scores, None), cache
+
+
+class MultiHeadSelfAttention(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, in_feat, n_feat, dropout_rate):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadSelfAttention, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        self.linear_out = nn.Linear(n_feat, n_feat)
+        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, x):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        b, t, d = x.size()
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q_h, k_h, v_h, v
+
+    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            if mask_att_chunk_encoder is not None:
+                mask = mask * mask_att_chunk_encoder
+
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, x, mask, mask_att_chunk_encoder=None):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+        return att_outs
+
+
+
diff --git a/funasr/models/sanm/decoder.py b/funasr/models/sanm/decoder.py
new file mode 100644
index 0000000..64033ad
--- /dev/null
+++ b/funasr/models/sanm/decoder.py
@@ -0,0 +1,474 @@
+from typing import List
+from typing import Tuple
+import logging
+import torch
+import torch.nn as nn
+import numpy as np
+
+from funasr.models.scama import utils as myutils
+from funasr.models.transformer.decoder import BaseTransformerDecoder
+
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayerSANM(nn.Module):
+    """Single decoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        src_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+    """
+
+    def __init__(
+        self,
+        size,
+        self_attn,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(DecoderLayerSANM, self).__init__()
+        self.size = size
+        self.self_attn = self_attn
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        if self_attn is not None:
+            self.norm2 = LayerNorm(size)
+        if src_attn is not None:
+            self.norm3 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        # tgt = self.dropout(tgt)
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, _ = self.self_attn(tgt, tgt_mask)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        # tgt = self.dropout(tgt)
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            if self.training:
+                cache = None
+            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
+            x = residual + x
+
+        return x, memory, fsmn_cache, opt_cache
+
+
+@register_class("decoder_classes", "FsmnDecoder")
+class FsmnDecoder(BaseTransformerDecoder):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
+
+    """
+    
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        att_layer_num: int = 6,
+        kernel_size: int = 21,
+        sanm_shfit: int = None,
+        concat_embeds: bool = False,
+        attention_dim: int = None,
+        tf2torch_tensor_name_prefix_torch: str = "decoder",
+        tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+        embed_tensor_name_prefix_tf: str = None,
+    ):
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+        if attention_dim is None:
+            attention_dim = encoder_output_size
+        
+        if input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(vocab_size, attention_dim),
+            )
+        elif input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(vocab_size, attention_dim),
+                torch.nn.LayerNorm(attention_dim),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        else:
+            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+        
+        self.normalize_before = normalize_before
+        if self.normalize_before:
+            self.after_norm = LayerNorm(attention_dim)
+        if use_output_layer:
+            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+        else:
+            self.output_layer = None
+        
+        self.att_layer_num = att_layer_num
+        self.num_blocks = num_blocks
+        if sanm_shfit is None:
+            sanm_shfit = (kernel_size - 1) // 2
+        self.decoders = repeat(
+            att_layer_num,
+            lambda lnum: DecoderLayerSANM(
+                attention_dim,
+                MultiHeadedAttentionSANMDecoder(
+                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+                ),
+                MultiHeadedAttentionCrossAtt(
+                    attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
+                ),
+                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if num_blocks - att_layer_num <= 0:
+            self.decoders2 = None
+        else:
+            self.decoders2 = repeat(
+                num_blocks - att_layer_num,
+                lambda lnum: DecoderLayerSANM(
+                    attention_dim,
+                    MultiHeadedAttentionSANMDecoder(
+                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+                    ),
+                    None,
+                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                    dropout_rate,
+                    normalize_before,
+                    concat_after,
+                ),
+            )
+        
+        self.decoders3 = repeat(
+            1,
+            lambda lnum: DecoderLayerSANM(
+                attention_dim,
+                None,
+                None,
+                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if concat_embeds:
+            self.embed_concat_ffn = repeat(
+                1,
+                lambda lnum: DecoderLayerSANM(
+                    attention_dim + encoder_output_size,
+                    None,
+                    None,
+                    PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
+                                                       adim=attention_dim),
+                    dropout_rate,
+                    normalize_before,
+                    concat_after,
+                ),
+            )
+        else:
+            self.embed_concat_ffn = None
+        self.concat_embeds = concat_embeds
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+        self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
+    
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        chunk_mask: torch.Tensor = None,
+        pre_acoustic_embeds: torch.Tensor = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        tgt = ys_in_pad
+        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+        
+        memory = hs_pad
+        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        if chunk_mask is not None:
+            memory_mask = memory_mask * chunk_mask
+            if tgt_mask.size(1) != memory_mask.size(1):
+                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
+        
+        x = self.embed(tgt)
+        
+        if pre_acoustic_embeds is not None and self.concat_embeds:
+            x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+            x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+        
+        x, tgt_mask, memory, memory_mask, _ = self.decoders(
+            x, tgt_mask, memory, memory_mask
+        )
+        if self.decoders2 is not None:
+            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
+                x, tgt_mask, memory, memory_mask
+            )
+        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
+            x, tgt_mask, memory, memory_mask
+        )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+        
+        olens = tgt_mask.sum(1)
+        return x, olens
+    
+    def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
+        """Score."""
+        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+        logp, state = self.forward_one_step(
+            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
+            cache=state
+        )
+        return logp.squeeze(0), state
+    
+    def forward_one_step(
+        self,
+        tgt: torch.Tensor,
+        tgt_mask: torch.Tensor,
+        memory: torch.Tensor,
+        memory_mask: torch.Tensor = None,
+        pre_acoustic_embeds: torch.Tensor = None,
+        cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Forward one step.
+
+        Args:
+            tgt: input token ids, int64 (batch, maxlen_out)
+            tgt_mask: input token mask,  (batch, maxlen_out)
+                      dtype=torch.uint8 in PyTorch 1.2-
+                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+            memory: encoded memory, float32  (batch, maxlen_in, feat)
+            cache: cached output list of (batch, max_time_out-1, size)
+        Returns:
+            y, cache: NN output value and cache per `self.decoders`.
+            y.shape` is (batch, maxlen_out, token)
+        """
+        
+        x = tgt[:, -1:]
+        tgt_mask = None
+        x = self.embed(x)
+        
+        if pre_acoustic_embeds is not None and self.concat_embeds:
+            x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+            x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+        
+        if cache is None:
+            cache_layer_num = len(self.decoders)
+            if self.decoders2 is not None:
+                cache_layer_num += len(self.decoders2)
+            cache = [None] * cache_layer_num
+        new_cache = []
+        # for c, decoder in zip(cache, self.decoders):
+        for i in range(self.att_layer_num):
+            decoder = self.decoders[i]
+            c = cache[i]
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                x, tgt_mask, memory, memory_mask, cache=c
+            )
+            new_cache.append(c_ret)
+        
+        if self.num_blocks - self.att_layer_num >= 1:
+            for i in range(self.num_blocks - self.att_layer_num):
+                j = i + self.att_layer_num
+                decoder = self.decoders2[i]
+                c = cache[j]
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                    x, tgt_mask, memory, memory_mask, cache=c
+                )
+                new_cache.append(c_ret)
+        
+        for decoder in self.decoders3:
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
+                x, tgt_mask, memory, None, cache=None
+            )
+        
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.output_layer is not None:
+            y = self.output_layer(y)
+            y = torch.log_softmax(y, dim=-1)
+        
+        return y, new_cache
+    
+    
\ No newline at end of file
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
new file mode 100644
index 0000000..8e159e2
--- /dev/null
+++ b/funasr/models/sanm/encoder.py
@@ -0,0 +1,454 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import numpy as np
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
+from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.positionwise_feed_forward import (
+    PositionwiseFeedForward,  # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+
+
+from funasr.models.ctc.ctc import CTC
+
+from funasr.utils.register import register_class
+
+class EncoderLayerSANM(nn.Module):
+    def __init__(
+        self,
+        in_size,
+        size,
+        self_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+        stochastic_depth_rate=0.0,
+    ):
+        """Construct an EncoderLayer object."""
+        super(EncoderLayerSANM, self).__init__()
+        self.self_attn = self_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(in_size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.in_size = in_size
+        self.size = size
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear = nn.Linear(size + size, size)
+        self.stochastic_depth_rate = stochastic_depth_rate
+        self.dropout_rate = dropout_rate
+
+    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+        skip_layer = False
+        # with stochastic depth, residual connection `x + f(x)` becomes
+        # `x <- x + 1 / (1 - p) * f(x)` at training time.
+        stoch_layer_coeff = 1.0
+        if self.training and self.stochastic_depth_rate > 0:
+            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+        if skip_layer:
+            if cache is not None:
+                x = torch.cat([cache, x], dim=1)
+            return x, mask
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.concat_after:
+            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+            else:
+                x = stoch_layer_coeff * self.concat_linear(x_concat)
+        else:
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+            else:
+                x = stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.in_size == self.size:
+            attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+            x = residual + attn
+        else:
+            x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + self.feed_forward(x)
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        return x, cache
+
+@register_class("encoder_classes", "SANMEncoder")
+class SANMEncoder(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    San-m: Memory equipped self-attention for end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        pos_enc_class=SinusoidalPositionEncoder,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
+        kernel_size : int = 11,
+        sanm_shfit : int = 0,
+        lora_list: List[str] = None,
+        lora_rank: int = 8,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.1,
+        selfattention_layer_type: str = "sanm",
+        tf2torch_tensor_name_prefix_torch: str = "encoder",
+        tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+    ):
+        super().__init__()
+        self._output_size = output_size
+
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                SinusoidalPositionEncoder(),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        elif input_layer == "pe":
+            self.embed = SinusoidalPositionEncoder()
+        elif input_layer == "pe_online":
+            self.embed = StreamSinusoidalPositionEncoder()
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+
+        elif selfattention_layer_type == "sanm":
+            encoder_selfattn_layer = MultiHeadedAttentionSANM
+            encoder_selfattn_layer_args0 = (
+                attention_heads,
+                input_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+                lora_list,
+                lora_rank,
+                lora_alpha,
+                lora_dropout,
+            )
+
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+                lora_list,
+                lora_rank,
+                lora_alpha,
+                lora_dropout,
+            )
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+        self.encoders = repeat(
+            num_blocks-1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+        self.dropout = nn.Dropout(dropout_rate)
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        xs_pad = xs_pad * self.output_size()**0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling2)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        # xs_pad = self.dropout(xs_pad)
+        encoder_outs = self.encoders0(xs_pad, masks)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+        if len(self.interctc_layer_idx) == 0:
+            encoder_outs = self.encoders(xs_pad, masks)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                encoder_outs = encoder_layer(xs_pad, masks)
+                xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens, None
+
+    def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+        if len(cache) == 0:
+            return feats
+        cache["feats"] = to_device(cache["feats"], device=feats.device)
+        overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+        return overlap_feats
+
+    def forward_chunk(self,
+                      xs_pad: torch.Tensor,
+                      ilens: torch.Tensor,
+                      cache: dict = None,
+                      ctc: CTC = None,
+                      ):
+        xs_pad *= self.output_size() ** 0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        else:
+            xs_pad = self.embed(xs_pad, cache)
+        if cache["tail_chunk"]:
+            xs_pad = to_device(cache["feats"], device=xs_pad.device)
+        else:
+            xs_pad = self._add_overlap_chunk(xs_pad, cache)
+        encoder_outs = self.encoders0(xs_pad, None, None, None, None)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+        if len(self.interctc_layer_idx) == 0:
+            encoder_outs = self.encoders(xs_pad, None, None, None, None)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                encoder_outs = encoder_layer(xs_pad, None, None, None, None)
+                xs_pad, masks = encoder_outs[0], encoder_outs[1]
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), None, None
+        return xs_pad, ilens, None
+
diff --git a/funasr/models/sanm/model.py b/funasr/models/sanm/model.py
new file mode 100644
index 0000000..e01394c
--- /dev/null
+++ b/funasr/models/sanm/model.py
@@ -0,0 +1,18 @@
+import logging
+
+import torch
+
+from funasr.models.transformer.model import Transformer
+from funasr.utils.register import register_class, registry_tables
+
+@register_class("model_classes", "SANM")
+class SANM(Transformer):
+	"""CTC-attention hybrid Encoder-Decoder model"""
+
+	def __init__(
+		self,
+		*args,
+		**kwargs,
+	):
+
+		super().__init__(*args, **kwargs)
diff --git a/funasr/models/sanm/positionwise_feed_forward.py b/funasr/models/sanm/positionwise_feed_forward.py
new file mode 100644
index 0000000..7125854
--- /dev/null
+++ b/funasr/models/sanm/positionwise_feed_forward.py
@@ -0,0 +1,34 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+
+"""Positionwise feed forward layer definition."""
+
+import torch
+
+from funasr.models.transformer.layer_norm import LayerNorm
+
+
+
+class PositionwiseFeedForwardDecoderSANM(torch.nn.Module):
+    """Positionwise feed forward layer.
+
+    Args:
+        idim (int): Input dimenstion.
+        hidden_units (int): The number of hidden units.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, idim, hidden_units, dropout_rate, adim=None, activation=torch.nn.ReLU()):
+        """Construct an PositionwiseFeedForward object."""
+        super(PositionwiseFeedForwardDecoderSANM, self).__init__()
+        self.w_1 = torch.nn.Linear(idim, hidden_units)
+        self.w_2 = torch.nn.Linear(hidden_units, idim if adim is None else adim, bias=False)
+        self.dropout = torch.nn.Dropout(dropout_rate)
+        self.activation = activation
+        self.norm = LayerNorm(hidden_units)
+
+    def forward(self, x):
+        """Forward function."""
+        return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x)))))
diff --git a/funasr/models/sanm/sanm_decoder.py b/funasr/models/sanm/sanm_decoder.py
deleted file mode 100644
index cb38c12..0000000
--- a/funasr/models/sanm/sanm_decoder.py
+++ /dev/null
@@ -1,1541 +0,0 @@
-from typing import List
-from typing import Tuple
-import logging
-import torch
-import torch.nn as nn
-import numpy as np
-
-from funasr.models.scama import utils as myutils
-from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
-
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
-from funasr.models.transformer.embedding import PositionalEncoding
-from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.transformer.repeat import repeat
-
-
-class DecoderLayerSANM(nn.Module):
-    """Single decoder layer module.
-
-    Args:
-        size (int): Input dimension.
-        self_attn (torch.nn.Module): Self-attention module instance.
-            `MultiHeadedAttention` instance can be used as the argument.
-        src_attn (torch.nn.Module): Self-attention module instance.
-            `MultiHeadedAttention` instance can be used as the argument.
-        feed_forward (torch.nn.Module): Feed-forward module instance.
-            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
-            can be used as the argument.
-        dropout_rate (float): Dropout rate.
-        normalize_before (bool): Whether to use layer_norm before the first block.
-        concat_after (bool): Whether to concat attention layer's input and output.
-            if True, additional linear will be applied.
-            i.e. x -> x + linear(concat(x, att(x)))
-            if False, no additional linear will be applied. i.e. x -> x + att(x)
-
-
-    """
-
-    def __init__(
-        self,
-        size,
-        self_attn,
-        src_attn,
-        feed_forward,
-        dropout_rate,
-        normalize_before=True,
-        concat_after=False,
-    ):
-        """Construct an DecoderLayer object."""
-        super(DecoderLayerSANM, self).__init__()
-        self.size = size
-        self.self_attn = self_attn
-        self.src_attn = src_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(size)
-        if self_attn is not None:
-            self.norm2 = LayerNorm(size)
-        if src_attn is not None:
-            self.norm3 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        if self.concat_after:
-            self.concat_linear1 = nn.Linear(size + size, size)
-            self.concat_linear2 = nn.Linear(size + size, size)
-
-    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
-        """Compute decoded features.
-
-        Args:
-            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
-            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
-            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
-            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
-            cache (List[torch.Tensor]): List of cached tensors.
-                Each tensor shape should be (#batch, maxlen_out - 1, size).
-
-        Returns:
-            torch.Tensor: Output tensor(#batch, maxlen_out, size).
-            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
-            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
-            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
-
-        """
-        # tgt = self.dropout(tgt)
-        residual = tgt
-        if self.normalize_before:
-            tgt = self.norm1(tgt)
-        tgt = self.feed_forward(tgt)
-
-        x = tgt
-        if self.self_attn:
-            if self.normalize_before:
-                tgt = self.norm2(tgt)
-            x, _ = self.self_attn(tgt, tgt_mask)
-            x = residual + self.dropout(x)
-
-        if self.src_attn is not None:
-            residual = x
-            if self.normalize_before:
-                x = self.norm3(x)
-
-            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
-
-        return x, tgt_mask, memory, memory_mask, cache
-
-    def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
-        """Compute decoded features.
-
-        Args:
-            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
-            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
-            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
-            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
-            cache (List[torch.Tensor]): List of cached tensors.
-                Each tensor shape should be (#batch, maxlen_out - 1, size).
-
-        Returns:
-            torch.Tensor: Output tensor(#batch, maxlen_out, size).
-            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
-            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
-            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
-
-        """
-        # tgt = self.dropout(tgt)
-        residual = tgt
-        if self.normalize_before:
-            tgt = self.norm1(tgt)
-        tgt = self.feed_forward(tgt)
-
-        x = tgt
-        if self.self_attn:
-            if self.normalize_before:
-                tgt = self.norm2(tgt)
-            if self.training:
-                cache = None
-            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
-            x = residual + self.dropout(x)
-
-        if self.src_attn is not None:
-            residual = x
-            if self.normalize_before:
-                x = self.norm3(x)
-
-            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
-
-
-        return x, tgt_mask, memory, memory_mask, cache
-
-    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
-        """Compute decoded features.
-
-        Args:
-            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
-            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
-            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
-            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
-            cache (List[torch.Tensor]): List of cached tensors.
-                Each tensor shape should be (#batch, maxlen_out - 1, size).
-
-        Returns:
-            torch.Tensor: Output tensor(#batch, maxlen_out, size).
-            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
-            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
-            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
-
-        """
-        residual = tgt
-        if self.normalize_before:
-            tgt = self.norm1(tgt)
-        tgt = self.feed_forward(tgt)
-
-        x = tgt
-        if self.self_attn:
-            if self.normalize_before:
-                tgt = self.norm2(tgt)
-            x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
-            x = residual + self.dropout(x)
-
-        if self.src_attn is not None:
-            residual = x
-            if self.normalize_before:
-                x = self.norm3(x)
-
-            x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
-            x = residual + x
-
-        return x, memory, fsmn_cache, opt_cache
-
-
-class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
-    https://arxiv.org/abs/2006.01713
-
-    """
-    def __init__(
-            self,
-            vocab_size: int,
-            encoder_output_size: int,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            self_attention_dropout_rate: float = 0.0,
-            src_attention_dropout_rate: float = 0.0,
-            input_layer: str = "embed",
-            use_output_layer: bool = True,
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            att_layer_num: int = 6,
-            kernel_size: int = 21,
-            sanm_shfit: int = None,
-            concat_embeds: bool = False,
-            attention_dim: int = None,
-            tf2torch_tensor_name_prefix_torch: str = "decoder",
-            tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
-            embed_tensor_name_prefix_tf: str = None,
-    ):
-        super().__init__(
-            vocab_size=vocab_size,
-            encoder_output_size=encoder_output_size,
-            dropout_rate=dropout_rate,
-            positional_dropout_rate=positional_dropout_rate,
-            input_layer=input_layer,
-            use_output_layer=use_output_layer,
-            pos_enc_class=pos_enc_class,
-            normalize_before=normalize_before,
-        )
-        if attention_dim is None:
-            attention_dim = encoder_output_size
-
-        if input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(vocab_size, attention_dim),
-            )
-        elif input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(vocab_size, attention_dim),
-                torch.nn.LayerNorm(attention_dim),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(attention_dim, positional_dropout_rate),
-            )
-        else:
-            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-
-        self.normalize_before = normalize_before
-        if self.normalize_before:
-            self.after_norm = LayerNorm(attention_dim)
-        if use_output_layer:
-            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
-        else:
-            self.output_layer = None
-
-        self.att_layer_num = att_layer_num
-        self.num_blocks = num_blocks
-        if sanm_shfit is None:
-            sanm_shfit = (kernel_size - 1) // 2
-        self.decoders = repeat(
-            att_layer_num,
-            lambda lnum: DecoderLayerSANM(
-                attention_dim,
-                MultiHeadedAttentionSANMDecoder(
-                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
-                ),
-                MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
-                ),
-                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if num_blocks - att_layer_num <= 0:
-            self.decoders2 = None
-        else:
-            self.decoders2 = repeat(
-                num_blocks - att_layer_num,
-                lambda lnum: DecoderLayerSANM(
-                    attention_dim,
-                    MultiHeadedAttentionSANMDecoder(
-                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
-                    ),
-                    None,
-                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
-                    dropout_rate,
-                    normalize_before,
-                    concat_after,
-                ),
-            )
-
-        self.decoders3 = repeat(
-            1,
-            lambda lnum: DecoderLayerSANM(
-                attention_dim,
-                None,
-                None,
-                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if concat_embeds:
-            self.embed_concat_ffn = repeat(
-                1,
-                lambda lnum: DecoderLayerSANM(
-                    attention_dim + encoder_output_size,
-                    None,
-                    None,
-                    PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
-                                                      adim=attention_dim),
-                    dropout_rate,
-                    normalize_before,
-                    concat_after,
-                ),
-            )
-        else:
-            self.embed_concat_ffn = None
-        self.concat_embeds = concat_embeds
-        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
-        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-        self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
-
-    def forward(
-            self,
-            hs_pad: torch.Tensor,
-            hlens: torch.Tensor,
-            ys_in_pad: torch.Tensor,
-            ys_in_lens: torch.Tensor,
-            chunk_mask: torch.Tensor = None,
-            pre_acoustic_embeds: torch.Tensor = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Forward decoder.
-
-        Args:
-            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
-            hlens: (batch)
-            ys_in_pad:
-                input token ids, int64 (batch, maxlen_out)
-                if input_layer == "embed"
-                input tensor (batch, maxlen_out, #mels) in the other cases
-            ys_in_lens: (batch)
-        Returns:
-            (tuple): tuple containing:
-
-            x: decoded token score before softmax (batch, maxlen_out, token)
-                if use_output_layer is True,
-            olens: (batch, )
-        """
-        tgt = ys_in_pad
-        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-
-        memory = hs_pad
-        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-        if chunk_mask is not None:
-            memory_mask = memory_mask * chunk_mask
-            if tgt_mask.size(1) != memory_mask.size(1):
-                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
-
-        x = self.embed(tgt)
-
-        if pre_acoustic_embeds is not None and self.concat_embeds:
-            x = torch.cat((x, pre_acoustic_embeds), dim=-1)
-            x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
-
-        x, tgt_mask, memory, memory_mask, _ = self.decoders(
-            x, tgt_mask, memory, memory_mask
-        )
-        if self.decoders2 is not None:
-            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
-                x, tgt_mask, memory, memory_mask
-            )
-        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
-            x, tgt_mask, memory, memory_mask
-        )
-        if self.normalize_before:
-            x = self.after_norm(x)
-        if self.output_layer is not None:
-            x = self.output_layer(x)
-
-        olens = tgt_mask.sum(1)
-        return x, olens
-
-    def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
-        """Score."""
-        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
-        logp, state = self.forward_one_step(
-            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
-            cache=state
-        )
-        return logp.squeeze(0), state
-
-    def forward_one_step(
-            self,
-            tgt: torch.Tensor,
-            tgt_mask: torch.Tensor,
-            memory: torch.Tensor,
-            memory_mask: torch.Tensor = None,
-            pre_acoustic_embeds: torch.Tensor = None,
-            cache: List[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
-        """Forward one step.
-
-        Args:
-            tgt: input token ids, int64 (batch, maxlen_out)
-            tgt_mask: input token mask,  (batch, maxlen_out)
-                      dtype=torch.uint8 in PyTorch 1.2-
-                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
-            memory: encoded memory, float32  (batch, maxlen_in, feat)
-            cache: cached output list of (batch, max_time_out-1, size)
-        Returns:
-            y, cache: NN output value and cache per `self.decoders`.
-            y.shape` is (batch, maxlen_out, token)
-        """
-
-        x = tgt[:, -1:]
-        tgt_mask = None
-        x = self.embed(x)
-
-        if pre_acoustic_embeds is not None and self.concat_embeds:
-            x = torch.cat((x, pre_acoustic_embeds), dim=-1)
-            x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
-
-        if cache is None:
-            cache_layer_num = len(self.decoders)
-            if self.decoders2 is not None:
-                cache_layer_num += len(self.decoders2)
-            cache = [None] * cache_layer_num
-        new_cache = []
-        # for c, decoder in zip(cache, self.decoders):
-        for i in range(self.att_layer_num):
-            decoder = self.decoders[i]
-            c = cache[i]
-            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
-                x, tgt_mask, memory, memory_mask, cache=c
-            )
-            new_cache.append(c_ret)
-
-        if self.num_blocks - self.att_layer_num >= 1:
-            for i in range(self.num_blocks - self.att_layer_num):
-                j = i + self.att_layer_num
-                decoder = self.decoders2[i]
-                c = cache[j]
-                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
-                    x, tgt_mask, memory, memory_mask, cache=c
-                )
-                new_cache.append(c_ret)
-
-        for decoder in self.decoders3:
-            x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
-                x, tgt_mask, memory, None, cache=None
-            )
-
-        if self.normalize_before:
-            y = self.after_norm(x[:, -1])
-        else:
-            y = x[:, -1]
-        if self.output_layer is not None:
-            y = self.output_layer(y)
-            y = torch.log_softmax(y, dim=-1)
-
-        return y, new_cache
-
-    def gen_tf2torch_map_dict(self):
-    
-        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
-        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
-        embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
-        map_dict_local = {
-        
-            ## decoder
-            # ffn
-            "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-        
-            # fsmn
-            "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
-                    tensor_name_prefix_tf),
-                    "squeeze": None,
-                    "transpose": None,
-                },  # (256,),(256,)
-            "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
-                    tensor_name_prefix_tf),
-                    "squeeze": None,
-                    "transpose": None,
-                },  # (256,),(256,)
-            "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
-                    tensor_name_prefix_tf),
-                    "squeeze": 0,
-                    "transpose": (1, 2, 0),
-                },  # (256,1,31),(1,31,256,1)
-            # src att
-            "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # dnn
-            "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-        
-            # embed_concat_ffn
-            "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-        
-            # out norm
-            "{}.after_norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.after_norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-        
-            # in embed
-            "{}.embed.0.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (4235,256),(4235,256)
-        
-            # out layer
-            "{}.output_layer.weight".format(tensor_name_prefix_torch):
-                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
-                          "{}/w_embs".format(embed_tensor_name_prefix_tf)],
-                 "squeeze": [None, None],
-                 "transpose": [(1, 0), None],
-                 },  # (4235,256),(256,4235)
-            "{}.output_layer.bias".format(tensor_name_prefix_torch):
-                {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
-                          "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
-                 "squeeze": [None, None],
-                 "transpose": [None, None],
-                 },  # (4235,),(4235,)
-        
-        }
-        return map_dict_local
-
-    def convert_tf2torch(self,
-                         var_dict_tf,
-                         var_dict_torch,
-                         ):
-    
-        map_dict = self.gen_tf2torch_map_dict()
-        var_dict_torch_update = dict()
-        decoder_layeridx_sets = set()
-        for name in sorted(var_dict_torch.keys(), reverse=False):
-            names = name.split('.')
-            if names[0] == self.tf2torch_tensor_name_prefix_torch:
-                if names[1] == "decoders":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "decoders2":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                    name_q = name_q.replace("decoders2", "decoders")
-                    layeridx_bias = len(decoder_layeridx_sets)
-                
-                    layeridx += layeridx_bias
-                    if "decoders." in name:
-                        decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "decoders3":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    if "decoders." in name:
-                        decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "embed" or names[1] == "output_layer":
-                    name_tf = map_dict[name]["name"]
-                    if isinstance(name_tf, list):
-                        idx_list = 0
-                        if name_tf[idx_list] in var_dict_tf.keys():
-                            pass
-                        else:
-                            idx_list = 1
-                        data_tf = var_dict_tf[name_tf[idx_list]]
-                        if map_dict[name]["squeeze"][idx_list] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
-                        if map_dict[name]["transpose"][idx_list] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
-                                                                                                   name_tf[idx_list],
-                                                                                                   var_dict_tf[name_tf[
-                                                                                                       idx_list]].shape))
-                
-                    else:
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
-                        if map_dict[name]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "after_norm":
-                    name_tf = map_dict[name]["name"]
-                    data_tf = var_dict_tf[name_tf]
-                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                    var_dict_torch_update[name] = data_tf
-                    logging.info(
-                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                      var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "embed_concat_ffn":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    if "decoders." in name:
-                        decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-    
-        return var_dict_torch_update
-
-
-class ParaformerSANMDecoder(BaseTransformerDecoder):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
-    https://arxiv.org/abs/2006.01713
-    """
-    def __init__(
-        self,
-        vocab_size: int,
-        encoder_output_size: int,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        self_attention_dropout_rate: float = 0.0,
-        src_attention_dropout_rate: float = 0.0,
-        input_layer: str = "embed",
-        use_output_layer: bool = True,
-        pos_enc_class=PositionalEncoding,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        att_layer_num: int = 6,
-        kernel_size: int = 21,
-        sanm_shfit: int = 0,
-        lora_list: List[str] = None,
-        lora_rank: int = 8,
-        lora_alpha: int = 16,
-        lora_dropout: float = 0.1,
-        chunk_multiply_factor: tuple = (1,),
-        tf2torch_tensor_name_prefix_torch: str = "decoder",
-        tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
-    ):
-        super().__init__(
-            vocab_size=vocab_size,
-            encoder_output_size=encoder_output_size,
-            dropout_rate=dropout_rate,
-            positional_dropout_rate=positional_dropout_rate,
-            input_layer=input_layer,
-            use_output_layer=use_output_layer,
-            pos_enc_class=pos_enc_class,
-            normalize_before=normalize_before,
-        )
-
-        attention_dim = encoder_output_size
-
-        if input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(vocab_size, attention_dim),
-                # pos_enc_class(attention_dim, positional_dropout_rate),
-            )
-        elif input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(vocab_size, attention_dim),
-                torch.nn.LayerNorm(attention_dim),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(attention_dim, positional_dropout_rate),
-            )
-        else:
-            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-
-        self.normalize_before = normalize_before
-        if self.normalize_before:
-            self.after_norm = LayerNorm(attention_dim)
-        if use_output_layer:
-            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
-        else:
-            self.output_layer = None
-
-        self.att_layer_num = att_layer_num
-        self.num_blocks = num_blocks
-        if sanm_shfit is None:
-            sanm_shfit = (kernel_size - 1) // 2
-        self.decoders = repeat(
-            att_layer_num,
-            lambda lnum: DecoderLayerSANM(
-                attention_dim,
-                MultiHeadedAttentionSANMDecoder(
-                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
-                ),
-                MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
-                ),
-                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if num_blocks - att_layer_num <= 0:
-            self.decoders2 = None
-        else:
-            self.decoders2 = repeat(
-                num_blocks - att_layer_num,
-                lambda lnum: DecoderLayerSANM(
-                    attention_dim,
-                    MultiHeadedAttentionSANMDecoder(
-                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
-                    ),
-                    None,
-                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
-                    dropout_rate,
-                    normalize_before,
-                    concat_after,
-                ),
-            )
-
-        self.decoders3 = repeat(
-            1,
-            lambda lnum: DecoderLayerSANM(
-                attention_dim,
-                None,
-                None,
-                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
-        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-        self.chunk_multiply_factor = chunk_multiply_factor
-
-    def forward(
-        self,
-        hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
-        ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
-        chunk_mask: torch.Tensor = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Forward decoder.
-
-        Args:
-            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
-            hlens: (batch)
-            ys_in_pad:
-                input token ids, int64 (batch, maxlen_out)
-                if input_layer == "embed"
-                input tensor (batch, maxlen_out, #mels) in the other cases
-            ys_in_lens: (batch)
-        Returns:
-            (tuple): tuple containing:
-
-            x: decoded token score before softmax (batch, maxlen_out, token)
-                if use_output_layer is True,
-            olens: (batch, )
-        """
-        tgt = ys_in_pad
-        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-        
-        memory = hs_pad
-        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-        if chunk_mask is not None:
-            memory_mask = memory_mask * chunk_mask
-            if tgt_mask.size(1) != memory_mask.size(1):
-                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
-
-        x = tgt
-        x, tgt_mask, memory, memory_mask, _ = self.decoders(
-            x, tgt_mask, memory, memory_mask
-        )
-        if self.decoders2 is not None:
-            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
-                x, tgt_mask, memory, memory_mask
-            )
-        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
-            x, tgt_mask, memory, memory_mask
-        )
-        if self.normalize_before:
-            x = self.after_norm(x)
-        if self.output_layer is not None:
-            x = self.output_layer(x)
-
-        olens = tgt_mask.sum(1)
-        return x, olens
-
-    def score(self, ys, state, x):
-        """Score."""
-        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
-        logp, state = self.forward_one_step(
-            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
-        )
-        return logp.squeeze(0), state
-
-    def forward_chunk(
-        self,
-        memory: torch.Tensor,
-        tgt: torch.Tensor,
-        cache: dict = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Forward decoder.
-
-        Args:
-            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
-            hlens: (batch)
-            ys_in_pad:
-                input token ids, int64 (batch, maxlen_out)
-                if input_layer == "embed"
-                input tensor (batch, maxlen_out, #mels) in the other cases
-            ys_in_lens: (batch)
-        Returns:
-            (tuple): tuple containing:
-
-            x: decoded token score before softmax (batch, maxlen_out, token)
-                if use_output_layer is True,
-            olens: (batch, )
-        """
-        x = tgt
-        if cache["decode_fsmn"] is None:
-            cache_layer_num = len(self.decoders)
-            if self.decoders2 is not None:
-                cache_layer_num += len(self.decoders2)
-            fsmn_cache = [None] * cache_layer_num
-        else:
-            fsmn_cache = cache["decode_fsmn"]
-
-        if cache["opt"] is None:
-            cache_layer_num = len(self.decoders)
-            opt_cache = [None] * cache_layer_num
-        else:
-            opt_cache = cache["opt"]
-
-        for i in range(self.att_layer_num):
-            decoder = self.decoders[i]
-            x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
-                x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
-                chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
-            )
-
-        if self.num_blocks - self.att_layer_num > 1:
-            for i in range(self.num_blocks - self.att_layer_num):
-                j = i + self.att_layer_num
-                decoder = self.decoders2[i]
-                x, memory, fsmn_cache[j], _  = decoder.forward_chunk(
-                    x, memory, fsmn_cache=fsmn_cache[j]
-                )
-
-        for decoder in self.decoders3:
-            x, memory, _, _ = decoder.forward_chunk(
-                x, memory
-            )
-        if self.normalize_before:
-            x = self.after_norm(x)
-        if self.output_layer is not None:
-            x = self.output_layer(x)
-
-        cache["decode_fsmn"] = fsmn_cache
-        if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
-            cache["opt"] = opt_cache
-        return x
-
-    def forward_one_step(
-        self,
-        tgt: torch.Tensor,
-        tgt_mask: torch.Tensor,
-        memory: torch.Tensor,
-        cache: List[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
-        """Forward one step.
-
-        Args:
-            tgt: input token ids, int64 (batch, maxlen_out)
-            tgt_mask: input token mask,  (batch, maxlen_out)
-                      dtype=torch.uint8 in PyTorch 1.2-
-                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
-            memory: encoded memory, float32  (batch, maxlen_in, feat)
-            cache: cached output list of (batch, max_time_out-1, size)
-        Returns:
-            y, cache: NN output value and cache per `self.decoders`.
-            y.shape` is (batch, maxlen_out, token)
-        """
-        x = self.embed(tgt)
-        if cache is None:
-            cache_layer_num = len(self.decoders)
-            if self.decoders2 is not None:
-                cache_layer_num += len(self.decoders2)
-            cache = [None] * cache_layer_num
-        new_cache = []
-        # for c, decoder in zip(cache, self.decoders):
-        for i in range(self.att_layer_num):
-            decoder = self.decoders[i]
-            c = cache[i]
-            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
-                x, tgt_mask, memory, None, cache=c
-            )
-            new_cache.append(c_ret)
-
-        if self.num_blocks - self.att_layer_num > 1:
-            for i in range(self.num_blocks - self.att_layer_num):
-                j = i + self.att_layer_num
-                decoder = self.decoders2[i]
-                c = cache[j]
-                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
-                    x, tgt_mask, memory, None, cache=c
-                )
-                new_cache.append(c_ret)
-
-        for decoder in self.decoders3:
-
-            x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
-                x, tgt_mask, memory, None, cache=None
-            )
-
-        if self.normalize_before:
-            y = self.after_norm(x[:, -1])
-        else:
-            y = x[:, -1]
-        if self.output_layer is not None:
-            y = torch.log_softmax(self.output_layer(y), dim=-1)
-
-        return y, new_cache
-
-    def gen_tf2torch_map_dict(self):
-    
-        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
-        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
-        map_dict_local = {
-        
-            ## decoder
-            # ffn
-            "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-        
-            # fsmn
-            "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
-                    tensor_name_prefix_tf),
-                    "squeeze": None,
-                    "transpose": None,
-                },  # (256,),(256,)
-            "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
-                    tensor_name_prefix_tf),
-                    "squeeze": None,
-                    "transpose": None,
-                },  # (256,),(256,)
-            "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
-                    tensor_name_prefix_tf),
-                    "squeeze": 0,
-                    "transpose": (1, 2, 0),
-                },  # (256,1,31),(1,31,256,1)
-            # src att
-            "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # dnn
-            "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-        
-            # embed_concat_ffn
-            "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-        
-            # out norm
-            "{}.after_norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.after_norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-        
-            # in embed
-            "{}.embed.0.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/w_embs".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (4235,256),(4235,256)
-        
-            # out layer
-            "{}.output_layer.weight".format(tensor_name_prefix_torch):
-                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
-                 "squeeze": [None, None],
-                 "transpose": [(1, 0), None],
-                 },  # (4235,256),(256,4235)
-            "{}.output_layer.bias".format(tensor_name_prefix_torch):
-                {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
-                          "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
-                 "squeeze": [None, None],
-                 "transpose": [None, None],
-                 },  # (4235,),(4235,)
-        
-        }
-        return map_dict_local
-
-    def convert_tf2torch(self,
-                         var_dict_tf,
-                         var_dict_torch,
-                         ):
-        map_dict = self.gen_tf2torch_map_dict()
-        var_dict_torch_update = dict()
-        decoder_layeridx_sets = set()
-        for name in sorted(var_dict_torch.keys(), reverse=False):
-            names = name.split('.')
-            if names[0] == self.tf2torch_tensor_name_prefix_torch:
-                if names[1] == "decoders":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "decoders2":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                    name_q = name_q.replace("decoders2", "decoders")
-                    layeridx_bias = len(decoder_layeridx_sets)
-                
-                    layeridx += layeridx_bias
-                    if "decoders." in name:
-                        decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "decoders3":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    if "decoders." in name:
-                        decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "embed" or names[1] == "output_layer":
-                    name_tf = map_dict[name]["name"]
-                    if isinstance(name_tf, list):
-                        idx_list = 0
-                        if name_tf[idx_list] in var_dict_tf.keys():
-                            pass
-                        else:
-                            idx_list = 1
-                        data_tf = var_dict_tf[name_tf[idx_list]]
-                        if map_dict[name]["squeeze"][idx_list] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
-                        if map_dict[name]["transpose"][idx_list] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
-                                                                                                   name_tf[idx_list],
-                                                                                                   var_dict_tf[name_tf[
-                                                                                                       idx_list]].shape))
-                
-                    else:
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
-                        if map_dict[name]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "after_norm":
-                    name_tf = map_dict[name]["name"]
-                    data_tf = var_dict_tf[name_tf]
-                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                    var_dict_torch_update[name] = data_tf
-                    logging.info(
-                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                      var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "embed_concat_ffn":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    if "decoders." in name:
-                        decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-    
-        return var_dict_torch_update
diff --git a/funasr/models/sanm/sanm_encoder.py b/funasr/models/sanm/sanm_encoder.py
deleted file mode 100644
index 83edbe7..0000000
--- a/funasr/models/sanm/sanm_encoder.py
+++ /dev/null
@@ -1,1293 +0,0 @@
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-import logging
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from funasr.models.scama.chunk_utilis import overlap_chunk
-import numpy as np
-from funasr.train_utils.device_funcs import to_device
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.transformer.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
-from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
-from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
-from funasr.models.transformer.positionwise_feed_forward import (
-    PositionwiseFeedForward,  # noqa: H301
-)
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
-from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
-
-from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
-
-class EncoderLayerSANM(nn.Module):
-    def __init__(
-        self,
-        in_size,
-        size,
-        self_attn,
-        feed_forward,
-        dropout_rate,
-        normalize_before=True,
-        concat_after=False,
-        stochastic_depth_rate=0.0,
-    ):
-        """Construct an EncoderLayer object."""
-        super(EncoderLayerSANM, self).__init__()
-        self.self_attn = self_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(in_size)
-        self.norm2 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.in_size = in_size
-        self.size = size
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        if self.concat_after:
-            self.concat_linear = nn.Linear(size + size, size)
-        self.stochastic_depth_rate = stochastic_depth_rate
-        self.dropout_rate = dropout_rate
-
-    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
-        """Compute encoded features.
-
-        Args:
-            x_input (torch.Tensor): Input tensor (#batch, time, size).
-            mask (torch.Tensor): Mask tensor for the input (#batch, time).
-            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time, size).
-            torch.Tensor: Mask tensor (#batch, time).
-
-        """
-        skip_layer = False
-        # with stochastic depth, residual connection `x + f(x)` becomes
-        # `x <- x + 1 / (1 - p) * f(x)` at training time.
-        stoch_layer_coeff = 1.0
-        if self.training and self.stochastic_depth_rate > 0:
-            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
-            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
-        if skip_layer:
-            if cache is not None:
-                x = torch.cat([cache, x], dim=1)
-            return x, mask
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm1(x)
-
-        if self.concat_after:
-            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
-            if self.in_size == self.size:
-                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
-            else:
-                x = stoch_layer_coeff * self.concat_linear(x_concat)
-        else:
-            if self.in_size == self.size:
-                x = residual + stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
-                )
-            else:
-                x = stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
-                )
-        if not self.normalize_before:
-            x = self.norm1(x)
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm2(x)
-        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm2(x)
-
-        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
-
-    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
-        """Compute encoded features.
-
-        Args:
-            x_input (torch.Tensor): Input tensor (#batch, time, size).
-            mask (torch.Tensor): Mask tensor for the input (#batch, time).
-            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time, size).
-            torch.Tensor: Mask tensor (#batch, time).
-
-        """
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm1(x)
-
-        if self.in_size == self.size:
-            attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
-            x = residual + attn
-        else:
-            x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
-
-        if not self.normalize_before:
-            x = self.norm1(x)
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm2(x)
-        x = residual + self.feed_forward(x)
-        if not self.normalize_before:
-            x = self.norm2(x)
-
-        return x, cache
-
-
-class SANMEncoder(AbsEncoder):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    San-m: Memory equipped self-attention for end-to-end speech recognition
-    https://arxiv.org/abs/2006.01713
-
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        output_size: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        attention_dropout_rate: float = 0.0,
-        input_layer: Optional[str] = "conv2d",
-        pos_enc_class=SinusoidalPositionEncoder,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        positionwise_layer_type: str = "linear",
-        positionwise_conv_kernel_size: int = 1,
-        padding_idx: int = -1,
-        interctc_layer_idx: List[int] = [],
-        interctc_use_conditioning: bool = False,
-        kernel_size : int = 11,
-        sanm_shfit : int = 0,
-        lora_list: List[str] = None,
-        lora_rank: int = 8,
-        lora_alpha: int = 16,
-        lora_dropout: float = 0.1,
-        selfattention_layer_type: str = "sanm",
-        tf2torch_tensor_name_prefix_torch: str = "encoder",
-        tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
-    ):
-        super().__init__()
-        self._output_size = output_size
-
-        if input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(input_size, output_size),
-                torch.nn.LayerNorm(output_size),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2d":
-            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d2":
-            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d6":
-            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d8":
-            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
-        elif input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
-                SinusoidalPositionEncoder(),
-            )
-        elif input_layer is None:
-            if input_size == output_size:
-                self.embed = None
-            else:
-                self.embed = torch.nn.Linear(input_size, output_size)
-        elif input_layer == "pe":
-            self.embed = SinusoidalPositionEncoder()
-        elif input_layer == "pe_online":
-            self.embed = StreamSinusoidalPositionEncoder()
-        else:
-            raise ValueError("unknown input_layer: " + input_layer)
-        self.normalize_before = normalize_before
-        if positionwise_layer_type == "linear":
-            positionwise_layer = PositionwiseFeedForward
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d":
-            positionwise_layer = MultiLayeredConv1d
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d-linear":
-            positionwise_layer = Conv1dLinear
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        else:
-            raise NotImplementedError("Support only linear or conv1d.")
-
-        if selfattention_layer_type == "selfattn":
-            encoder_selfattn_layer = MultiHeadedAttention
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                attention_dropout_rate,
-            )
-
-        elif selfattention_layer_type == "sanm":
-            encoder_selfattn_layer = MultiHeadedAttentionSANM
-            encoder_selfattn_layer_args0 = (
-                attention_heads,
-                input_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-                lora_list,
-                lora_rank,
-                lora_alpha,
-                lora_dropout,
-            )
-
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-                lora_list,
-                lora_rank,
-                lora_alpha,
-                lora_dropout,
-            )
-        self.encoders0 = repeat(
-            1,
-            lambda lnum: EncoderLayerSANM(
-                input_size,
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-
-        self.encoders = repeat(
-            num_blocks-1,
-            lambda lnum: EncoderLayerSANM(
-                output_size,
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-
-        self.interctc_layer_idx = interctc_layer_idx
-        if len(interctc_layer_idx) > 0:
-            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
-        self.interctc_use_conditioning = interctc_use_conditioning
-        self.conditioning_layer = None
-        self.dropout = nn.Dropout(dropout_rate)
-        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
-        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-
-    def output_size(self) -> int:
-        return self._output_size
-
-    def forward(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        prev_states: torch.Tensor = None,
-        ctc: CTC = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Embed positions in tensor.
-
-        Args:
-            xs_pad: input tensor (B, L, D)
-            ilens: input length (B)
-            prev_states: Not to be used now.
-        Returns:
-            position embedded tensor and mask
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad = xs_pad * self.output_size()**0.5
-        if self.embed is None:
-            xs_pad = xs_pad
-        elif (
-            isinstance(self.embed, Conv2dSubsampling)
-            or isinstance(self.embed, Conv2dSubsampling2)
-            or isinstance(self.embed, Conv2dSubsampling6)
-            or isinstance(self.embed, Conv2dSubsampling8)
-        ):
-            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
-            if short_status:
-                raise TooShortUttError(
-                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
-                    + f"(it needs more than {limit_size} frames), return empty results",
-                    xs_pad.size(1),
-                    limit_size,
-                )
-            xs_pad, masks = self.embed(xs_pad, masks)
-        else:
-            xs_pad = self.embed(xs_pad)
-
-        # xs_pad = self.dropout(xs_pad)
-        encoder_outs = self.encoders0(xs_pad, masks)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        intermediate_outs = []
-        if len(self.interctc_layer_idx) == 0:
-            encoder_outs = self.encoders(xs_pad, masks)
-            xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        else:
-            for layer_idx, encoder_layer in enumerate(self.encoders):
-                encoder_outs = encoder_layer(xs_pad, masks)
-                xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
-                if layer_idx + 1 in self.interctc_layer_idx:
-                    encoder_out = xs_pad
-
-                    # intermediate outputs are also normalized
-                    if self.normalize_before:
-                        encoder_out = self.after_norm(encoder_out)
-
-                    intermediate_outs.append((layer_idx + 1, encoder_out))
-
-                    if self.interctc_use_conditioning:
-                        ctc_out = ctc.softmax(encoder_out)
-                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-
-        olens = masks.squeeze(1).sum(1)
-        if len(intermediate_outs) > 0:
-            return (xs_pad, intermediate_outs), olens, None
-        return xs_pad, olens, None
-
-    def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
-        if len(cache) == 0:
-            return feats
-        cache["feats"] = to_device(cache["feats"], device=feats.device)
-        overlap_feats = torch.cat((cache["feats"], feats), dim=1)
-        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
-        return overlap_feats
-
-    def forward_chunk(self,
-                      xs_pad: torch.Tensor,
-                      ilens: torch.Tensor,
-                      cache: dict = None,
-                      ctc: CTC = None,
-                      ):
-        xs_pad *= self.output_size() ** 0.5
-        if self.embed is None:
-            xs_pad = xs_pad
-        else:
-            xs_pad = self.embed(xs_pad, cache)
-        if cache["tail_chunk"]:
-            xs_pad = to_device(cache["feats"], device=xs_pad.device)
-        else:
-            xs_pad = self._add_overlap_chunk(xs_pad, cache)
-        encoder_outs = self.encoders0(xs_pad, None, None, None, None)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        intermediate_outs = []
-        if len(self.interctc_layer_idx) == 0:
-            encoder_outs = self.encoders(xs_pad, None, None, None, None)
-            xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        else:
-            for layer_idx, encoder_layer in enumerate(self.encoders):
-                encoder_outs = encoder_layer(xs_pad, None, None, None, None)
-                xs_pad, masks = encoder_outs[0], encoder_outs[1]
-                if layer_idx + 1 in self.interctc_layer_idx:
-                    encoder_out = xs_pad
-
-                    # intermediate outputs are also normalized
-                    if self.normalize_before:
-                        encoder_out = self.after_norm(encoder_out)
-
-                    intermediate_outs.append((layer_idx + 1, encoder_out))
-
-                    if self.interctc_use_conditioning:
-                        ctc_out = ctc.softmax(encoder_out)
-                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-
-        if len(intermediate_outs) > 0:
-            return (xs_pad, intermediate_outs), None, None
-        return xs_pad, ilens, None
-
-    def gen_tf2torch_map_dict(self):
-        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
-        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
-        map_dict_local = {
-            ## encoder
-            # cicd
-            "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (768,256),(1,256,768)
-            "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (768,),(768,)
-            "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 2, 0),
-                 },  # (256,1,31),(1,31,256,1)
-            "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # ffn
-            "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-            "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # out norm
-            "{}.after_norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.after_norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-        
-        }
-    
-        return map_dict_local
-
-    def convert_tf2torch(self,
-                         var_dict_tf,
-                         var_dict_torch,
-                         ):
-        
-        map_dict = self.gen_tf2torch_map_dict()
-    
-        var_dict_torch_update = dict()
-        for name in sorted(var_dict_torch.keys(), reverse=False):
-            names = name.split('.')
-            if names[0] == self.tf2torch_tensor_name_prefix_torch:
-                if names[1] == "encoders0":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                
-                    name_q = name_q.replace("encoders0", "encoders")
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-                elif names[1] == "encoders":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                    layeridx_bias = 1
-                    layeridx += layeridx_bias
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "after_norm":
-                    name_tf = map_dict[name]["name"]
-                    data_tf = var_dict_tf[name_tf]
-                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                    var_dict_torch_update[name] = data_tf
-                    logging.info(
-                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                      var_dict_tf[name_tf].shape))
-    
-        return var_dict_torch_update
-
-
-class SANMEncoderChunkOpt(AbsEncoder):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
-    https://arxiv.org/abs/2006.01713
-
-    """
-
-    def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            input_layer: Optional[str] = "conv2d",
-            pos_enc_class=SinusoidalPositionEncoder,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            positionwise_layer_type: str = "linear",
-            positionwise_conv_kernel_size: int = 1,
-            padding_idx: int = -1,
-            interctc_layer_idx: List[int] = [],
-            interctc_use_conditioning: bool = False,
-            kernel_size: int = 11,
-            sanm_shfit: int = 0,
-            selfattention_layer_type: str = "sanm",
-            chunk_size: Union[int, Sequence[int]] = (16,),
-            stride: Union[int, Sequence[int]] = (10,),
-            pad_left: Union[int, Sequence[int]] = (0,),
-            encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
-            decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
-            tf2torch_tensor_name_prefix_torch: str = "encoder",
-            tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
-    ):
-        super().__init__()
-        self._output_size = output_size
-
-        if input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(input_size, output_size),
-                torch.nn.LayerNorm(output_size),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2d":
-            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d2":
-            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d6":
-            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d8":
-            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
-        elif input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer is None:
-            if input_size == output_size:
-                self.embed = None
-            else:
-                self.embed = torch.nn.Linear(input_size, output_size)
-        elif input_layer == "pe":
-            self.embed = SinusoidalPositionEncoder()
-        elif input_layer == "pe_online":
-            self.embed = StreamSinusoidalPositionEncoder()
-        else:
-            raise ValueError("unknown input_layer: " + input_layer)
-        self.normalize_before = normalize_before
-        if positionwise_layer_type == "linear":
-            positionwise_layer = PositionwiseFeedForward
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d":
-            positionwise_layer = MultiLayeredConv1d
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d-linear":
-            positionwise_layer = Conv1dLinear
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        else:
-            raise NotImplementedError("Support only linear or conv1d.")
-
-        if selfattention_layer_type == "selfattn":
-            encoder_selfattn_layer = MultiHeadedAttention
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                attention_dropout_rate,
-            )
-        elif selfattention_layer_type == "sanm":
-            encoder_selfattn_layer = MultiHeadedAttentionSANM
-            encoder_selfattn_layer_args0 = (
-                attention_heads,
-                input_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-        self.encoders0 = repeat(
-            1,
-            lambda lnum: EncoderLayerSANM(
-                input_size,
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-
-        self.encoders = repeat(
-            num_blocks - 1,
-            lambda lnum: EncoderLayerSANM(
-                output_size,
-                output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-
-        self.interctc_layer_idx = interctc_layer_idx
-        if len(interctc_layer_idx) > 0:
-            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
-        self.interctc_use_conditioning = interctc_use_conditioning
-        self.conditioning_layer = None
-        shfit_fsmn = (kernel_size - 1) // 2
-        self.overlap_chunk_cls = overlap_chunk(
-            chunk_size=chunk_size,
-            stride=stride,
-            pad_left=pad_left,
-            shfit_fsmn=shfit_fsmn,
-            encoder_att_look_back_factor=encoder_att_look_back_factor,
-            decoder_att_look_back_factor=decoder_att_look_back_factor,
-        )
-        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
-        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
-
-    def output_size(self) -> int:
-        return self._output_size
-
-    def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            prev_states: torch.Tensor = None,
-            ctc: CTC = None,
-            ind: int = 0,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Embed positions in tensor.
-
-        Args:
-            xs_pad: input tensor (B, L, D)
-            ilens: input length (B)
-            prev_states: Not to be used now.
-        Returns:
-            position embedded tensor and mask
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad *= self.output_size() ** 0.5
-        if self.embed is None:
-            xs_pad = xs_pad
-        elif (
-                isinstance(self.embed, Conv2dSubsampling)
-                or isinstance(self.embed, Conv2dSubsampling2)
-                or isinstance(self.embed, Conv2dSubsampling6)
-                or isinstance(self.embed, Conv2dSubsampling8)
-        ):
-            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
-            if short_status:
-                raise TooShortUttError(
-                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
-                    + f"(it needs more than {limit_size} frames), return empty results",
-                    xs_pad.size(1),
-                    limit_size,
-                )
-            xs_pad, masks = self.embed(xs_pad, masks)
-        else:
-            xs_pad = self.embed(xs_pad)
-
-        mask_shfit_chunk, mask_att_chunk_encoder = None, None
-        if self.overlap_chunk_cls is not None:
-            ilens = masks.squeeze(1).sum(1)
-            chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
-            xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
-            masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-            mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
-                                                                           dtype=xs_pad.dtype)
-            mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
-                                                                                       xs_pad.size(0),
-                                                                                       dtype=xs_pad.dtype)
-
-        encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        intermediate_outs = []
-        if len(self.interctc_layer_idx) == 0:
-            encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
-            xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        else:
-            for layer_idx, encoder_layer in enumerate(self.encoders):
-                encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
-                xs_pad, masks = encoder_outs[0], encoder_outs[1]
-                if layer_idx + 1 in self.interctc_layer_idx:
-                    encoder_out = xs_pad
-
-                    # intermediate outputs are also normalized
-                    if self.normalize_before:
-                        encoder_out = self.after_norm(encoder_out)
-
-                    intermediate_outs.append((layer_idx + 1, encoder_out))
-
-                    if self.interctc_use_conditioning:
-                        ctc_out = ctc.softmax(encoder_out)
-                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-
-        olens = masks.squeeze(1).sum(1)
-        if len(intermediate_outs) > 0:
-            return (xs_pad, intermediate_outs), olens, None
-        return xs_pad, olens, None
-
-    def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
-        if len(cache) == 0:
-            return feats
-        cache["feats"] = to_device(cache["feats"], device=feats.device)
-        overlap_feats = torch.cat((cache["feats"], feats), dim=1)
-        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
-        return overlap_feats
-
-    def forward_chunk(self,
-                      xs_pad: torch.Tensor,
-                      ilens: torch.Tensor,
-                      cache: dict = None,
-                      ):
-        xs_pad *= self.output_size() ** 0.5
-        if self.embed is None:
-            xs_pad = xs_pad
-        else:
-            xs_pad = self.embed(xs_pad, cache)
-        if cache["tail_chunk"]:
-            xs_pad = to_device(cache["feats"], device=xs_pad.device)
-        else:
-            xs_pad = self._add_overlap_chunk(xs_pad, cache)
-        if cache["opt"] is None:
-            cache_layer_num = len(self.encoders0) + len(self.encoders)
-            new_cache = [None] * cache_layer_num
-        else:
-            new_cache = cache["opt"]
-
-        for layer_idx, encoder_layer in enumerate(self.encoders0):
-            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
-            xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
-
-        for layer_idx, encoder_layer in enumerate(self.encoders):
-            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
-            xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-        if cache["encoder_chunk_look_back"] > 0 or cache["encoder_chunk_look_back"] == -1:
-            cache["opt"] = new_cache
-
-        return xs_pad, ilens, None
-
-    def gen_tf2torch_map_dict(self):
-        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
-        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
-        map_dict_local = {
-            ## encoder
-            # cicd
-            "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (768,256),(1,256,768)
-            "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (768,),(768,)
-            "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 2, 0),
-                 },  # (256,1,31),(1,31,256,1)
-            "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # ffn
-            "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-            "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # out norm
-            "{}.after_norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.after_norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-        
-        }
-    
-        return map_dict_local
-
-    def convert_tf2torch(self,
-                         var_dict_tf,
-                         var_dict_torch,
-                         ):
-    
-        map_dict = self.gen_tf2torch_map_dict()
-    
-        var_dict_torch_update = dict()
-        for name in sorted(var_dict_torch.keys(), reverse=False):
-            names = name.split('.')
-            if names[0] == self.tf2torch_tensor_name_prefix_torch:
-                if names[1] == "encoders0":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                
-                    name_q = name_q.replace("encoders0", "encoders")
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-                elif names[1] == "encoders":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                    layeridx_bias = 1
-                    layeridx += layeridx_bias
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "after_norm":
-                    name_tf = map_dict[name]["name"]
-                    data_tf = var_dict_tf[name_tf]
-                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                    var_dict_torch_update[name] = data_tf
-                    logging.info(
-                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                      var_dict_tf[name_tf].shape))
-    
-        return var_dict_torch_update
-
-
-class SANMVadEncoder(AbsEncoder):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        output_size: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        attention_dropout_rate: float = 0.0,
-        input_layer: Optional[str] = "conv2d",
-        pos_enc_class=SinusoidalPositionEncoder,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        positionwise_layer_type: str = "linear",
-        positionwise_conv_kernel_size: int = 1,
-        padding_idx: int = -1,
-        interctc_layer_idx: List[int] = [],
-        interctc_use_conditioning: bool = False,
-        kernel_size : int = 11,
-        sanm_shfit : int = 0,
-        selfattention_layer_type: str = "sanm",
-    ):
-        super().__init__()
-        self._output_size = output_size
-
-        if input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(input_size, output_size),
-                torch.nn.LayerNorm(output_size),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2d":
-            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d2":
-            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d6":
-            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d8":
-            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
-        elif input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
-                SinusoidalPositionEncoder(),
-            )
-        elif input_layer is None:
-            if input_size == output_size:
-                self.embed = None
-            else:
-                self.embed = torch.nn.Linear(input_size, output_size)
-        elif input_layer == "pe":
-            self.embed = SinusoidalPositionEncoder()
-        else:
-            raise ValueError("unknown input_layer: " + input_layer)
-        self.normalize_before = normalize_before
-        if positionwise_layer_type == "linear":
-            positionwise_layer = PositionwiseFeedForward
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d":
-            positionwise_layer = MultiLayeredConv1d
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d-linear":
-            positionwise_layer = Conv1dLinear
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        else:
-            raise NotImplementedError("Support only linear or conv1d.")
-
-        if selfattention_layer_type == "selfattn":
-            encoder_selfattn_layer = MultiHeadedAttention
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                attention_dropout_rate,
-            )
-
-        elif selfattention_layer_type == "sanm":
-            self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
-            encoder_selfattn_layer_args0 = (
-                attention_heads,
-                input_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-
-        self.encoders0 = repeat(
-            1,
-            lambda lnum: EncoderLayerSANM(
-                input_size,
-                output_size,
-                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-
-        self.encoders = repeat(
-            num_blocks-1,
-            lambda lnum: EncoderLayerSANM(
-                output_size,
-                output_size,
-                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-
-        self.interctc_layer_idx = interctc_layer_idx
-        if len(interctc_layer_idx) > 0:
-            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
-        self.interctc_use_conditioning = interctc_use_conditioning
-        self.conditioning_layer = None
-        self.dropout = nn.Dropout(dropout_rate)
-
-    def output_size(self) -> int:
-        return self._output_size
-
-    def forward(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        vad_indexes: torch.Tensor,
-        prev_states: torch.Tensor = None,
-        ctc: CTC = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Embed positions in tensor.
-
-        Args:
-            xs_pad: input tensor (B, L, D)
-            ilens: input length (B)
-            prev_states: Not to be used now.
-        Returns:
-            position embedded tensor and mask
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
-        no_future_masks = masks & sub_masks
-        xs_pad *= self.output_size()**0.5
-        if self.embed is None:
-            xs_pad = xs_pad
-        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
-              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
-            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
-            if short_status:
-                raise TooShortUttError(
-                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
-                    f"(it needs more than {limit_size} frames), return empty results",
-                    xs_pad.size(1),
-                    limit_size,
-                )
-            xs_pad, masks = self.embed(xs_pad, masks)
-        else:
-            xs_pad = self.embed(xs_pad)
-
-        # xs_pad = self.dropout(xs_pad)
-        mask_tup0 = [masks, no_future_masks]
-        encoder_outs = self.encoders0(xs_pad, mask_tup0)
-        xs_pad, _ = encoder_outs[0], encoder_outs[1]
-        intermediate_outs = []
-
-
-        for layer_idx, encoder_layer in enumerate(self.encoders):
-                if layer_idx + 1 == len(self.encoders):
-                    # This is last layer.
-                    coner_mask = torch.ones(masks.size(0),
-                                            masks.size(-1),
-                                            masks.size(-1),
-                                            device=xs_pad.device,
-                                            dtype=torch.bool)
-                    for word_index, length in enumerate(ilens):
-                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
-                                                                vad_indexes[word_index],
-                                                                device=xs_pad.device)
-                    layer_mask = masks & coner_mask
-                else:
-                    layer_mask = no_future_masks
-                mask_tup1 = [masks, layer_mask]
-                encoder_outs = encoder_layer(xs_pad, mask_tup1)
-                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-
-        olens = masks.squeeze(1).sum(1)
-        if len(intermediate_outs) > 0:
-            return (xs_pad, intermediate_outs), olens, None
-        return xs_pad, olens, None
diff --git a/funasr/models/paraformer/contextual_decoder.py b/funasr/models/scama/sanm_decoder.py
similarity index 70%
copy from funasr/models/paraformer/contextual_decoder.py
copy to funasr/models/scama/sanm_decoder.py
index 626cdef..53423d0 100644
--- a/funasr/models/paraformer/contextual_decoder.py
+++ b/funasr/models/scama/sanm_decoder.py
@@ -6,17 +6,38 @@
 import numpy as np
 
 from funasr.models.scama import utils as myutils
-from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
+from funasr.models.transformer.decoder import BaseTransformerDecoder
 
-from funasr.models.transformer.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
 from funasr.models.transformer.embedding import PositionalEncoding
 from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.transformer.repeat import repeat
-from funasr.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.transformer.utils.repeat import repeat
+
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayerSANM(nn.Module):
+    """Single decoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        src_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
 
 
-class ContextualDecoderLayer(nn.Module):
+    """
+
     def __init__(
         self,
         size,
@@ -28,7 +49,7 @@
         concat_after=False,
     ):
         """Construct an DecoderLayer object."""
-        super(ContextualDecoderLayer, self).__init__()
+        super(DecoderLayerSANM, self).__init__()
         self.size = size
         self.self_attn = self_attn
         self.src_attn = src_attn
@@ -45,85 +66,161 @@
             self.concat_linear1 = nn.Linear(size + size, size)
             self.concat_linear2 = nn.Linear(size + size, size)
 
-    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
         # tgt = self.dropout(tgt)
-        if isinstance(tgt, Tuple):
-            tgt, _ = tgt
         residual = tgt
         if self.normalize_before:
             tgt = self.norm1(tgt)
         tgt = self.feed_forward(tgt)
 
         x = tgt
-        if self.normalize_before:
-            tgt = self.norm2(tgt)
-        if self.training:
-            cache = None
-        x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
-        x = residual + self.dropout(x)
-        x_self_attn = x
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, _ = self.self_attn(tgt, tgt_mask)
+            x = residual + self.dropout(x)
 
-        residual = x
-        if self.normalize_before:
-            x = self.norm3(x)
-        x = self.src_attn(x, memory, memory_mask)
-        x_src_attn = x
-
-        x = residual + self.dropout(x)
-        return x, tgt_mask, x_self_attn, x_src_attn
-
-
-class ContextualBiasDecoder(nn.Module):
-    def __init__(
-        self,
-        size,
-        src_attn,
-        dropout_rate,
-        normalize_before=True,
-    ):
-        """Construct an DecoderLayer object."""
-        super(ContextualBiasDecoder, self).__init__()
-        self.size = size
-        self.src_attn = src_attn
-        if src_attn is not None:
-            self.norm3 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.normalize_before = normalize_before
-
-    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
-        x = tgt
         if self.src_attn is not None:
+            residual = x
             if self.normalize_before:
                 x = self.norm3(x)
-            x =  self.dropout(self.src_attn(x, memory, memory_mask))
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
         return x, tgt_mask, memory, memory_mask, cache
 
+    def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
 
-class ContextualParaformerDecoder(ParaformerSANMDecoder):
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        # tgt = self.dropout(tgt)
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            if self.training:
+                cache = None
+            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
+            x = residual + x
+
+        return x, memory, fsmn_cache, opt_cache
+
+@register_class("decoder_classes", "FsmnDecoderSCAMAOpt")
+class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
-    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
+
     """
     def __init__(
-        self,
-        vocab_size: int,
-        encoder_output_size: int,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        self_attention_dropout_rate: float = 0.0,
-        src_attention_dropout_rate: float = 0.0,
-        input_layer: str = "embed",
-        use_output_layer: bool = True,
-        pos_enc_class=PositionalEncoding,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        att_layer_num: int = 6,
-        kernel_size: int = 21,
-        sanm_shfit: int = 0,
+            self,
+            vocab_size: int,
+            encoder_output_size: int,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            self_attention_dropout_rate: float = 0.0,
+            src_attention_dropout_rate: float = 0.0,
+            input_layer: str = "embed",
+            use_output_layer: bool = True,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            att_layer_num: int = 6,
+            kernel_size: int = 21,
+            sanm_shfit: int = None,
+            concat_embeds: bool = False,
+            attention_dim: int = None,
+            tf2torch_tensor_name_prefix_torch: str = "decoder",
+            tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+            embed_tensor_name_prefix_tf: str = None,
     ):
         super().__init__(
             vocab_size=vocab_size,
@@ -135,14 +232,12 @@
             pos_enc_class=pos_enc_class,
             normalize_before=normalize_before,
         )
+        if attention_dim is None:
+            attention_dim = encoder_output_size
 
-        attention_dim = encoder_output_size
-        if input_layer == 'none':
-            self.embed = None
         if input_layer == "embed":
             self.embed = torch.nn.Sequential(
                 torch.nn.Embedding(vocab_size, attention_dim),
-                # pos_enc_class(attention_dim, positional_dropout_rate),
             )
         elif input_layer == "linear":
             self.embed = torch.nn.Sequential(
@@ -168,14 +263,14 @@
         if sanm_shfit is None:
             sanm_shfit = (kernel_size - 1) // 2
         self.decoders = repeat(
-            att_layer_num - 1,
+            att_layer_num,
             lambda lnum: DecoderLayerSANM(
                 attention_dim,
                 MultiHeadedAttentionSANMDecoder(
                     attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                 ),
                 MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate
+                    attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
                 ),
                 PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
@@ -183,29 +278,6 @@
                 concat_after,
             ),
         )
-        self.dropout = nn.Dropout(dropout_rate)
-        self.bias_decoder = ContextualBiasDecoder(
-            size=attention_dim,
-            src_attn=MultiHeadedAttentionCrossAtt(
-                attention_heads, attention_dim, src_attention_dropout_rate
-            ),
-            dropout_rate=dropout_rate,
-            normalize_before=True,
-        )
-        self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
-        self.last_decoder = ContextualDecoderLayer(
-                attention_dim,
-                MultiHeadedAttentionSANMDecoder(
-                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
-                ),
-                MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
-                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            )
         if num_blocks - att_layer_num <= 0:
             self.decoders2 = None
         else:
@@ -214,7 +286,7 @@
                 lambda lnum: DecoderLayerSANM(
                     attention_dim,
                     MultiHeadedAttentionSANMDecoder(
-                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                     ),
                     None,
                     PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -236,16 +308,35 @@
                 concat_after,
             ),
         )
+        if concat_embeds:
+            self.embed_concat_ffn = repeat(
+                1,
+                lambda lnum: DecoderLayerSANM(
+                    attention_dim + encoder_output_size,
+                    None,
+                    None,
+                    PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
+                                                      adim=attention_dim),
+                    dropout_rate,
+                    normalize_before,
+                    concat_after,
+                ),
+            )
+        else:
+            self.embed_concat_ffn = None
+        self.concat_embeds = concat_embeds
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+        self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
 
     def forward(
-        self,
-        hs_pad: torch.Tensor,
-        hlens: torch.Tensor,
-        ys_in_pad: torch.Tensor,
-        ys_in_lens: torch.Tensor,
-        contextual_info: torch.Tensor,
-        clas_scale: float = 1.0,
-        return_hidden: bool = False,
+            self,
+            hs_pad: torch.Tensor,
+            hlens: torch.Tensor,
+            ys_in_pad: torch.Tensor,
+            ys_in_lens: torch.Tensor,
+            chunk_mask: torch.Tensor = None,
+            pre_acoustic_embeds: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
 
@@ -269,46 +360,122 @@
 
         memory = hs_pad
         memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        if chunk_mask is not None:
+            memory_mask = memory_mask * chunk_mask
+            if tgt_mask.size(1) != memory_mask.size(1):
+                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
 
-        x = tgt
+        x = self.embed(tgt)
+
+        if pre_acoustic_embeds is not None and self.concat_embeds:
+            x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+            x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+
         x, tgt_mask, memory, memory_mask, _ = self.decoders(
             x, tgt_mask, memory, memory_mask
         )
-        _, _, x_self_attn, x_src_attn = self.last_decoder(
-            x, tgt_mask, memory, memory_mask
-        )
-
-        # contextual paraformer related
-        contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
-        contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
-        cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
-
-        if self.bias_output is not None:
-            x = torch.cat([x_src_attn, cx*clas_scale], dim=2)
-            x = self.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
-            x = x_self_attn + self.dropout(x)
-
         if self.decoders2 is not None:
             x, tgt_mask, memory, memory_mask, _ = self.decoders2(
                 x, tgt_mask, memory, memory_mask
             )
-
         x, tgt_mask, memory, memory_mask, _ = self.decoders3(
             x, tgt_mask, memory, memory_mask
         )
         if self.normalize_before:
             x = self.after_norm(x)
-        olens = tgt_mask.sum(1)
-        if self.output_layer is not None and return_hidden is False:
+        if self.output_layer is not None:
             x = self.output_layer(x)
+
+        olens = tgt_mask.sum(1)
         return x, olens
 
-    def gen_tf2torch_map_dict(self):
+    def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
+        """Score."""
+        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+        logp, state = self.forward_one_step(
+            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
+            cache=state
+        )
+        return logp.squeeze(0), state
 
+    def forward_one_step(
+            self,
+            tgt: torch.Tensor,
+            tgt_mask: torch.Tensor,
+            memory: torch.Tensor,
+            memory_mask: torch.Tensor = None,
+            pre_acoustic_embeds: torch.Tensor = None,
+            cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Forward one step.
+
+        Args:
+            tgt: input token ids, int64 (batch, maxlen_out)
+            tgt_mask: input token mask,  (batch, maxlen_out)
+                      dtype=torch.uint8 in PyTorch 1.2-
+                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+            memory: encoded memory, float32  (batch, maxlen_in, feat)
+            cache: cached output list of (batch, max_time_out-1, size)
+        Returns:
+            y, cache: NN output value and cache per `self.decoders`.
+            y.shape` is (batch, maxlen_out, token)
+        """
+
+        x = tgt[:, -1:]
+        tgt_mask = None
+        x = self.embed(x)
+
+        if pre_acoustic_embeds is not None and self.concat_embeds:
+            x = torch.cat((x, pre_acoustic_embeds), dim=-1)
+            x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
+
+        if cache is None:
+            cache_layer_num = len(self.decoders)
+            if self.decoders2 is not None:
+                cache_layer_num += len(self.decoders2)
+            cache = [None] * cache_layer_num
+        new_cache = []
+        # for c, decoder in zip(cache, self.decoders):
+        for i in range(self.att_layer_num):
+            decoder = self.decoders[i]
+            c = cache[i]
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                x, tgt_mask, memory, memory_mask, cache=c
+            )
+            new_cache.append(c_ret)
+
+        if self.num_blocks - self.att_layer_num >= 1:
+            for i in range(self.num_blocks - self.att_layer_num):
+                j = i + self.att_layer_num
+                decoder = self.decoders2[i]
+                c = cache[j]
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                    x, tgt_mask, memory, memory_mask, cache=c
+                )
+                new_cache.append(c_ret)
+
+        for decoder in self.decoders3:
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
+                x, tgt_mask, memory, None, cache=None
+            )
+
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.output_layer is not None:
+            y = self.output_layer(y)
+            y = torch.log_softmax(y, dim=-1)
+
+        return y, new_cache
+
+    def gen_tf2torch_map_dict(self):
+    
         tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
         tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+        embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
         map_dict_local = {
-
+        
             ## decoder
             # ffn
             "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
@@ -346,7 +513,7 @@
                  "squeeze": 0,
                  "transpose": (1, 0),
                  },  # (256,1024),(1,1024,256)
-
+        
             # fsmn
             "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
                 {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
@@ -443,7 +610,7 @@
                  "squeeze": 0,
                  "transpose": (1, 0),
                  },  # (256,1024),(1,1024,256)
-
+        
             # embed_concat_ffn
             "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
                 {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
@@ -480,7 +647,7 @@
                  "squeeze": 0,
                  "transpose": (1, 0),
                  },  # (256,1024),(1,1024,256)
-
+        
             # out norm
             "{}.after_norm.weight".format(tensor_name_prefix_torch):
                 {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
@@ -492,17 +659,18 @@
                  "squeeze": None,
                  "transpose": None,
                  },  # (256,),(256,)
-
+        
             # in embed
             "{}.embed.0.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/w_embs".format(tensor_name_prefix_tf),
+                {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
                  "squeeze": None,
                  "transpose": None,
                  },  # (4235,256),(4235,256)
-
+        
             # out layer
             "{}.output_layer.weight".format(tensor_name_prefix_torch):
-                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
+                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
+                          "{}/w_embs".format(embed_tensor_name_prefix_tf)],
                  "squeeze": [None, None],
                  "transpose": [(1, 0), None],
                  },  # (4235,256),(256,4235)
@@ -512,56 +680,7 @@
                  "squeeze": [None, None],
                  "transpose": [None, None],
                  },  # (4235,),(4235,)
-
-            ## clas decoder
-            # src att
-            "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # dnn
-            "{}.bias_output.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": (2, 1, 0),
-                 },  # (1024,256),(1,256,1024)
-
+        
         }
         return map_dict_local
 
@@ -569,6 +688,7 @@
                          var_dict_tf,
                          var_dict_torch,
                          ):
+    
         map_dict = self.gen_tf2torch_map_dict()
         var_dict_torch_update = dict()
         decoder_layeridx_sets = set()
@@ -598,37 +718,13 @@
                         logging.info(
                             "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                           var_dict_tf[name_tf].shape))
-                elif names[1] == "last_decoder":
-                    layeridx = 15
-                    name_q = name.replace("last_decoder", "decoders.layeridx")
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    decoder_layeridx_sets.add(layeridx)
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-
-
+            
                 elif names[1] == "decoders2":
                     layeridx = int(names[2])
                     name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
                     name_q = name_q.replace("decoders2", "decoders")
                     layeridx_bias = len(decoder_layeridx_sets)
-
+                
                     layeridx += layeridx_bias
                     if "decoders." in name:
                         decoder_layeridx_sets.add(layeridx)
@@ -649,11 +745,11 @@
                         logging.info(
                             "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                           var_dict_tf[name_tf].shape))
-
+            
                 elif names[1] == "decoders3":
                     layeridx = int(names[2])
                     name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
+                
                     layeridx_bias = 0
                     layeridx += layeridx_bias
                     if "decoders." in name:
@@ -675,29 +771,8 @@
                         logging.info(
                             "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                           var_dict_tf[name_tf].shape))
-                elif names[1] == "bias_decoder":
-                    name_q = name
-
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-
-
-                elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
+            
+                elif names[1] == "embed" or names[1] == "output_layer":
                     name_tf = map_dict[name]["name"]
                     if isinstance(name_tf, list):
                         idx_list = 0
@@ -720,7 +795,7 @@
                                                                                                    name_tf[idx_list],
                                                                                                    var_dict_tf[name_tf[
                                                                                                        idx_list]].shape))
-
+                
                     else:
                         data_tf = var_dict_tf[name_tf]
                         if map_dict[name]["squeeze"] is not None:
@@ -736,7 +811,7 @@
                         logging.info(
                             "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                           var_dict_tf[name_tf].shape))
-
+            
                 elif names[1] == "after_norm":
                     name_tf = map_dict[name]["name"]
                     data_tf = var_dict_tf[name_tf]
@@ -745,11 +820,11 @@
                     logging.info(
                         "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                       var_dict_tf[name_tf].shape))
-
+            
                 elif names[1] == "embed_concat_ffn":
                     layeridx = int(names[2])
                     name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
+                
                     layeridx_bias = 0
                     layeridx += layeridx_bias
                     if "decoders." in name:
@@ -771,5 +846,6 @@
                         logging.info(
                             "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
                                                                                           var_dict_tf[name_tf].shape))
-
+    
         return var_dict_torch_update
+
diff --git a/funasr/models/scama/sanm_encoder.py b/funasr/models/scama/sanm_encoder.py
new file mode 100644
index 0000000..c89bfb3
--- /dev/null
+++ b/funasr/models/scama/sanm_encoder.py
@@ -0,0 +1,613 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from funasr.models.scama.chunk_utilis import overlap_chunk
+import numpy as np
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
+from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.positionwise_feed_forward import (
+    PositionwiseFeedForward,  # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
+
+from funasr.models.ctc.ctc import CTC
+
+from funasr.utils.register import register_class, registry_tables
+
+class EncoderLayerSANM(nn.Module):
+    def __init__(
+        self,
+        in_size,
+        size,
+        self_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+        stochastic_depth_rate=0.0,
+    ):
+        """Construct an EncoderLayer object."""
+        super(EncoderLayerSANM, self).__init__()
+        self.self_attn = self_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(in_size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.in_size = in_size
+        self.size = size
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear = nn.Linear(size + size, size)
+        self.stochastic_depth_rate = stochastic_depth_rate
+        self.dropout_rate = dropout_rate
+
+    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+        skip_layer = False
+        # with stochastic depth, residual connection `x + f(x)` becomes
+        # `x <- x + 1 / (1 - p) * f(x)` at training time.
+        stoch_layer_coeff = 1.0
+        if self.training and self.stochastic_depth_rate > 0:
+            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+        if skip_layer:
+            if cache is not None:
+                x = torch.cat([cache, x], dim=1)
+            return x, mask
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.concat_after:
+            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+            else:
+                x = stoch_layer_coeff * self.concat_linear(x_concat)
+        else:
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+            else:
+                x = stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.in_size == self.size:
+            attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+            x = residual + attn
+        else:
+            x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + self.feed_forward(x)
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        return x, cache
+
+
+@register_class("encoder_classes", "SANMEncoderChunkOpt")
+class SANMEncoderChunkOpt(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
+
+    """
+
+    def __init__(
+            self,
+            input_size: int,
+            output_size: int = 256,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            attention_dropout_rate: float = 0.0,
+            input_layer: Optional[str] = "conv2d",
+            pos_enc_class=SinusoidalPositionEncoder,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            positionwise_layer_type: str = "linear",
+            positionwise_conv_kernel_size: int = 1,
+            padding_idx: int = -1,
+            interctc_layer_idx: List[int] = [],
+            interctc_use_conditioning: bool = False,
+            kernel_size: int = 11,
+            sanm_shfit: int = 0,
+            selfattention_layer_type: str = "sanm",
+            chunk_size: Union[int, Sequence[int]] = (16,),
+            stride: Union[int, Sequence[int]] = (10,),
+            pad_left: Union[int, Sequence[int]] = (0,),
+            encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+            decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+            tf2torch_tensor_name_prefix_torch: str = "encoder",
+            tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+    ):
+        super().__init__()
+        self._output_size = output_size
+
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        elif input_layer == "pe":
+            self.embed = SinusoidalPositionEncoder()
+        elif input_layer == "pe_online":
+            self.embed = StreamSinusoidalPositionEncoder()
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+        elif selfattention_layer_type == "sanm":
+            encoder_selfattn_layer = MultiHeadedAttentionSANM
+            encoder_selfattn_layer_args0 = (
+                attention_heads,
+                input_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+        self.encoders = repeat(
+            num_blocks - 1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+        shfit_fsmn = (kernel_size - 1) // 2
+        self.overlap_chunk_cls = overlap_chunk(
+            chunk_size=chunk_size,
+            stride=stride,
+            pad_left=pad_left,
+            shfit_fsmn=shfit_fsmn,
+            encoder_att_look_back_factor=encoder_att_look_back_factor,
+            decoder_att_look_back_factor=decoder_att_look_back_factor,
+        )
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+            prev_states: torch.Tensor = None,
+            ctc: CTC = None,
+            ind: int = 0,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        xs_pad *= self.output_size() ** 0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (
+                isinstance(self.embed, Conv2dSubsampling)
+                or isinstance(self.embed, Conv2dSubsampling2)
+                or isinstance(self.embed, Conv2dSubsampling6)
+                or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        mask_shfit_chunk, mask_att_chunk_encoder = None, None
+        if self.overlap_chunk_cls is not None:
+            ilens = masks.squeeze(1).sum(1)
+            chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
+            xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
+            masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+            mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
+                                                                           dtype=xs_pad.dtype)
+            mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
+                                                                                       xs_pad.size(0),
+                                                                                       dtype=xs_pad.dtype)
+
+        encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+        if len(self.interctc_layer_idx) == 0:
+            encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+                xs_pad, masks = encoder_outs[0], encoder_outs[1]
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens, None
+
+    def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+        if len(cache) == 0:
+            return feats
+        cache["feats"] = to_device(cache["feats"], device=feats.device)
+        overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+        return overlap_feats
+
+    def forward_chunk(self,
+                      xs_pad: torch.Tensor,
+                      ilens: torch.Tensor,
+                      cache: dict = None,
+                      ):
+        xs_pad *= self.output_size() ** 0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        else:
+            xs_pad = self.embed(xs_pad, cache)
+        if cache["tail_chunk"]:
+            xs_pad = to_device(cache["feats"], device=xs_pad.device)
+        else:
+            xs_pad = self._add_overlap_chunk(xs_pad, cache)
+        if cache["opt"] is None:
+            cache_layer_num = len(self.encoders0) + len(self.encoders)
+            new_cache = [None] * cache_layer_num
+        else:
+            new_cache = cache["opt"]
+
+        for layer_idx, encoder_layer in enumerate(self.encoders0):
+            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
+            xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
+
+        for layer_idx, encoder_layer in enumerate(self.encoders):
+            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
+            xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+        if cache["encoder_chunk_look_back"] > 0 or cache["encoder_chunk_look_back"] == -1:
+            cache["opt"] = new_cache
+
+        return xs_pad, ilens, None
+
+    def gen_tf2torch_map_dict(self):
+        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+        map_dict_local = {
+            ## encoder
+            # cicd
+            "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (768,256),(1,256,768)
+            "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (768,),(768,)
+            "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 2, 0),
+                 },  # (256,1,31),(1,31,256,1)
+            "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,256),(1,256,256)
+            "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            # ffn
+            "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+            "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            # out norm
+            "{}.after_norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.after_norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+        
+        }
+    
+        return map_dict_local
+
+    def convert_tf2torch(self,
+                         var_dict_tf,
+                         var_dict_torch,
+                         ):
+    
+        map_dict = self.gen_tf2torch_map_dict()
+    
+        var_dict_torch_update = dict()
+        for name in sorted(var_dict_torch.keys(), reverse=False):
+            names = name.split('.')
+            if names[0] == self.tf2torch_tensor_name_prefix_torch:
+                if names[1] == "encoders0":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                
+                    name_q = name_q.replace("encoders0", "encoders")
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+                elif names[1] == "encoders":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                    layeridx_bias = 1
+                    layeridx += layeridx_bias
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+            
+                elif names[1] == "after_norm":
+                    name_tf = map_dict[name]["name"]
+                    data_tf = var_dict_tf[name_tf]
+                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                    var_dict_torch_update[name] = data_tf
+                    logging.info(
+                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+                                                                                      var_dict_tf[name_tf].shape))
+    
+        return var_dict_torch_update
+
diff --git a/funasr/models/sond/attention.py b/funasr/models/sond/attention.py
new file mode 100644
index 0000000..290ab03
--- /dev/null
+++ b/funasr/models/sond/attention.py
@@ -0,0 +1,328 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+from typing import Optional, Tuple
+
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+import funasr.models.lora.layers as lora
+
+class MultiHeadedAttention(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttention, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        self.linear_q = nn.Linear(n_feat, n_feat)
+        self.linear_k = nn.Linear(n_feat, n_feat)
+        self.linear_v = nn.Linear(n_feat, n_feat)
+        self.linear_out = nn.Linear(n_feat, n_feat)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, query, key, value):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        n_batch = query.size(0)
+        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
+        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
+        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q, k, v
+
+    def forward_attention(self, value, scores, mask):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, query, key, value, mask):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, mask)
+
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+    """Multi-Head Attention layer with relative position encoding (new implementation).
+
+    Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+    Paper: https://arxiv.org/abs/1901.02860
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
+        """Construct an RelPositionMultiHeadedAttention object."""
+        super().__init__(n_head, n_feat, dropout_rate)
+        self.zero_triu = zero_triu
+        # linear transformation for positional encoding
+        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+        # these two learnable bias are used in matrix c and matrix d
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+        torch.nn.init.xavier_uniform_(self.pos_bias_u)
+        torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+    def rel_shift(self, x):
+        """Compute relative positional encoding.
+
+        Args:
+            x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
+            time1 means the length of query vector.
+
+        Returns:
+            torch.Tensor: Output tensor.
+
+        """
+        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+        x_padded = torch.cat([zero_pad, x], dim=-1)
+
+        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+        x = x_padded[:, :, 1:].view_as(x)[
+            :, :, :, : x.size(-1) // 2 + 1
+            ]  # only keep the positions from 0 to time2
+
+        if self.zero_triu:
+            ones = torch.ones((x.size(2), x.size(3)), device=x.device)
+            x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+        return x
+
+    def forward(self, query, key, value, pos_emb, mask):
+        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            pos_emb (torch.Tensor): Positional embedding tensor
+                (#batch, 2*time1-1, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
+
+        n_batch_pos = pos_emb.size(0)
+        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+        p = p.transpose(1, 2)  # (batch, head, 2*time1-1, d_k)
+
+        # (batch, head, time1, d_k)
+        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+        # (batch, head, time1, d_k)
+        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+        # compute matrix b and matrix d
+        # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        scores = (matrix_ac + matrix_bd) / math.sqrt(
+            self.d_k
+        )  # (batch, head, time1, time2)
+
+        return self.forward_attention(v, scores, mask)
+
+
+
+
+
+
+class MultiHeadSelfAttention(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, in_feat, n_feat, dropout_rate):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadSelfAttention, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        self.linear_out = nn.Linear(n_feat, n_feat)
+        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, x):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        b, t, d = x.size()
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q_h, k_h, v_h, v
+
+    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            if mask_att_chunk_encoder is not None:
+                mask = mask * mask_att_chunk_encoder
+
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, x, mask, mask_att_chunk_encoder=None):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+        return att_outs
+
+
diff --git a/funasr/models/sond/e2e_diar_sond.py b/funasr/models/sond/e2e_diar_sond.py
index ff70502..21d1d59 100644
--- a/funasr/models/sond/e2e_diar_sond.py
+++ b/funasr/models/sond/e2e_diar_sond.py
@@ -18,7 +18,7 @@
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.models.specaug.abs_profileaug import AbsProfileAug
 from funasr.layers.abs_normalize import AbsNormalize
diff --git a/funasr/models/sond/encoder/conv_encoder.py b/funasr/models/sond/encoder/conv_encoder.py
index 4c345cb..3933c01 100644
--- a/funasr/models/sond/encoder/conv_encoder.py
+++ b/funasr/models/sond/encoder/conv_encoder.py
@@ -12,7 +12,7 @@
 from funasr.models.transformer.layer_norm import LayerNorm
 from funasr.models.encoder.abs_encoder import AbsEncoder
 import math
-from funasr.models.transformer.repeat import repeat
+from funasr.models.transformer.utils.repeat import repeat
 
 
 class EncoderLayer(nn.Module):
diff --git a/funasr/models/sond/encoder/fsmn_encoder.py b/funasr/models/sond/encoder/fsmn_encoder.py
index e23f3f1..129a748 100644
--- a/funasr/models/sond/encoder/fsmn_encoder.py
+++ b/funasr/models/sond/encoder/fsmn_encoder.py
@@ -12,7 +12,7 @@
 from funasr.models.transformer.layer_norm import LayerNorm
 from funasr.models.encoder.abs_encoder import AbsEncoder
 import math
-from funasr.models.transformer.repeat import repeat
+from funasr.models.transformer.utils.repeat import repeat
 from funasr.models.transformer.utils.multi_layer_conv import FsmnFeedForward
 
 
diff --git a/funasr/models/sond/encoder/self_attention_encoder.py b/funasr/models/sond/encoder/self_attention_encoder.py
index c620295..ea974c6 100644
--- a/funasr/models/sond/encoder/self_attention_encoder.py
+++ b/funasr/models/sond/encoder/self_attention_encoder.py
@@ -9,7 +9,7 @@
 from funasr.models.scama.chunk_utilis import overlap_chunk
 import numpy as np
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.transformer.attention import MultiHeadSelfAttention, MultiHeadedAttentionSANM
+from funasr.models.sond.attention import MultiHeadSelfAttention
 from funasr.models.transformer.embedding import SinusoidalPositionEncoder
 from funasr.models.transformer.layer_norm import LayerNorm
 from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
@@ -17,13 +17,13 @@
 from funasr.models.transformer.positionwise_feed_forward import (
     PositionwiseFeedForward,  # noqa: H301
 )
-from funasr.models.transformer.repeat import repeat
-from funasr.models.transformer.subsampling import Conv2dSubsampling
-from funasr.models.transformer.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.subsampling import TooShortUttError
-from funasr.models.transformer.subsampling import check_short_utt
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
 from funasr.models.ctc import CTC
 from funasr.models.encoder.abs_encoder import AbsEncoder
 
diff --git a/funasr/models/specaug/specaug.py b/funasr/models/specaug/specaug.py
index e5da8e2..17f2657 100644
--- a/funasr/models/specaug/specaug.py
+++ b/funasr/models/specaug/specaug.py
@@ -7,9 +7,11 @@
 from funasr.models.specaug.mask_along_axis import MaskAlongAxisVariableMaxWidth
 from funasr.models.specaug.mask_along_axis import MaskAlongAxisLFR
 from funasr.models.specaug.time_warp import TimeWarp
+from funasr.utils.register import register_class
 
 import torch.nn as nn
 
+@register_class("specaug_classes", "SpecAug")
 class SpecAug(nn.Module):
     """Implementation of SpecAug.
 
@@ -99,7 +101,8 @@
             x, x_lengths = self.time_mask(x, x_lengths)
         return x, x_lengths
 
-class SpecAugLFR(AbsSpecAug):
+@register_class("specaug_classes", "SpecAugLFR")
+class SpecAugLFR(nn.Module):
     """Implementation of SpecAug.
     lfr_rate锛歭ow frame rate
     """
diff --git a/funasr/models/tp_aligner/e2e_tp.py b/funasr/models/tp_aligner/e2e_tp.py
index e6c4028..c675b0e 100644
--- a/funasr/models/tp_aligner/e2e_tp.py
+++ b/funasr/models/tp_aligner/e2e_tp.py
@@ -11,13 +11,13 @@
 import numpy as np
 
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.predictor.cif import mae_loss
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.frontends.abs_frontend import AbsFrontend
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.models.base_model import FunASRModel
-from funasr.models.predictor.cif import CifPredictorV3
+from funasr.models.paraformer.cif_predictor import CifPredictorV3
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
diff --git a/funasr/models/transducer/model.py b/funasr/models/transducer/model.py
index ec340f6..21de4ba 100644
--- a/funasr/models/transducer/model.py
+++ b/funasr/models/transducer/model.py
@@ -16,7 +16,6 @@
 import random
 import numpy as np
 import time
-# from funasr.layers.abs_normalize import AbsNormalize
 from funasr.losses.label_smoothing_loss import (
 	LabelSmoothingLoss,  # noqa: H301
 )
@@ -24,17 +23,17 @@
 # from funasr.models.decoder.abs_decoder import AbsDecoder
 # from funasr.models.e2e_asr_common import ErrorCalculator
 # from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
+# from funasr.frontends.abs_frontend import AbsFrontend
 # from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
+from funasr.models.paraformer.cif_predictor import mae_loss
 # from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 # from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
 from funasr.train_utils.device_funcs import force_gatherable
 # from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
+# from funasr.models.paraformer.cif_predictor import CifPredictorV3
 from funasr.models.paraformer.search import Hypothesis
 
 from funasr.models.model_class_factory import *
@@ -46,7 +45,7 @@
 	@contextmanager
 	def autocast(enabled=True):
 		yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
 from funasr.utils import postprocess_utils
 from funasr.utils.datadir_writer import DatadirWriter
 from funasr.models.transformer.utils.nets_utils import get_transducer_task_io
@@ -98,19 +97,19 @@
 		super().__init__()
 
 		if frontend is not None:
-			frontend_class = frontend_choices.get_class(frontend)
+			frontend_class = frontend_classes.get_class(frontend)
 			frontend = frontend_class(**frontend_conf)
 		if specaug is not None:
-			specaug_class = specaug_choices.get_class(specaug)
+			specaug_class = specaug_classes.get_class(specaug)
 			specaug = specaug_class(**specaug_conf)
 		if normalize is not None:
-			normalize_class = normalize_choices.get_class(normalize)
+			normalize_class = normalize_classes.get_class(normalize)
 			normalize = normalize_class(**normalize_conf)
-		encoder_class = encoder_choices.get_class(encoder)
+		encoder_class = encoder_classes.get_class(encoder)
 		encoder = encoder_class(input_size=input_size, **encoder_conf)
 		encoder_output_size = encoder.output_size()
 
-		decoder_class = decoder_choices.get_class(decoder)
+		decoder_class = decoder_classes.get_class(decoder)
 		decoder = decoder_class(
 			vocab_size=vocab_size,
 			encoder_output_size=encoder_output_size,
@@ -118,7 +117,7 @@
 		)
 		decoder_output_size = decoder.output_size
 
-		joint_network_class = joint_network_choices.get_class(decoder)
+		joint_network_class = joint_network_classes.get_class(decoder)
 		joint_network = joint_network_class(
 			vocab_size,
 			encoder_output_size,
@@ -521,7 +520,7 @@
 		audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
 		time2 = time.perf_counter()
 		meta_data["load_data"] = f"{time2 - time1:0.3f}"
-		speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"), frontend=self.frontend)
+		speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
 		time3 = time.perf_counter()
 		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
 		meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
diff --git a/funasr/models/transducer/rnn_decoder.py b/funasr/models/transducer/rnn_decoder.py
index 1743f99..204f0b1 100644
--- a/funasr/models/transducer/rnn_decoder.py
+++ b/funasr/models/transducer/rnn_decoder.py
@@ -2,13 +2,12 @@
 
 import numpy as np
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
 
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.transformer.utils.nets_utils import to_device
 from funasr.models.language_model.rnn.attentions import initial_att
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.utils.get_default_kwargs import get_default_kwargs
 
 
 def build_attention_list(
@@ -80,7 +79,7 @@
     return att_list
 
 
-class RNNDecoder(AbsDecoder):
+class RNNDecoder(nn.Module):
     def __init__(
         self,
         vocab_size: int,
@@ -93,7 +92,7 @@
         context_residual: bool = False,
         replace_sos: bool = False,
         num_encs: int = 1,
-        att_conf: dict = get_default_kwargs(build_attention_list),
+        att_conf: dict = None,
     ):
         # FIXME(kamo): The parts of num_spk should be refactored more more more
         if rnn_type not in {"lstm", "gru"}:
diff --git a/funasr/models/transformer/attention.py b/funasr/models/transformer/attention.py
index 04607c6..32e1e47 100644
--- a/funasr/models/transformer/attention.py
+++ b/funasr/models/transformer/attention.py
@@ -15,7 +15,7 @@
 
 import torch.nn.functional as F
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
-import funasr.modules.lora.layers as lora
+import funasr.models.lora.layers as lora
 
 class MultiHeadedAttention(nn.Module):
     """Multi-Head Attention layer.
@@ -312,780 +312,8 @@
         return self.forward_attention(v, scores, mask)
 
 
-class MultiHeadedAttentionSANM(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionSANM, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        # self.linear_q = nn.Linear(n_feat, n_feat)
-        # self.linear_k = nn.Linear(n_feat, n_feat)
-        # self.linear_v = nn.Linear(n_feat, n_feat)
-        if lora_list is not None:
-            if "o" in lora_list:
-                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
-            else:
-                self.linear_out = nn.Linear(n_feat, n_feat)
-            lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
-            if lora_qkv_list == [False, False, False]:
-                self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
-            else:
-                self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
-        else:
-            self.linear_out = nn.Linear(n_feat, n_feat)
-            self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-        self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
-        # padding
-        left_padding = (kernel_size - 1) // 2
-        if sanm_shfit > 0:
-            left_padding = left_padding + sanm_shfit
-        right_padding = kernel_size - 1 - left_padding
-        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
-
-    def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
-        b, t, d = inputs.size()
-        if mask is not None:
-            mask = torch.reshape(mask, (b, -1, 1))
-            if mask_shfit_chunk is not None:
-                mask = mask * mask_shfit_chunk
-            inputs = inputs * mask
-
-        x = inputs.transpose(1, 2)
-        x = self.pad_fn(x)
-        x = self.fsmn_block(x)
-        x = x.transpose(1, 2)
-        x += inputs
-        x = self.dropout(x)
-        if mask is not None:
-            x = x * mask
-        return x
-
-    def forward_qkv(self, x):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-        b, t, d = x.size()
-        q_k_v = self.linear_q_k_v(x)
-        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
-        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-
-        return q_h, k_h, v_h, v
-
-    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            if mask_att_chunk_encoder is not None:
-                mask = mask * mask_att_chunk_encoder
-
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x)  # (batch, time1, d_model)
-
-    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
-        return att_outs + fsmn_memory
-
-    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        if chunk_size is not None and look_back > 0 or look_back == -1:
-            if cache is not None:
-                k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
-                v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
-                k_h = torch.cat((cache["k"], k_h), dim=2)
-                v_h = torch.cat((cache["v"], v_h), dim=2)
-
-                cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
-                cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
-                if look_back != -1:
-                    cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
-                    cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
-            else:
-                cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
-                             "v": v_h[:, :, :-(chunk_size[2]), :]}
-                cache = cache_tmp
-        fsmn_memory = self.forward_fsmn(v, None)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, None)
-        return att_outs + fsmn_memory, cache
-
-
-class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
-        return att_outs + fsmn_memory
-
-class MultiHeadedAttentionSANMDecoder(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionSANMDecoder, self).__init__()
-
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-        self.fsmn_block = nn.Conv1d(n_feat, n_feat,
-                                    kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
-        # padding
-        # padding
-        left_padding = (kernel_size - 1) // 2
-        if sanm_shfit > 0:
-            left_padding = left_padding + sanm_shfit
-        right_padding = kernel_size - 1 - left_padding
-        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
-        self.kernel_size = kernel_size
-
-    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
-        '''
-        :param x: (#batch, time1, size).
-        :param mask: Mask tensor (#batch, 1, time)
-        :return:
-        '''
-        # print("in fsmn, inputs", inputs.size())
-        b, t, d = inputs.size()
-        # logging.info(
-        #     "mask: {}".format(mask.size()))
-        if mask is not None:
-            mask = torch.reshape(mask, (b ,-1, 1))
-            # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
-            if mask_shfit_chunk is not None:
-                # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
-                mask = mask * mask_shfit_chunk
-            # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
-            # print("in fsmn, mask", mask.size())
-            # print("in fsmn, inputs", inputs.size())
-            inputs = inputs * mask
-
-        x = inputs.transpose(1, 2)
-        b, d, t = x.size()
-        if cache is None:
-            # print("in fsmn, cache is None, x", x.size())
-
-            x = self.pad_fn(x)
-            if not self.training:
-                cache = x
-        else:
-            # print("in fsmn, cache is not None, x", x.size())
-            # x = torch.cat((x, cache), dim=2)[:, :, :-1]
-            # if t < self.kernel_size:
-            #     x = self.pad_fn(x)
-            x = torch.cat((cache[:, :, 1:], x), dim=2)
-            x = x[:, :, -(self.kernel_size+t-1):]
-            # print("in fsmn, cache is not None, x_cat", x.size())
-            cache = x
-        x = self.fsmn_block(x)
-        x = x.transpose(1, 2)
-        # print("in fsmn, fsmn_out", x.size())
-        if x.size(1) != inputs.size(1):
-            inputs = inputs[:, -1, :]
-
-        x = x + inputs
-        x = self.dropout(x)
-        if mask is not None:
-            x = x * mask
-        return x, cache
-
-class MultiHeadedAttentionCrossAtt(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionCrossAtt, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        if lora_list is not None:
-            if "q" in lora_list:
-                self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
-            else:
-                self.linear_q = nn.Linear(n_feat, n_feat)
-            lora_kv_list = ["k" in lora_list, "v" in lora_list]
-            if lora_kv_list == [False, False]:
-                self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
-            else:
-                self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
-                                      r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
-            if "o" in lora_list:
-                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
-            else:
-                self.linear_out = nn.Linear(n_feat, n_feat)
-        else:
-            self.linear_q = nn.Linear(n_feat, n_feat)
-            self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
-            self.linear_out = nn.Linear(n_feat, n_feat)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-    def forward_qkv(self, x, memory):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-
-        # print("in forward_qkv, x", x.size())
-        b = x.size(0)
-        q = self.linear_q(x)
-        q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time1, d_k)
-
-        k_v = self.linear_k_v(memory)
-        k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
-        k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
-
-
-        return q_h, k_h, v_h
-
-    def forward_attention(self, value, scores, mask):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
-            # logging.info(
-            #     "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x)  # (batch, time1, d_model)
-
-    def forward(self, x, memory, memory_mask):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h = self.forward_qkv(x, memory)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        return self.forward_attention(v_h, scores, memory_mask)
-
-    def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h = self.forward_qkv(x, memory)
-        if chunk_size is not None and look_back > 0:
-            if cache is not None:
-                k_h = torch.cat((cache["k"], k_h), dim=2)
-                v_h = torch.cat((cache["v"], v_h), dim=2)
-                cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
-                cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
-            else:
-                cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
-                             "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
-                cache = cache_tmp
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        return self.forward_attention(v_h, scores, None), cache
-
-
-class MultiHeadSelfAttention(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, in_feat, n_feat, dropout_rate):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadSelfAttention, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        self.linear_out = nn.Linear(n_feat, n_feat)
-        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-    def forward_qkv(self, x):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-        b, t, d = x.size()
-        q_k_v = self.linear_q_k_v(x)
-        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
-        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-
-        return q_h, k_h, v_h, v
-
-    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            if mask_att_chunk_encoder is not None:
-                mask = mask * mask_att_chunk_encoder
-
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x)  # (batch, time1, d_model)
-
-    def forward(self, x, mask, mask_att_chunk_encoder=None):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
-        return att_outs
-
-class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
-    """RelPositionMultiHeadedAttention definition.
-    Args:
-        num_heads: Number of attention heads.
-        embed_size: Embedding size.
-        dropout_rate: Dropout rate.
-    """
-
-    def __init__(
-        self,
-        num_heads: int,
-        embed_size: int,
-        dropout_rate: float = 0.0,
-        simplified_attention_score: bool = False,
-    ) -> None:
-        """Construct an MultiHeadedAttention object."""
-        super().__init__()
-
-        self.d_k = embed_size // num_heads
-        self.num_heads = num_heads
-
-        assert self.d_k * num_heads == embed_size, (
-            "embed_size (%d) must be divisible by num_heads (%d)",
-            (embed_size, num_heads),
-        )
-
-        self.linear_q = torch.nn.Linear(embed_size, embed_size)
-        self.linear_k = torch.nn.Linear(embed_size, embed_size)
-        self.linear_v = torch.nn.Linear(embed_size, embed_size)
-
-        self.linear_out = torch.nn.Linear(embed_size, embed_size)
-
-        if simplified_attention_score:
-            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
-
-            self.compute_att_score = self.compute_simplified_attention_score
-        else:
-            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
-
-            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
-            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
-            torch.nn.init.xavier_uniform_(self.pos_bias_u)
-            torch.nn.init.xavier_uniform_(self.pos_bias_v)
-
-            self.compute_att_score = self.compute_attention_score
-
-        self.dropout = torch.nn.Dropout(p=dropout_rate)
-        self.attn = None
-
-    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
-        """Compute relative positional encoding.
-        Args:
-            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
-            left_context: Number of frames in left context.
-        Returns:
-            x: Output sequence. (B, H, T_1, T_2)
-        """
-        batch_size, n_heads, time1, n = x.shape
-        time2 = time1 + left_context
-
-        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
-
-        return x.as_strided(
-            (batch_size, n_heads, time1, time2),
-            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
-            storage_offset=(n_stride * (time1 - 1)),
-        )
-
-    def compute_simplified_attention_score(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        pos_enc: torch.Tensor,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Simplified attention score computation.
-        Reference: https://github.com/k2-fsa/icefall/pull/458
-        Args:
-            query: Transformed query tensor. (B, H, T_1, d_k)
-            key: Transformed key tensor. (B, H, T_2, d_k)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            left_context: Number of frames in left context.
-        Returns:
-            : Attention score. (B, H, T_1, T_2)
-        """
-        pos_enc = self.linear_pos(pos_enc)
-
-        matrix_ac = torch.matmul(query, key.transpose(2, 3))
-
-        matrix_bd = self.rel_shift(
-            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
-            left_context=left_context,
-        )
-
-        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
-    def compute_attention_score(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        pos_enc: torch.Tensor,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Attention score computation.
-        Args:
-            query: Transformed query tensor. (B, H, T_1, d_k)
-            key: Transformed key tensor. (B, H, T_2, d_k)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            left_context: Number of frames in left context.
-        Returns:
-            : Attention score. (B, H, T_1, T_2)
-        """
-        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
-
-        query = query.transpose(1, 2)
-        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
-        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
-
-        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
-
-        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
-        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
-
-        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
-    def forward_qkv(
-        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        """Transform query, key and value.
-        Args:
-            query: Query tensor. (B, T_1, size)
-            key: Key tensor. (B, T_2, size)
-            v: Value tensor. (B, T_2, size)
-        Returns:
-            q: Transformed query tensor. (B, H, T_1, d_k)
-            k: Transformed key tensor. (B, H, T_2, d_k)
-            v: Transformed value tensor. (B, H, T_2, d_k)
-        """
-        n_batch = query.size(0)
-
-        q = (
-            self.linear_q(query)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-        k = (
-            self.linear_k(key)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-        v = (
-            self.linear_v(value)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-
-        return q, k, v
-
-    def forward_attention(
-        self,
-        value: torch.Tensor,
-        scores: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
-        """Compute attention context vector.
-        Args:
-            value: Transformed value. (B, H, T_2, d_k)
-            scores: Attention score. (B, H, T_1, T_2)
-            mask: Source mask. (B, T_2)
-            chunk_mask: Chunk mask. (T_1, T_1)
-        Returns:
-           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
-        """
-        batch_size = scores.size(0)
-        mask = mask.unsqueeze(1).unsqueeze(2)
-        if chunk_mask is not None:
-            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
-        scores = scores.masked_fill(mask, float("-inf"))
-        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
-
-        attn_output = self.dropout(self.attn)
-        attn_output = torch.matmul(attn_output, value)
-
-        attn_output = self.linear_out(
-            attn_output.transpose(1, 2)
-            .contiguous()
-            .view(batch_size, -1, self.num_heads * self.d_k)
-        )
-
-        return attn_output
-
-    def forward(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Compute scaled dot product attention with rel. positional encoding.
-        Args:
-            query: Query tensor. (B, T_1, size)
-            key: Key tensor. (B, T_2, size)
-            value: Value tensor. (B, T_2, size)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            mask: Source mask. (B, T_2)
-            chunk_mask: Chunk mask. (T_1, T_1)
-            left_context: Number of frames in left context.
-        Returns:
-            : Output tensor. (B, T_1, H * d_k)
-        """
-        q, k, v = self.forward_qkv(query, key, value)
-        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
-        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
 
 
-class CosineDistanceAttention(nn.Module):
-    """ Compute Cosine Distance between spk decoder output and speaker profile 
-    Args:
-        profile_path: speaker profile file path (.npy file)
-    """
 
-    def __init__(self):
-        super().__init__()
-        self.softmax = nn.Softmax(dim=-1)
 
-    def forward(self, spk_decoder_out, profile, profile_lens=None):
-        """
-        Args:
-            spk_decoder_out(torch.Tensor):(B, L, D)
-            spk_profiles(torch.Tensor):(B, N, D)
-        """
-        x = spk_decoder_out.unsqueeze(2)  # (B, L, 1, D)
-        if profile_lens is not None:
-            
-            mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
-            )
-            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
-            weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0)  # (B, L, N)
-        else:
-            x = x[:, -1:, :, :]
-            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
-            weights = self.softmax(weights_not_softmax)  # (B, 1, N)
-        spk_embedding = torch.matmul(weights, profile.to(weights.device))  # (B, L, D)
 
-        return spk_embedding, weights
diff --git a/funasr/models/transformer/decoder.py b/funasr/models/transformer/decoder.py
new file mode 100644
index 0000000..3e8d224
--- /dev/null
+++ b/funasr/models/transformer/decoder.py
@@ -0,0 +1,647 @@
+# Copyright 2019 Shigeki Karita
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Decoder definition."""
+from typing import Any
+from typing import List
+from typing import Sequence
+from typing import Tuple
+
+import torch
+from torch import nn
+
+
+from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.transformer.utils.dynamic_conv import DynamicConvolution
+from funasr.models.transformer.utils.dynamic_conv2d import DynamicConvolution2D
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.lightconv import LightweightConvolution
+from funasr.models.transformer.utils.lightconv2d import LightweightConvolution2D
+from funasr.models.transformer.utils.mask import subsequent_mask
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.positionwise_feed_forward import (
+    PositionwiseFeedForward,  # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
+
+from funasr.utils.register import register_class, registry_tables
+
+class DecoderLayer(nn.Module):
+    """Single decoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        src_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+    """
+
+    def __init__(
+            self,
+            size,
+            self_attn,
+            src_attn,
+            feed_forward,
+            dropout_rate,
+            normalize_before=True,
+            concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(DecoderLayer, self).__init__()
+        self.size = size
+        self.self_attn = self_attn
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        self.norm2 = LayerNorm(size)
+        self.norm3 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+
+        if cache is None:
+            tgt_q = tgt
+            tgt_q_mask = tgt_mask
+        else:
+            # compute only the last frame query keeping dim: max_time_out -> 1
+            assert cache.shape == (
+                tgt.shape[0],
+                tgt.shape[1] - 1,
+                self.size,
+            ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+            tgt_q = tgt[:, -1:, :]
+            residual = residual[:, -1:, :]
+            tgt_q_mask = None
+            if tgt_mask is not None:
+                tgt_q_mask = tgt_mask[:, -1:, :]
+
+        if self.concat_after:
+            tgt_concat = torch.cat(
+                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
+            )
+            x = residual + self.concat_linear1(tgt_concat)
+        else:
+            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        if self.concat_after:
+            x_concat = torch.cat(
+                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
+            )
+            x = residual + self.concat_linear2(x_concat)
+        else:
+            x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm3(x)
+        x = residual + self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm3(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+
+        return x, tgt_mask, memory, memory_mask
+
+
+class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
+    """Base class of Transfomer decoder module.
+
+    Args:
+        vocab_size: output dim
+        encoder_output_size: dimension of attention
+        attention_heads: the number of heads of multi head attention
+        linear_units: the number of units of position-wise feed forward
+        num_blocks: the number of decoder blocks
+        dropout_rate: dropout rate
+        self_attention_dropout_rate: dropout rate for attention
+        input_layer: input layer type
+        use_output_layer: whether to use output layer
+        pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+        normalize_before: whether to use layer_norm before the first block
+        concat_after: whether to concat attention layer's input and output
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied.
+            i.e. x -> x + att(x)
+    """
+
+    def __init__(
+            self,
+            vocab_size: int,
+            encoder_output_size: int,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            input_layer: str = "embed",
+            use_output_layer: bool = True,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+    ):
+        super().__init__()
+        attention_dim = encoder_output_size
+
+        if input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(vocab_size, attention_dim),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        elif input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(vocab_size, attention_dim),
+                torch.nn.LayerNorm(attention_dim),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        else:
+            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+        self.normalize_before = normalize_before
+        if self.normalize_before:
+            self.after_norm = LayerNorm(attention_dim)
+        if use_output_layer:
+            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+        else:
+            self.output_layer = None
+
+        # Must set by the inheritance
+        self.decoders = None
+
+    def forward(
+            self,
+            hs_pad: torch.Tensor,
+            hlens: torch.Tensor,
+            ys_in_pad: torch.Tensor,
+            ys_in_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        tgt = ys_in_pad
+        # tgt_mask: (B, 1, L)
+        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+        # m: (1, L, L)
+        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+        # tgt_mask: (B, L, L)
+        tgt_mask = tgt_mask & m
+
+        memory = hs_pad
+        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
+            memory.device
+        )
+        # Padding for Longformer
+        if memory_mask.shape[-1] != memory.shape[1]:
+            padlen = memory.shape[1] - memory_mask.shape[-1]
+            memory_mask = torch.nn.functional.pad(
+                memory_mask, (0, padlen), "constant", False
+            )
+
+        x = self.embed(tgt)
+        x, tgt_mask, memory, memory_mask = self.decoders(
+            x, tgt_mask, memory, memory_mask
+        )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+
+        olens = tgt_mask.sum(1)
+        return x, olens
+
+    def forward_one_step(
+            self,
+            tgt: torch.Tensor,
+            tgt_mask: torch.Tensor,
+            memory: torch.Tensor,
+            cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Forward one step.
+
+        Args:
+            tgt: input token ids, int64 (batch, maxlen_out)
+            tgt_mask: input token mask,  (batch, maxlen_out)
+                      dtype=torch.uint8 in PyTorch 1.2-
+                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+            memory: encoded memory, float32  (batch, maxlen_in, feat)
+            cache: cached output list of (batch, max_time_out-1, size)
+        Returns:
+            y, cache: NN output value and cache per `self.decoders`.
+            y.shape` is (batch, maxlen_out, token)
+        """
+        x = self.embed(tgt)
+        if cache is None:
+            cache = [None] * len(self.decoders)
+        new_cache = []
+        for c, decoder in zip(cache, self.decoders):
+            x, tgt_mask, memory, memory_mask = decoder(
+                x, tgt_mask, memory, None, cache=c
+            )
+            new_cache.append(x)
+
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.output_layer is not None:
+            y = torch.log_softmax(self.output_layer(y), dim=-1)
+
+        return y, new_cache
+
+    def score(self, ys, state, x):
+        """Score."""
+        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
+        logp, state = self.forward_one_step(
+            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
+        )
+        return logp.squeeze(0), state
+
+    def batch_score(
+            self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+    ) -> Tuple[torch.Tensor, List[Any]]:
+        """Score new token batch.
+
+        Args:
+            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+            states (List[Any]): Scorer states for prefix tokens.
+            xs (torch.Tensor):
+                The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+        Returns:
+            tuple[torch.Tensor, List[Any]]: Tuple of
+                batchfied scores for next token with shape of `(n_batch, n_vocab)`
+                and next state list for ys.
+
+        """
+        # merge states
+        n_batch = len(ys)
+        n_layers = len(self.decoders)
+        if states[0] is None:
+            batch_state = None
+        else:
+            # transpose state of [batch, layer] into [layer, batch]
+            batch_state = [
+                torch.stack([states[b][i] for b in range(n_batch)])
+                for i in range(n_layers)
+            ]
+
+        # batch decoding
+        ys_mask = subsequent_mask(ys.size(-1), device=xs.device).unsqueeze(0)
+        logp, states = self.forward_one_step(ys, ys_mask, xs, cache=batch_state)
+
+        # transpose state of [layer, batch] into [batch, layer]
+        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
+        return logp, state_list
+
+@register_class("decoder_classes", "TransformerDecoder")
+class TransformerDecoder(BaseTransformerDecoder):
+    def __init__(
+            self,
+            vocab_size: int,
+            encoder_output_size: int,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            self_attention_dropout_rate: float = 0.0,
+            src_attention_dropout_rate: float = 0.0,
+            input_layer: str = "embed",
+            use_output_layer: bool = True,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+    ):
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+        self.decoders = repeat(
+            num_blocks,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, self_attention_dropout_rate
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+
+@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
+class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
+    def __init__(
+            self,
+            vocab_size: int,
+            encoder_output_size: int,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            self_attention_dropout_rate: float = 0.0,
+            src_attention_dropout_rate: float = 0.0,
+            input_layer: str = "embed",
+            use_output_layer: bool = True,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            conv_wshare: int = 4,
+            conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+            conv_usebias: int = False,
+    ):
+        if len(conv_kernel_length) != num_blocks:
+            raise ValueError(
+                "conv_kernel_length must have equal number of values to num_blocks: "
+                f"{len(conv_kernel_length)} != {num_blocks}"
+            )
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+        self.decoders = repeat(
+            num_blocks,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                LightweightConvolution(
+                    wshare=conv_wshare,
+                    n_feat=attention_dim,
+                    dropout_rate=self_attention_dropout_rate,
+                    kernel_size=conv_kernel_length[lnum],
+                    use_kernel_mask=True,
+                    use_bias=conv_usebias,
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
+class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
+    def __init__(
+            self,
+            vocab_size: int,
+            encoder_output_size: int,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            self_attention_dropout_rate: float = 0.0,
+            src_attention_dropout_rate: float = 0.0,
+            input_layer: str = "embed",
+            use_output_layer: bool = True,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            conv_wshare: int = 4,
+            conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+            conv_usebias: int = False,
+    ):
+        if len(conv_kernel_length) != num_blocks:
+            raise ValueError(
+                "conv_kernel_length must have equal number of values to num_blocks: "
+                f"{len(conv_kernel_length)} != {num_blocks}"
+            )
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+        self.decoders = repeat(
+            num_blocks,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                LightweightConvolution2D(
+                    wshare=conv_wshare,
+                    n_feat=attention_dim,
+                    dropout_rate=self_attention_dropout_rate,
+                    kernel_size=conv_kernel_length[lnum],
+                    use_kernel_mask=True,
+                    use_bias=conv_usebias,
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+
+@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
+class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
+    def __init__(
+            self,
+            vocab_size: int,
+            encoder_output_size: int,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            self_attention_dropout_rate: float = 0.0,
+            src_attention_dropout_rate: float = 0.0,
+            input_layer: str = "embed",
+            use_output_layer: bool = True,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            conv_wshare: int = 4,
+            conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+            conv_usebias: int = False,
+    ):
+        if len(conv_kernel_length) != num_blocks:
+            raise ValueError(
+                "conv_kernel_length must have equal number of values to num_blocks: "
+                f"{len(conv_kernel_length)} != {num_blocks}"
+            )
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+        attention_dim = encoder_output_size
+
+        self.decoders = repeat(
+            num_blocks,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                DynamicConvolution(
+                    wshare=conv_wshare,
+                    n_feat=attention_dim,
+                    dropout_rate=self_attention_dropout_rate,
+                    kernel_size=conv_kernel_length[lnum],
+                    use_kernel_mask=True,
+                    use_bias=conv_usebias,
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
+class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
+    def __init__(
+            self,
+            vocab_size: int,
+            encoder_output_size: int,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            self_attention_dropout_rate: float = 0.0,
+            src_attention_dropout_rate: float = 0.0,
+            input_layer: str = "embed",
+            use_output_layer: bool = True,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            conv_wshare: int = 4,
+            conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+            conv_usebias: int = False,
+    ):
+        if len(conv_kernel_length) != num_blocks:
+            raise ValueError(
+                "conv_kernel_length must have equal number of values to num_blocks: "
+                f"{len(conv_kernel_length)} != {num_blocks}"
+            )
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+        attention_dim = encoder_output_size
+
+        self.decoders = repeat(
+            num_blocks,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                DynamicConvolution2D(
+                    wshare=conv_wshare,
+                    n_feat=attention_dim,
+                    dropout_rate=self_attention_dropout_rate,
+                    kernel_size=conv_kernel_length[lnum],
+                    use_kernel_mask=True,
+                    use_bias=conv_usebias,
+                ),
+                MultiHeadedAttention(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
diff --git a/funasr/models/transformer/encoder.py b/funasr/models/transformer/encoder.py
new file mode 100644
index 0000000..a3d5249
--- /dev/null
+++ b/funasr/models/transformer/encoder.py
@@ -0,0 +1,332 @@
+# Copyright 2019 Shigeki Karita
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Transformer encoder definition."""
+
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import torch
+from torch import nn
+import logging
+
+from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.ctc.ctc import CTC
+
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+
+from funasr.utils.register import register_class
+
+class EncoderLayer(nn.Module):
+    """Encoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+            can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+        stochastic_depth_rate (float): Proability to skip this layer.
+            During training, the layer may skip residual computation and return input
+            as-is with given probability.
+    """
+
+    def __init__(
+            self,
+            size,
+            self_attn,
+            feed_forward,
+            dropout_rate,
+            normalize_before=True,
+            concat_after=False,
+            stochastic_depth_rate=0.0,
+    ):
+        """Construct an EncoderLayer object."""
+        super(EncoderLayer, self).__init__()
+        self.self_attn = self_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.size = size
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear = nn.Linear(size + size, size)
+        self.stochastic_depth_rate = stochastic_depth_rate
+
+    def forward(self, x, mask, cache=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+        skip_layer = False
+        # with stochastic depth, residual connection `x + f(x)` becomes
+        # `x <- x + 1 / (1 - p) * f(x)` at training time.
+        stoch_layer_coeff = 1.0
+        if self.training and self.stochastic_depth_rate > 0:
+            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+        if skip_layer:
+            if cache is not None:
+                x = torch.cat([cache, x], dim=1)
+            return x, mask
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if cache is None:
+            x_q = x
+        else:
+            assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
+            x_q = x[:, -1:, :]
+            residual = residual[:, -1:, :]
+            mask = None if mask is None else mask[:, -1:, :]
+
+        if self.concat_after:
+            x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
+            x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+        else:
+            x = residual + stoch_layer_coeff * self.dropout(
+                self.self_attn(x_q, x, x, mask)
+            )
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+
+        return x, mask
+
+@register_class("encoder_classes", "TransformerEncoder")
+class TransformerEncoder(nn.Module):
+    """Transformer encoder module.
+
+    Args:
+        input_size: input dim
+        output_size: dimension of attention
+        attention_heads: the number of heads of multi head attention
+        linear_units: the number of units of position-wise feed forward
+        num_blocks: the number of decoder blocks
+        dropout_rate: dropout rate
+        attention_dropout_rate: dropout rate in attention
+        positional_dropout_rate: dropout rate after adding positional encoding
+        input_layer: input layer type
+        pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+        normalize_before: whether to use layer_norm before the first block
+        concat_after: whether to concat attention layer's input and output
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied.
+            i.e. x -> x + att(x)
+        positionwise_layer_type: linear of conv1d
+        positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
+        padding_idx: padding_idx for input_layer=embed
+    """
+
+    def __init__(
+            self,
+            input_size: int,
+            output_size: int = 256,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            attention_dropout_rate: float = 0.0,
+            input_layer: Optional[str] = "conv2d",
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            positionwise_layer_type: str = "linear",
+            positionwise_conv_kernel_size: int = 1,
+            padding_idx: int = -1,
+            interctc_layer_idx: List[int] = [],
+            interctc_use_conditioning: bool = False,
+    ):
+        super().__init__()
+        self._output_size = output_size
+
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+        self.encoders = repeat(
+            num_blocks,
+            lambda lnum: EncoderLayer(
+                output_size,
+                MultiHeadedAttention(
+                    attention_heads, output_size, attention_dropout_rate
+                ),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+            prev_states: torch.Tensor = None,
+            ctc: CTC = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (
+                isinstance(self.embed, Conv2dSubsampling)
+                or isinstance(self.embed, Conv2dSubsampling2)
+                or isinstance(self.embed, Conv2dSubsampling6)
+                or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        intermediate_outs = []
+        if len(self.interctc_layer_idx) == 0:
+            xs_pad, masks = self.encoders(xs_pad, masks)
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                xs_pad, masks = encoder_layer(xs_pad, masks)
+
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens, None
+
diff --git a/funasr/models/transformer/model.py b/funasr/models/transformer/model.py
index 107d62f..e4eae10 100644
--- a/funasr/models/transformer/model.py
+++ b/funasr/models/transformer/model.py
@@ -1,56 +1,23 @@
 import logging
-from contextlib import contextmanager
-from distutils.version import LooseVersion
-from typing import Dict
-from typing import List
-from typing import Optional
-from typing import Tuple
-from typing import Union
-import tempfile
-import codecs
-import requests
-import re
-import copy
+from typing import Union, Dict, List, Tuple, Optional
+
+import time
 import torch
 import torch.nn as nn
-import random
-import numpy as np
-import time
-# from funasr.layers.abs_normalize import AbsNormalize
-from funasr.losses.label_smoothing_loss import (
-	LabelSmoothingLoss,  # noqa: H301
-)
-# from funasr.models.ctc import CTC
-# from funasr.models.decoder.abs_decoder import AbsDecoder
+from torch.cuda.amp import autocast
+
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
 # from funasr.models.e2e_asr_common import ErrorCalculator
-# from funasr.models.encoder.abs_encoder import AbsEncoder
-# from funasr.models.frontend.abs_frontend import AbsFrontend
-# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.predictor.cif import mae_loss
-# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-# from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
 from funasr.train_utils.device_funcs import force_gatherable
-# from funasr.models.base_model import FunASRModel
-# from funasr.models.predictor.cif import CifPredictorV3
-from funasr.models.paraformer.search import Hypothesis
-
-from funasr.models.model_class_factory import *
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
-	from torch.cuda.amp import autocast
-else:
-	# Nothing to do if torch<1.6.0
-	@contextmanager
-	def autocast(enabled=True):
-		yield
-from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio, extract_fbank
 from funasr.utils import postprocess_utils
 from funasr.utils.datadir_writer import DatadirWriter
+from funasr.utils.register import register_class, registry_tables
 
-
+@register_class("model_classes", "Transformer")
 class Transformer(nn.Module):
 	"""CTC-attention hybrid Encoder-Decoder model"""
 
@@ -93,19 +60,19 @@
 		super().__init__()
 
 		if frontend is not None:
-			frontend_class = frontend_choices.get_class(frontend)
+			frontend_class = registry_tables.frontend_classes.get_class(frontend.lower())
 			frontend = frontend_class(**frontend_conf)
 		if specaug is not None:
-			specaug_class = specaug_choices.get_class(specaug)
+			specaug_class = registry_tables.specaug_classes.get_class(specaug.lower())
 			specaug = specaug_class(**specaug_conf)
 		if normalize is not None:
-			normalize_class = normalize_choices.get_class(normalize)
+			normalize_class = registry_tables.normalize_classes.get_class(normalize.lower())
 			normalize = normalize_class(**normalize_conf)
-		encoder_class = encoder_choices.get_class(encoder)
+		encoder_class = registry_tables.encoder_classes.get_class(encoder.lower())
 		encoder = encoder_class(input_size=input_size, **encoder_conf)
 		encoder_output_size = encoder.output_size()
 		if decoder is not None:
-			decoder_class = decoder_choices.get_class(decoder)
+			decoder_class = registry_tables.decoder_classes.get_class(decoder.lower())
 			decoder = decoder_class(
 				vocab_size=vocab_size,
 				encoder_output_size=encoder_output_size,
@@ -239,7 +206,7 @@
 			           ) * loss_ctc + self.interctc_weight * loss_interctc
 		
 		# decoder: Attention decoder branch
-		loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
+		loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
 			encoder_out, encoder_out_lens, text, text_lengths
 		)
 		
@@ -428,7 +395,7 @@
 		audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
 		time2 = time.perf_counter()
 		meta_data["load_data"] = f"{time2 - time1:0.3f}"
-		speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"), frontend=self.frontend)
+		speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend)
 		time3 = time.perf_counter()
 		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
 		meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
diff --git a/funasr/models/transformer/positionwise_feed_forward.py b/funasr/models/transformer/positionwise_feed_forward.py
index ffa0f4e..7ca55cb 100644
--- a/funasr/models/transformer/positionwise_feed_forward.py
+++ b/funasr/models/transformer/positionwise_feed_forward.py
@@ -34,25 +34,3 @@
         return self.w_2(self.dropout(self.activation(self.w_1(x))))
 
 
-class PositionwiseFeedForwardDecoderSANM(torch.nn.Module):
-    """Positionwise feed forward layer.
-
-    Args:
-        idim (int): Input dimenstion.
-        hidden_units (int): The number of hidden units.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, idim, hidden_units, dropout_rate, adim=None, activation=torch.nn.ReLU()):
-        """Construct an PositionwiseFeedForward object."""
-        super(PositionwiseFeedForwardDecoderSANM, self).__init__()
-        self.w_1 = torch.nn.Linear(idim, hidden_units)
-        self.w_2 = torch.nn.Linear(hidden_units, idim if adim is None else adim, bias=False)
-        self.dropout = torch.nn.Dropout(dropout_rate)
-        self.activation = activation
-        self.norm = LayerNorm(hidden_units)
-
-    def forward(self, x):
-        """Forward function."""
-        return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x)))))
diff --git a/funasr/models/transformer/utils/nets_utils.py b/funasr/models/transformer/utils/nets_utils.py
index 0beb083..ce151a0 100644
--- a/funasr/models/transformer/utils/nets_utils.py
+++ b/funasr/models/transformer/utils/nets_utils.py
@@ -342,27 +342,6 @@
     return ret
 
 
-def th_accuracy(pad_outputs, pad_targets, ignore_label):
-    """Calculate accuracy.
-
-    Args:
-        pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
-        pad_targets (LongTensor): Target label tensors (B, Lmax).
-        ignore_label (int): Ignore label id.
-
-    Returns:
-        float: Accuracy value (0.0 - 1.0).
-
-    """
-    pad_pred = pad_outputs.view(
-        pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
-    ).argmax(2)
-    mask = pad_targets != ignore_label
-    numerator = torch.sum(
-        pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
-    )
-    denominator = torch.sum(mask)
-    return float(numerator) / float(denominator)
 
 
 def to_torch_tensor(x):
diff --git a/funasr/models/uniasr/e2e_uni_asr.py b/funasr/models/uniasr/e2e_uni_asr.py
index 0fb4039..46c5832 100644
--- a/funasr/models/uniasr/e2e_uni_asr.py
+++ b/funasr/models/uniasr/e2e_uni_asr.py
@@ -10,15 +10,15 @@
 import torch
 
 from funasr.models.e2e_asr_common import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.losses.label_smoothing_loss import (
     LabelSmoothingLoss,  # noqa: H301
 )
 from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
 from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 from funasr.models.specaug.abs_specaug import AbsSpecAug
@@ -26,7 +26,7 @@
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.models.base_model import FunASRModel
 from funasr.models.scama.chunk_utilis import sequence_mask
-from funasr.models.predictor.cif import mae_loss
+from funasr.models.paraformer.cif_predictor import mae_loss
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
diff --git a/funasr/models/xvector/e2e_sv.py b/funasr/models/xvector/e2e_sv.py
index 3eac9ef..fce0dfe 100644
--- a/funasr/models/xvector/e2e_sv.py
+++ b/funasr/models/xvector/e2e_sv.py
@@ -20,13 +20,13 @@
 from funasr.models.ctc import CTC
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.frontends.abs_frontend import AbsFrontend
 from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.transformer.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.metrics import ErrorCalculator
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.models.base_model import FunASRModel
 
diff --git a/funasr/optimizers/__init__.py b/funasr/optimizers/__init__.py
index b4dfe5d..177f89e 100644
--- a/funasr/optimizers/__init__.py
+++ b/funasr/optimizers/__init__.py
@@ -2,7 +2,7 @@
 from funasr.optimizers.fairseq_adam import FairseqAdam
 from funasr.optimizers.sgd import SGD
 
-optim_choices = dict(
+optim_classes = dict(
 	adam=torch.optim.Adam,
 	fairseq_adam=FairseqAdam,
 	adamw=torch.optim.AdamW,
diff --git a/funasr/schedulers/__init__.py b/funasr/schedulers/__init__.py
index 7bb8118..2ee3a9e 100644
--- a/funasr/schedulers/__init__.py
+++ b/funasr/schedulers/__init__.py
@@ -7,7 +7,7 @@
 from funasr.schedulers.tri_stage_scheduler import TriStageLR
 from funasr.schedulers.warmup_lr import WarmupLR
 
-scheduler_choices = dict(
+scheduler_classes = dict(
 	ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
 	lambdalr=torch.optim.lr_scheduler.LambdaLR,
 	steplr=torch.optim.lr_scheduler.StepLR,
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index bbfd173..349ebc0 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -7,6 +7,7 @@
 from typing import Iterable
 from typing import List
 from typing import Union
+import json
 
 import numpy as np
 
@@ -27,7 +28,7 @@
                  ):
         
         if token_list is not None:
-            if isinstance(token_list, (Path, str)):
+            if isinstance(token_list, (Path, str)) and token_list.endswith(".txt"):
                 token_list = Path(token_list)
                 self.token_list_repr = str(token_list)
                 self.token_list: List[str] = []
@@ -36,7 +37,14 @@
                     for idx, line in enumerate(f):
                         line = line.rstrip()
                         self.token_list.append(line)
-            
+            elif isinstance(token_list, (Path, str)) and token_list.endswith(".json"):
+                token_list = Path(token_list)
+                self.token_list_repr = str(token_list)
+                self.token_list: List[str] = []
+
+                with open('data.json', 'r', encoding='utf-8') as f:
+                    self.token_list = json.loads(f.read())
+
             else:
                 self.token_list: List[str] = list(token_list)
                 self.token_list_repr = ""
diff --git a/funasr/tokenizer/char_tokenizer.py b/funasr/tokenizer/char_tokenizer.py
index 80528a2..23ff743 100644
--- a/funasr/tokenizer/char_tokenizer.py
+++ b/funasr/tokenizer/char_tokenizer.py
@@ -4,10 +4,10 @@
 from typing import Union
 import warnings
 
-
-from funasr.tokenizer.abs_tokenizer import AbsTokenizer
 from funasr.tokenizer.abs_tokenizer import BaseTokenizer
+from funasr.utils.register import register_class
 
+@register_class("tokenizer_classes", "CharTokenizer")
 class CharTokenizer(BaseTokenizer):
     def __init__(
         self,
diff --git a/funasr/tokenizer/phoneme_tokenizer.py b/funasr/tokenizer/phoneme_tokenizer.py
index 04b423b..f1f7168 100644
--- a/funasr/tokenizer/phoneme_tokenizer.py
+++ b/funasr/tokenizer/phoneme_tokenizer.py
@@ -13,7 +13,7 @@
 from funasr.tokenizer.abs_tokenizer import AbsTokenizer
 
 
-g2p_choices = [
+g2p_classes = [
     None,
     "g2p_en",
     "g2p_en_no_space",
diff --git a/funasr/train_utils/collect_stats.py b/funasr/train_utils/collect_stats.py
deleted file mode 100644
index 7454ccf..0000000
--- a/funasr/train_utils/collect_stats.py
+++ /dev/null
@@ -1,124 +0,0 @@
-from collections import defaultdict
-import logging
-from pathlib import Path
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Optional
-from typing import Tuple
-
-import numpy as np
-import torch
-from torch.nn.parallel import data_parallel
-from torch.utils.data import DataLoader
-
-from funasr.utils.datadir_writer import DatadirWriter
-from funasr.fileio.npy_scp import NpyScpWriter
-from funasr.train_utils.device_funcs import to_device
-from funasr.train_utils.forward_adaptor import ForwardAdaptor
-from funasr.models.base_model import FunASRModel
-
-
-@torch.no_grad()
-def collect_stats(
-    model: FunASRModel,
-    train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
-    valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
-    output_dir: Path,
-    ngpu: Optional[int],
-    log_interval: Optional[int],
-    write_collected_feats: bool,
-) -> None:
-    """Perform on collect_stats mode.
-
-    Running for deriving the shape information from data
-    and gathering statistics.
-    This method is used before executing train().
-
-    """
-
-    npy_scp_writers = {}
-    for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]):
-        if log_interval is None:
-            try:
-                log_interval = max(len(itr) // 20, 10)
-            except TypeError:
-                log_interval = 100
-
-        sum_dict = defaultdict(lambda: 0)
-        sq_dict = defaultdict(lambda: 0)
-        count_dict = defaultdict(lambda: 0)
-
-        with DatadirWriter(output_dir / mode) as datadir_writer:
-            for iiter, (keys, batch) in enumerate(itr, 1):
-                batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
-
-                # 1. Write shape file
-                for name in batch:
-                    if name.endswith("_lengths"):
-                        continue
-                    for i, (key, data) in enumerate(zip(keys, batch[name])):
-                        if f"{name}_lengths" in batch:
-                            lg = int(batch[f"{name}_lengths"][i])
-                            data = data[:lg]
-                        datadir_writer[f"{name}_shape"][key] = ",".join(
-                            map(str, data.shape)
-                        )
-
-                # 2. Extract feats
-                if ngpu <= 1:
-                    data = model.collect_feats(**batch)
-                else:
-                    # Note that data_parallel can parallelize only "forward()"
-                    data = data_parallel(
-                        ForwardAdaptor(model, "collect_feats"),
-                        (),
-                        range(ngpu),
-                        module_kwargs=batch,
-                    )
-
-                # 3. Calculate sum and square sum
-                for key, v in data.items():
-                    for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())):
-                        # Truncate zero-padding region
-                        if f"{key}_lengths" in data:
-                            length = data[f"{key}_lengths"][i]
-                            # seq: (Length, Dim, ...)
-                            seq = seq[:length]
-                        else:
-                            # seq: (Dim, ...) -> (1, Dim, ...)
-                            seq = seq[None]
-                        # Accumulate value, its square, and count
-                        sum_dict[key] += seq.sum(0)
-                        sq_dict[key] += (seq**2).sum(0)
-                        count_dict[key] += len(seq)
-
-                        # 4. [Option] Write derived features as npy format file.
-                        if write_collected_feats:
-                            # Instantiate NpyScpWriter for the first iteration
-                            if (key, mode) not in npy_scp_writers:
-                                p = output_dir / mode / "collect_feats"
-                                npy_scp_writers[(key, mode)] = NpyScpWriter(
-                                    p / f"data_{key}", p / f"{key}.scp"
-                                )
-                            # Save array as npy file
-                            npy_scp_writers[(key, mode)][uttid] = seq
-
-                if iiter % log_interval == 0:
-                    logging.info(f"Niter: {iiter}")
-
-        for key in sum_dict:
-            np.savez(
-                output_dir / mode / f"{key}_stats.npz",
-                count=count_dict[key],
-                sum=sum_dict[key],
-                sum_square=sq_dict[key],
-            )
-
-        # batch_keys and stats_keys are used by aggregate_stats_dirs.py
-        with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f:
-            f.write(
-                "\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n"
-            )
-        with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f:
-            f.write("\n".join(sum_dict) + "\n")
diff --git a/funasr/train_utils/model_summary.py b/funasr/train_utils/model_summary.py
index 8d7f14f..1001160 100644
--- a/funasr/train_utils/model_summary.py
+++ b/funasr/train_utils/model_summary.py
@@ -1,4 +1,3 @@
-import humanfriendly
 import numpy as np
 import torch
 
@@ -59,12 +58,7 @@
     message += (
         f"    Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
     )
-    num_bytes = humanfriendly.format_size(
-        sum(
-            p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad
-        )
-    )
-    message += f"    Size: {num_bytes}\n"
+
     dtype = next(iter(model.parameters())).dtype
     message += f"    Type: {dtype}"
     return message
diff --git a/funasr/train_utils/pack_funcs.py b/funasr/train_utils/pack_funcs.py
deleted file mode 100644
index fe365d8..0000000
--- a/funasr/train_utils/pack_funcs.py
+++ /dev/null
@@ -1,302 +0,0 @@
-from datetime import datetime
-from io import BytesIO
-from io import TextIOWrapper
-import os
-from pathlib import Path
-import sys
-import tarfile
-from typing import Dict
-from typing import Iterable
-from typing import Optional
-from typing import Union
-import zipfile
-
-import yaml
-
-
-class Archiver:
-    def __init__(self, file, mode="r"):
-        if Path(file).suffix == ".tar":
-            self.type = "tar"
-        elif Path(file).suffix == ".tgz" or Path(file).suffixes == [".tar", ".gz"]:
-            self.type = "tar"
-            if mode == "w":
-                mode = "w:gz"
-        elif Path(file).suffix == ".tbz2" or Path(file).suffixes == [".tar", ".bz2"]:
-            self.type = "tar"
-            if mode == "w":
-                mode = "w:bz2"
-        elif Path(file).suffix == ".txz" or Path(file).suffixes == [".tar", ".xz"]:
-            self.type = "tar"
-            if mode == "w":
-                mode = "w:xz"
-        elif Path(file).suffix == ".zip":
-            self.type = "zip"
-        else:
-            raise ValueError(f"Cannot detect archive format: type={file}")
-
-        if self.type == "tar":
-            self.fopen = tarfile.open(file, mode=mode)
-        elif self.type == "zip":
-
-            self.fopen = zipfile.ZipFile(file, mode=mode)
-        else:
-            raise ValueError(f"Not supported: type={type}")
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        self.fopen.close()
-
-    def close(self):
-        self.fopen.close()
-
-    def __iter__(self):
-        if self.type == "tar":
-            return iter(self.fopen)
-        elif self.type == "zip":
-            return iter(self.fopen.infolist())
-        else:
-            raise ValueError(f"Not supported: type={self.type}")
-
-    def add(self, filename, arcname=None, recursive: bool = True):
-        if arcname is not None:
-            print(f"adding: {arcname}")
-        else:
-            print(f"adding: {filename}")
-
-        if recursive and Path(filename).is_dir():
-            for f in Path(filename).glob("**/*"):
-                if f.is_dir():
-                    continue
-
-                if arcname is not None:
-                    _arcname = Path(arcname) / f
-                else:
-                    _arcname = None
-
-                self.add(f, _arcname)
-            return
-
-        if self.type == "tar":
-            return self.fopen.add(filename, arcname)
-        elif self.type == "zip":
-            return self.fopen.write(filename, arcname)
-        else:
-            raise ValueError(f"Not supported: type={self.type}")
-
-    def addfile(self, info, fileobj):
-        print(f"adding: {self.get_name_from_info(info)}")
-
-        if self.type == "tar":
-            return self.fopen.addfile(info, fileobj)
-        elif self.type == "zip":
-            return self.fopen.writestr(info, fileobj.read())
-        else:
-            raise ValueError(f"Not supported: type={self.type}")
-
-    def generate_info(self, name, size) -> Union[tarfile.TarInfo, zipfile.ZipInfo]:
-        """Generate TarInfo using system information"""
-        if self.type == "tar":
-            tarinfo = tarfile.TarInfo(str(name))
-            if os.name == "posix":
-                tarinfo.gid = os.getgid()
-                tarinfo.uid = os.getuid()
-            tarinfo.mtime = datetime.now().timestamp()
-            tarinfo.size = size
-            # Keep mode as default
-            return tarinfo
-        elif self.type == "zip":
-            zipinfo = zipfile.ZipInfo(str(name), datetime.now().timetuple()[:6])
-            zipinfo.file_size = size
-            return zipinfo
-        else:
-            raise ValueError(f"Not supported: type={self.type}")
-
-    def get_name_from_info(self, info):
-        if self.type == "tar":
-            assert isinstance(info, tarfile.TarInfo), type(info)
-            return info.name
-        elif self.type == "zip":
-            assert isinstance(info, zipfile.ZipInfo), type(info)
-            return info.filename
-        else:
-            raise ValueError(f"Not supported: type={self.type}")
-
-    def extract(self, info, path=None):
-        if self.type == "tar":
-            return self.fopen.extract(info, path)
-        elif self.type == "zip":
-            return self.fopen.extract(info, path)
-        else:
-            raise ValueError(f"Not supported: type={self.type}")
-
-    def extractfile(self, info, mode="r"):
-        if self.type == "tar":
-            f = self.fopen.extractfile(info)
-            if mode == "r":
-                return TextIOWrapper(f)
-            else:
-                return f
-        elif self.type == "zip":
-            if mode == "rb":
-                mode = "r"
-            return self.fopen.open(info, mode)
-        else:
-            raise ValueError(f"Not supported: type={self.type}")
-
-
-def find_path_and_change_it_recursive(value, src: str, tgt: str):
-    if isinstance(value, dict):
-        return {
-            k: find_path_and_change_it_recursive(v, src, tgt) for k, v in value.items()
-        }
-    elif isinstance(value, (list, tuple)):
-        return [find_path_and_change_it_recursive(v, src, tgt) for v in value]
-    elif isinstance(value, str) and Path(value) == Path(src):
-        return tgt
-    else:
-        return value
-
-
-def get_dict_from_cache(meta: Union[Path, str]) -> Optional[Dict[str, str]]:
-    meta = Path(meta)
-    outpath = meta.parent.parent
-    if not meta.exists():
-        return None
-
-    with meta.open("r", encoding="utf-8") as f:
-        d = yaml.safe_load(f)
-        assert isinstance(d, dict), type(d)
-        yaml_files = d["yaml_files"]
-        files = d["files"]
-        assert isinstance(yaml_files, dict), type(yaml_files)
-        assert isinstance(files, dict), type(files)
-
-        retval = {}
-        for key, value in list(yaml_files.items()) + list(files.items()):
-            if not (outpath / value).exists():
-                return None
-            retval[key] = str(outpath / value)
-        return retval
-
-
-def unpack(
-    input_archive: Union[Path, str],
-    outpath: Union[Path, str],
-    use_cache: bool = True,
-) -> Dict[str, str]:
-    """Scan all files in the archive file and return as a dict of files.
-
-    Examples:
-        tarfile:
-           model.pb
-           some1.file
-           some2.file
-
-        >>> unpack("tarfile", "out")
-        {'asr_model_file': 'out/model.pb'}
-    """
-    input_archive = Path(input_archive)
-    outpath = Path(outpath)
-
-    with Archiver(input_archive) as archive:
-        for info in archive:
-            if Path(archive.get_name_from_info(info)).name == "meta.yaml":
-                if (
-                    use_cache
-                    and (outpath / Path(archive.get_name_from_info(info))).exists()
-                ):
-                    retval = get_dict_from_cache(
-                        outpath / Path(archive.get_name_from_info(info))
-                    )
-                    if retval is not None:
-                        return retval
-                d = yaml.safe_load(archive.extractfile(info))
-                assert isinstance(d, dict), type(d)
-                yaml_files = d["yaml_files"]
-                files = d["files"]
-                assert isinstance(yaml_files, dict), type(yaml_files)
-                assert isinstance(files, dict), type(files)
-                break
-        else:
-            raise RuntimeError("Format error: not found meta.yaml")
-
-        for info in archive:
-            fname = archive.get_name_from_info(info)
-            outname = outpath / fname
-            outname.parent.mkdir(parents=True, exist_ok=True)
-            if fname in set(yaml_files.values()):
-                d = yaml.safe_load(archive.extractfile(info))
-                # Rewrite yaml
-                for info2 in archive:
-                    name = archive.get_name_from_info(info2)
-                    d = find_path_and_change_it_recursive(d, name, str(outpath / name))
-                with outname.open("w", encoding="utf-8") as f:
-                    yaml.safe_dump(d, f)
-            else:
-                archive.extract(info, path=outpath)
-
-        retval = {}
-        for key, value in list(yaml_files.items()) + list(files.items()):
-            retval[key] = str(outpath / value)
-        return retval
-
-
-def _to_relative_or_resolve(f):
-    # Resolve to avoid symbolic link
-    p = Path(f).resolve()
-    try:
-        # Change to relative if it can
-        p = p.relative_to(Path(".").resolve())
-    except ValueError:
-        pass
-    return str(p)
-
-
-def pack(
-    files: Dict[str, Union[str, Path]],
-    yaml_files: Dict[str, Union[str, Path]],
-    outpath: Union[str, Path],
-    option: Iterable[Union[str, Path]] = (),
-):
-    for v in list(files.values()) + list(yaml_files.values()) + list(option):
-        if not Path(v).exists():
-            raise FileNotFoundError(f"No such file or directory: {v}")
-
-    files = {k: _to_relative_or_resolve(v) for k, v in files.items()}
-    yaml_files = {k: _to_relative_or_resolve(v) for k, v in yaml_files.items()}
-    option = [_to_relative_or_resolve(v) for v in option]
-
-    meta_objs = dict(
-        files=files,
-        yaml_files=yaml_files,
-        timestamp=datetime.now().timestamp(),
-        python=sys.version,
-    )
-
-    try:
-        import torch
-
-        meta_objs.update(torch=str(torch.__version__))
-    except ImportError:
-        pass
-    try:
-        import espnet
-
-        meta_objs.update(espnet=espnet.__version__)
-    except ImportError:
-        pass
-
-    Path(outpath).parent.mkdir(parents=True, exist_ok=True)
-    with Archiver(outpath, mode="w") as archive:
-        # Write packed/meta.yaml
-        fileobj = BytesIO(yaml.safe_dump(meta_objs).encode())
-        info = archive.generate_info("meta.yaml", fileobj.getbuffer().nbytes)
-        archive.addfile(info, fileobj=fileobj)
-
-        for f in list(yaml_files.values()) + list(files.values()) + list(option):
-            archive.add(f)
-
-    print(f"Generate: {outpath}")
diff --git a/funasr/train_utils/pytorch_version.py b/funasr/train_utils/pytorch_version.py
deleted file mode 100644
index 01f17cc..0000000
--- a/funasr/train_utils/pytorch_version.py
+++ /dev/null
@@ -1,16 +0,0 @@
-import torch
-
-
-def pytorch_cudnn_version() -> str:
-    message = (
-        f"pytorch.version={torch.__version__}, "
-        f"cuda.available={torch.cuda.is_available()}, "
-    )
-
-    if torch.backends.cudnn.enabled:
-        message += (
-            f"cudnn.version={torch.backends.cudnn.version()}, "
-            f"cudnn.benchmark={torch.backends.cudnn.benchmark}, "
-            f"cudnn.deterministic={torch.backends.cudnn.deterministic}"
-        )
-    return message
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 425b79f..ea502f7 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -55,7 +55,7 @@
 		self.dataloader_val = dataloader_val
 		self.output_dir = kwargs.get('output_dir', './')
 		self.resume = kwargs.get('resume', True)
-		self.start_epoch = 1
+		self.start_epoch = 0
 		self.max_epoch = kwargs.get('max_epoch', 100)
 		self.local_rank = local_rank
 		self.use_ddp = use_ddp
@@ -123,7 +123,7 @@
 		for epoch in range(self.start_epoch, self.max_epoch + 1):
 			self._train_epoch(epoch)
 			# self._validate_epoch(epoch)
-			if dist.get_rank() == 0:
+			if self.rank == 0:
 				self._save_checkpoint(epoch)
 			self.scheduler.step()
 			break
@@ -201,21 +201,22 @@
 				speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
 	
 				speed_stats["total_time"] = total_time
-				
 			
+			# import pdb;
+			# pdb.set_trace()
 			pbar.update(1)
 			if self.local_rank == 0:
 				description = (
 					f"Epoch: {epoch + 1}/{self.max_epoch}, "
 					f"step {batch_idx}/{len(self.dataloader_train)}, "
 					f"{speed_stats}, "
-					f"(loss: {loss.detach().float():.3f}), "
+					f"(loss: {loss.detach().cpu().item():.3f}), "
 					f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
 				)
 				pbar.set_description(description)
 			
-			if batch_idx == 2:
-				break
+			# if batch_idx == 2:
+			# 	break
 		pbar.close()
 
 	def _validate_epoch(self, epoch):
diff --git a/funasr/utils/class_choices.py b/funasr/utils/class_choices.py
deleted file mode 100644
index 1ffb97a..0000000
--- a/funasr/utils/class_choices.py
+++ /dev/null
@@ -1,90 +0,0 @@
-from typing import Mapping
-from typing import Optional
-from typing import Tuple
-
-
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import str_or_none
-
-
-class ClassChoices:
-    """Helper class to manage the options for variable objects and its configuration.
-
-    Example:
-
-    >>> class A:
-    ...     def __init__(self, foo=3):  pass
-    >>> class B:
-    ...     def __init__(self, bar="aaaa"):  pass
-    >>> choices = ClassChoices("var", dict(a=A, b=B), default="a")
-    >>> import argparse
-    >>> parser = argparse.ArgumentParser()
-    >>> choices.add_arguments(parser)
-    >>> args = parser.parse_args(["--var", "a", "--var_conf", "foo=4")
-    >>> args.var
-    a
-    >>> args.var_conf
-    {"foo": 4}
-    >>> class_obj = choices.get_class(args.var)
-    >>> a_object = class_obj(**args.var_conf)
-
-    """
-
-    def __init__(
-        self,
-        name: str,
-        classes: Mapping[str, type],
-        type_check: type = None,
-        default: str = None,
-        optional: bool = False,
-    ):
-        self.name = name
-        self.base_type = type_check
-        self.classes = {k.lower(): v for k, v in classes.items()}
-        if "none" in self.classes or "nil" in self.classes or "null" in self.classes:
-            raise ValueError('"none", "nil", and "null" are reserved.')
-        if type_check is not None:
-            for v in self.classes.values():
-                if not issubclass(v, type_check):
-                    raise ValueError(f"must be {type_check.__name__}, but got {v}")
-
-        self.optional = optional
-        self.default = default
-        if default is None:
-            self.optional = True
-
-    def choices(self) -> Tuple[Optional[str], ...]:
-        retval = tuple(self.classes)
-        if self.optional:
-            return retval + (None,)
-        else:
-            return retval
-
-    def get_class(self, name: Optional[str]) -> Optional[type]:
-        if name is None or (self.optional and name.lower() == ("none", "null", "nil")):
-            retval = None
-        elif name.lower() in self.classes:
-            class_obj = self.classes[name]
-            retval = class_obj
-        else:
-            raise ValueError(
-                f"--{self.name} must be one of {self.choices()}: "
-                f"--{self.name} {name.lower()}"
-            )
-
-        return retval
-
-    def add_arguments(self, parser):
-        parser.add_argument(
-            f"--{self.name}",
-            type=lambda x: str_or_none(x.lower()),
-            default=self.default,
-            choices=self.choices(),
-            help=f"The {self.name} type",
-        )
-        parser.add_argument(
-            f"--{self.name}_conf",
-            action=NestedDictAction,
-            default=dict(),
-            help=f"The keyword arguments for {self.name}",
-        )
diff --git a/funasr/utils/cli_utils.py b/funasr/utils/cli_utils.py
deleted file mode 100644
index c4a4cd1..0000000
--- a/funasr/utils/cli_utils.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from collections.abc import Sequence
-from distutils.util import strtobool as dist_strtobool
-import sys
-
-import numpy
-
-
-def strtobool(x):
-    # distutils.util.strtobool returns integer, but it's confusing,
-    return bool(dist_strtobool(x))
-
-
-def get_commandline_args():
-    extra_chars = [
-        " ",
-        ";",
-        "&",
-        "(",
-        ")",
-        "|",
-        "^",
-        "<",
-        ">",
-        "?",
-        "*",
-        "[",
-        "]",
-        "$",
-        "`",
-        '"',
-        "\\",
-        "!",
-        "{",
-        "}",
-    ]
-
-    # Escape the extra characters for shell
-    argv = [
-        arg.replace("'", "'\\''")
-        if all(char not in arg for char in extra_chars)
-        else "'" + arg.replace("'", "'\\''") + "'"
-        for arg in sys.argv
-    ]
-
-    return sys.executable + " " + " ".join(argv)
-
-
-def is_scipy_wav_style(value):
-    # If Tuple[int, numpy.ndarray] or not
-    return (
-        isinstance(value, Sequence)
-        and len(value) == 2
-        and isinstance(value[0], int)
-        and isinstance(value[1], numpy.ndarray)
-    )
-
-
-def assert_scipy_wav_style(value):
-    assert is_scipy_wav_style(
-        value
-    ), "Must be Tuple[int, numpy.ndarray], but got {}".format(
-        type(value)
-        if not isinstance(value, Sequence)
-        else "{}[{}]".format(type(value), ", ".join(str(type(v)) for v in value))
-    )
diff --git a/funasr/utils/dynamic_import.py b/funasr/utils/dynamic_import.py
deleted file mode 100644
index 2830cb2..0000000
--- a/funasr/utils/dynamic_import.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import importlib
-
-
-def dynamic_import(import_path):
-    """dynamic import module and class
-
-    :param str import_path: syntax 'module_name:class_name'
-    :return: imported class
-    """
-
-    module_name, objname = import_path.split(":")
-    m = importlib.import_module(module_name)
-    return getattr(m, objname)
diff --git a/funasr/utils/load_fr_tf.py b/funasr/utils/load_fr_tf.py
deleted file mode 100644
index 5c8c275..0000000
--- a/funasr/utils/load_fr_tf.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import numpy as np
-np.set_printoptions(threshold=np.inf)
-import logging
-
-def load_ckpt(checkpoint_path):
-	import tensorflow as tf
-	if tf.__version__.startswith('2'):
-		import tensorflow.compat.v1 as tf
-		tf.disable_v2_behavior()
-		reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path)
-	else:
-		from tensorflow.python import pywrap_tensorflow
-		reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
-	var_to_shape_map = reader.get_variable_to_shape_map()
-
-	var_dict = dict()
-	for var_name in sorted(var_to_shape_map):
-		if "Adam" in var_name:
-			continue
-		tensor = reader.get_tensor(var_name)
-		# print("in ckpt: {}, {}".format(var_name, tensor.shape))
-		# print(tensor)
-		var_dict[var_name] = tensor
-
-	return var_dict
-
-
-
-def load_tf_pb_dict(pb_model):
-	import tensorflow as tf
-	if tf.__version__.startswith('2'):
-		import tensorflow.compat.v1 as tf
-		tf.disable_v2_behavior()
-		# import tensorflow_addons as tfa
-		# from tensorflow_addons.seq2seq.python.ops import beam_search_ops
-	else:
-		from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
-	from tensorflow.python.ops import lookup_ops as lookup
-	from tensorflow.python.framework import tensor_util
-	from tensorflow.python.platform import gfile
-	
-	sess = tf.Session()
-	with gfile.FastGFile(pb_model, 'rb') as f:
-		graph_def = tf.GraphDef()
-		graph_def.ParseFromString(f.read())
-		sess.graph.as_default()
-		tf.import_graph_def(graph_def, name='')
-	
-	var_dict = dict()
-	for node in sess.graph_def.node:
-		if node.op == 'Const':
-			value = tensor_util.MakeNdarray(node.attr['value'].tensor)
-			if len(value.shape) >= 1:
-				var_dict[node.name] = value
-	return var_dict
-
-def load_tf_dict(pb_model):
-	if "model.ckpt-" in pb_model:
-		var_dict = load_ckpt(pb_model)
-	else:
-		var_dict = load_tf_pb_dict(pb_model)
-	return var_dict
diff --git a/funasr/utils/register.py b/funasr/utils/register.py
new file mode 100644
index 0000000..0dfcdab
--- /dev/null
+++ b/funasr/utils/register.py
@@ -0,0 +1,72 @@
+import logging
+import inspect
+from dataclasses import dataclass, fields
+
+
+@dataclass
+class ClassRegistryTables:
+    model_classes = {}
+    frontend_classes = {}
+    specaug_classes = {}
+    normalize_classes = {}
+    encoder_classes = {}
+    decoder_classes = {}
+    joint_network_classes = {}
+    predictor_classes = {}
+    stride_conv_classes = {}
+    tokenizer_classes = {}
+    batch_sampler_classes = {}
+    dataset_classes = {}
+    index_ds_classes = {}
+
+    def print_register_tables(self,):
+        print("\nregister_tables: \n")
+        fields = vars(self)
+        for classes_key, classes_dict in fields.items():
+            print(f"-----------    ** {classes_key.replace('_meta', '')} **    --------------")
+        
+            if classes_key.endswith("_meta"):
+                headers = ["class name", "register name", "class location"]
+                metas = []
+                for register_key, meta in classes_dict.items():
+                    metas.append(meta)
+                metas.sort(key=lambda x: x[0])
+                data = [headers] + metas
+                col_widths = [max(len(str(item)) for item in col) for col in zip(*data)]
+            
+                for row in data:
+                    print("| " + " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths)) + " |")
+        print("\n")
+
+registry_tables = ClassRegistryTables()
+
+def register_class(registry_tables_key:str, key=None):
+    def decorator(target_class):
+        
+        if not hasattr(registry_tables, registry_tables_key):
+            setattr(registry_tables, registry_tables_key, {})
+            logging.info("new registry table has been added: {}".format(registry_tables_key))
+
+        registry = getattr(registry_tables, registry_tables_key)
+        registry_key = key if key is not None else target_class.__name__
+        registry_key = registry_key.lower()
+        # import pdb; pdb.set_trace()
+        assert not registry_key in registry, "(key: {} / class: {}) has been registered already锛宨n {}".format(registry_key, target_class, registry_tables_key)
+
+        registry[registry_key] = target_class
+        
+        # meta锛� headers = ["class name", "register name", "class location"]
+        registry_tables_key_meta = registry_tables_key + "_meta"
+        if not hasattr(registry_tables, registry_tables_key_meta):
+            setattr(registry_tables, registry_tables_key_meta, {})
+        registry_meta = getattr(registry_tables, registry_tables_key_meta)
+        class_file = inspect.getfile(target_class)
+        class_line = inspect.getsourcelines(target_class)[1]
+        meata_data = [f"{target_class.__name__}", f"{registry_key}", f"{class_file}:{class_line}"]
+        registry_meta[registry_key] = meata_data
+        # print(f"Registering class: {class_file}:{class_line} - {target_class.__name__} as {registry_key}")
+        return target_class
+    return decorator
+
+import funasr
+
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index c463f0c..8186dff 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -269,7 +269,7 @@
 
 
 def convert_external_alphas(alphas_file, text_file, output_file):
-    from funasr.models.predictor.cif import cif_wo_hidden
+    from funasr.models.paraformer.cif_predictor import cif_wo_hidden
     with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
         for line1, line2 in zip(f1.readlines(), f2.readlines()):
             line1 = line1.rstrip()
diff --git a/funasr/utils/yaml_no_alias_safe_dump.py b/funasr/utils/yaml_no_alias_safe_dump.py
deleted file mode 100644
index 70a7b0e..0000000
--- a/funasr/utils/yaml_no_alias_safe_dump.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import yaml
-
-
-class NoAliasSafeDumper(yaml.SafeDumper):
-    # Disable anchor/alias in yaml because looks ugly
-    def ignore_aliases(self, data):
-        return True
-
-
-def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
-    """Safe-dump in yaml with no anchor/alias"""
-    return yaml.dump(
-        data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
-    )

--
Gitblit v1.9.1