From 806a03609df033d61f824f1ab8527eb88fe837ad Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 十二月 2023 19:43:13 +0800
Subject: [PATCH] funasr2 paraformer biciparaformer contextuaparaformer

---
 funasr/models/paraformer/model.py                              | 1760 +++++++++++++++++++++++++++++++++++++++
 funasr/models/paraformer/search.py                             |  453 ++++++++++
 .gitignore                                                     |    1 
 examples/industrial_data_pretraining/paraformer-large/infer.sh |   15 
 funasr/models/frontend/wav_frontend.py                         |    2 
 examples/industrial_data_pretraining/paraformer-large/run.sh   |    2 
 funasr/bin/inference.py                                        |  170 +++
 funasr/datasets/fun_datasets/load_audio_extract_fbank.py       |   75 +
 /dev/null                                                      |   81 -
 funasr/tokenizer/abs_tokenizer.py                              |    4 
 funasr/datasets/fun_datasets/__init__.py                       |    0 
 funasr/utils/download_from_hub.py                              |   24 
 funasr/cli/train_cli.py                                        |    8 
 funasr/datasets/dataset_jsonl.py                               |   33 
 funasr/models/paraformer/__init__.py                           |    0 
 funasr/__init__.py                                             |    2 
 funasr/bin/asr_inference_launch.py                             |   31 
 17 files changed, 2,501 insertions(+), 160 deletions(-)

diff --git a/.gitignore b/.gitignore
index 37f39fe..dea4634 100644
--- a/.gitignore
+++ b/.gitignore
@@ -21,3 +21,4 @@
 modelscope
 samples
 .ipynb_checkpoints
+outputs*
diff --git a/examples/industrial_data_pretraining/paraformer-large/infer.sh b/examples/industrial_data_pretraining/paraformer-large/infer.sh
new file mode 100644
index 0000000..b7fbe75
--- /dev/null
+++ b/examples/industrial_data_pretraining/paraformer-large/infer.sh
@@ -0,0 +1,15 @@
+
+cmd="funasr/bin/inference.py"
+
+python $cmd \
++model="/Users/zhifu/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
++input="/Users/zhifu/Downloads/asr_example.wav" \
++output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
++device="cpu" \
++"hotword='杈鹃瓟闄� 榄旀惌'"
+
+#+input="/Users/zhifu/funasr_github/test_local/asr_example.wav" \
+#+input="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
+#+model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+
+#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/paraformer-large/run.sh b/examples/industrial_data_pretraining/paraformer-large/run.sh
index 9b40b81..8571974 100644
--- a/examples/industrial_data_pretraining/paraformer-large/run.sh
+++ b/examples/industrial_data_pretraining/paraformer-large/run.sh
@@ -2,7 +2,7 @@
 cmd="funasr/cli/train_cli.py"
 
 python $cmd \
-+model_pretrain="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
++model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
 +token_list="/Users/zhifu/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/tokens.txt" \
 +train_data_set_list="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
 +output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
diff --git a/funasr/__init__.py b/funasr/__init__.py
index d0b7aa5..ac1591d 100644
--- a/funasr/__init__.py
+++ b/funasr/__init__.py
@@ -7,4 +7,4 @@
 with open(version_file, "r") as f:
     __version__ = f.read().strip()
 
-from funasr.bin.inference_cli import infer
\ No newline at end of file
+from funasr.bin.inference import infer
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index f34bfb2..6151d28 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1254,37 +1254,6 @@
 
         return cache
 
-    #def _prepare_cache(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
-    #    if len(cache) > 0:
-    #        return cache
-    #    config = _read_yaml(asr_train_config)
-    #    enc_output_size = config["encoder_conf"]["output_size"]
-    #    feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
-    #    cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
-    #                "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
-    #                "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)), "tail_chunk": False}
-    #    cache["encoder"] = cache_en
-
-    #    cache_de = {"decode_fsmn": None}
-    #    cache["decoder"] = cache_de
-
-    #    return cache
-
-    #def _cache_reset(cache: dict = {}, chunk_size=[5, 10, 5], batch_size=1):
-    #    if len(cache) > 0:
-    #        config = _read_yaml(asr_train_config)
-    #        enc_output_size = config["encoder_conf"]["output_size"]
-    #        feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
-    #        cache_en = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
-    #                    "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size, "last_chunk": False,
-    #                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
-    #                    "tail_chunk": False}
-    #        cache["encoder"] = cache_en
-
-    #        cache_de = {"decode_fsmn": None}
-    #        cache["decoder"] = cache_de
-
-    #    return cache
 
     def _forward(
             data_path_and_name_and_type,
diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py
deleted file mode 100755
index 8161e7b..0000000
--- a/funasr/bin/asr_train.py
+++ /dev/null
@@ -1,67 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import os
-
-from funasr.tasks.asr import ASRTask
-
-
-# for ASR Training
-def parse_args():
-    parser = ASRTask.get_parser()
-    parser.add_argument(
-        "--mode",
-        type=str,
-        default="asr",
-        help=" ",
-    )
-    parser.add_argument(
-        "--gpu_id",
-        type=int,
-        default=0,
-        help="local gpu id.",
-    )
-    args = parser.parse_args()
-    return args
-
-
-def main(args=None, cmd=None):
-    
-    # for ASR Training
-    if args.mode == "asr":
-        from funasr.tasks.asr import ASRTask
-    if args.mode == "paraformer":
-        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-    if args.mode == "uniasr":
-        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
-    if args.mode == "rnnt":
-        from funasr.tasks.asr import ASRTransducerTask as ASRTask    
-
-    ASRTask.main(args=args, cmd=cmd)
-
-
-if __name__ == '__main__':
-    args = parse_args()
-
-    # setup local gpu_id
-    if args.ngpu > 0:
-        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-
-    # DDP settings
-    if args.ngpu > 1:
-        args.distributed = True
-    else:
-        args.distributed = False
-    assert args.num_worker_count == 1
-
-    # re-compute batch size: when dataset type is small
-    if args.dataset_type == "small":
-        if args.batch_size is not None and args.ngpu > 0:
-            args.batch_size = args.batch_size * args.ngpu
-        if args.batch_bins is not None and args.ngpu > 0:
-            args.batch_bins = args.batch_bins * args.ngpu
-
-    main(args=args)
-
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
new file mode 100644
index 0000000..4e140e1
--- /dev/null
+++ b/funasr/bin/inference.py
@@ -0,0 +1,170 @@
+import os.path
+
+import torch
+import numpy as np
+import hydra
+import json
+from omegaconf import DictConfig, OmegaConf
+from funasr.utils.dynamic_import import dynamic_import
+import logging
+from funasr.utils.download_from_hub import download_model
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.tokenizer.funtoken import build_tokenizer
+from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_bytes
+from funasr.torch_utils.device_funcs import to_device
+from tqdm import tqdm
+from funasr.torch_utils.load_pretrained_model import load_pretrained_model
+import time
+import random
+import string
+
+@hydra.main(config_name=None, version_base=None)
+def main_hydra(kwargs: DictConfig):
+	assert "model" in kwargs
+
+	pipeline = infer(**kwargs)
+	res = pipeline(input=kwargs["input"])
+	print(res)
+	
+def infer(**kwargs):
+	
+	if ":" not in kwargs["model"]:
+		logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+		kwargs = download_model(**kwargs)
+	
+	set_all_random_seed(kwargs.get("seed", 0))
+
+	
+	device = kwargs.get("device", "cuda")
+	if not torch.cuda.is_available() or kwargs.get("ngpu", 1):
+		device = "cpu"
+		batch_size = 1
+	kwargs["device"] = device
+	
+	# build_tokenizer
+	tokenizer = build_tokenizer(
+		token_type=kwargs.get("token_type", "char"),
+		bpemodel=kwargs.get("bpemodel", None),
+		delimiter=kwargs.get("delimiter", None),
+		space_symbol=kwargs.get("space_symbol", "<space>"),
+		non_linguistic_symbols=kwargs.get("non_linguistic_symbols", None),
+		g2p_type=kwargs.get("g2p_type", None),
+		token_list=kwargs.get("token_list", None),
+		unk_symbol=kwargs.get("unk_symbol", "<unk>"),
+	)
+
+	import pdb;
+	pdb.set_trace()
+	# build model
+	model_class = dynamic_import(kwargs.get("model"))
+	model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list))
+	model.eval()
+	model.to(device)
+	frontend = model.frontend
+	kwargs["token_list"] = tokenizer.token_list
+	
+	
+	# init_param
+	init_param = kwargs.get("init_param", None)
+	if init_param is not None:
+		logging.info(f"Loading pretrained params from {init_param}")
+		load_pretrained_model(
+			model=model,
+			init_param=init_param,
+			ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
+			oss_bucket=kwargs.get("oss_bucket", None),
+		)
+	
+	def _forward(input, input_len=None, **cfg):
+		cfg = OmegaConf.merge(kwargs, cfg)
+		date_type = cfg.get("date_type", "sound")
+		
+		key_list, data_list = build_iter_for_infer(input, input_len=input_len, date_type=date_type, frontend=frontend)
+		
+		speed_stats = {}
+		asr_result_list = []
+		num_samples = len(data_list)
+		pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+		for beg_idx in range(0, num_samples, batch_size):
+
+			end_idx = min(num_samples, beg_idx + batch_size)
+			data_batch = data_list[beg_idx:end_idx]
+			key_batch = key_list[beg_idx:end_idx]
+			batch = {"data_in": data_batch, "key": key_batch}
+			
+			time1 = time.perf_counter()
+			results, meta_data = model.generate(**batch, tokenizer=tokenizer, **cfg)
+			time2 = time.perf_counter()
+			
+			asr_result_list.append(results)
+			pbar.update(1)
+			
+			# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
+			batch_data_time = meta_data.get("batch_data_time", -1)
+			speed_stats["load_data"] = meta_data["load_data"]
+			speed_stats["extract_feat"] = meta_data["extract_feat"]
+			speed_stats["forward"] = f"{time2 - time1:0.3f}"
+			speed_stats["rtf"] = f"{(time2 - time1)/batch_data_time:0.3f}"
+			description = (
+				f"{speed_stats}, "
+			)
+			pbar.set_description(description)
+		
+		torch.cuda.empty_cache()
+		return asr_result_list
+	
+	return _forward
+	
+
+def build_iter_for_infer(data_in, input_len=None, date_type="sound", frontend=None):
+	"""
+	
+	:param input:
+	:param input_len:
+	:param date_type:
+	:param frontend:
+	:return:
+	"""
+	data_list = []
+	key_list = []
+	filelist = [".scp", ".txt", ".json", ".jsonl"]
+	
+	chars = string.ascii_letters + string.digits
+	
+	if isinstance(data_in, str) and os.path.exists(data_in): # wav_pat; filelist: wav.scp, file.jsonl;text.txt;
+		_, file_extension = os.path.splitext(data_in)
+		file_extension = file_extension.lower()
+		if file_extension in filelist: #filelist: wav.scp, file.jsonl;text.txt;
+			with open(data_in, encoding='utf-8') as fin:
+				for line in fin:
+					key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
+					if data_in.endswith(".jsonl"): #file.jsonl: json.dumps({"source": data})
+						lines = json.loads(line.strip())
+						data = lines["source"]
+						key = data["key"] if "key" in data else key
+					else: # filelist, wav.scp, text.txt: id \t data or data
+						lines = line.strip().split()
+						data = lines[1] if len(lines)>1 else lines[0]
+						key = lines[0] if len(lines)>1 else key
+					
+					data_list.append(data)
+					key_list.append(key)
+		else:
+			key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
+			data_list = [data_in]
+			key_list = [key]
+	elif isinstance(data_in, (list, tuple)): # [audio sample point, fbank, wav_path]
+		data_list = data_in
+		key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
+	else: # raw text; audio sample point, fbank
+		if isinstance(data_in, bytes): # audio bytes
+			data_in = load_bytes(data_in)
+		key = "rand_key_" + ''.join(random.choice(chars) for _ in range(13))
+		data_list = [data_in]
+		key_list = [key]
+	
+	return key_list, data_list
+
+
+if __name__ == '__main__':
+	main_hydra()
\ No newline at end of file
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
deleted file mode 100755
index 6aebf8a..0000000
--- a/funasr/bin/train.py
+++ /dev/null
@@ -1,595 +0,0 @@
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import argparse
-import logging
-import os
-import sys
-from io import BytesIO
-
-import torch
-
-from funasr.build_utils.build_args import build_args
-from funasr.build_utils.build_dataloader import build_dataloader
-from funasr.build_utils.build_distributed import build_distributed
-from funasr.build_utils.build_model import build_model
-from funasr.build_utils.build_optimizer import build_optimizer
-from funasr.build_utils.build_scheduler import build_scheduler
-from funasr.build_utils.build_trainer import build_trainer
-from funasr.tokenizer.phoneme_tokenizer import g2p_choices
-from funasr.torch_utils.load_pretrained_model import load_pretrained_model
-from funasr.torch_utils.model_summary import model_summary
-from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.prepare_data import prepare_data
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
-from funasr.modules.lora.utils import mark_only_lora_as_trainable
-
-
-def get_parser():
-    parser = argparse.ArgumentParser(
-        description="FunASR Common Training Parser",
-    )
-
-    # common configuration
-    parser.add_argument("--output_dir", help="model save path")
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
-
-    # ddp related
-    parser.add_argument(
-        "--dist_backend",
-        default="nccl",
-        type=str,
-        help="distributed backend",
-    )
-    parser.add_argument(
-        "--dist_init_method",
-        type=str,
-        default="env://",
-        help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
-             '"WORLD_SIZE", and "RANK" are referred.',
-    )
-    parser.add_argument(
-        "--dist_world_size",
-        type=int,
-        default=1,
-        help="number of nodes for distributed training",
-    )
-    parser.add_argument(
-        "--dist_rank",
-        type=int,
-        default=None,
-        help="node rank for distributed training",
-    )
-    parser.add_argument(
-        "--local_rank",
-        type=int,
-        default=None,
-        help="local rank for distributed training",
-    )
-    parser.add_argument(
-        "--dist_master_addr",
-        default=None,
-        type=str_or_none,
-        help="The master address for distributed training. "
-             "This value is used when dist_init_method == 'env://'",
-    )
-    parser.add_argument(
-        "--dist_master_port",
-        default=None,
-        type=int_or_none,
-        help="The master port for distributed training"
-             "This value is used when dist_init_method == 'env://'",
-    )
-    parser.add_argument(
-        "--dist_launcher",
-        default=None,
-        type=str_or_none,
-        choices=["slurm", "mpi", None],
-        help="The launcher type for distributed training",
-    )
-    parser.add_argument(
-        "--multiprocessing_distributed",
-        default=True,
-        type=str2bool,
-        help="Use multi-processing distributed training to launch "
-             "N processes per node, which has N GPUs. This is the "
-             "fastest way to use PyTorch for either single node or "
-             "multi node data parallel training",
-    )
-    parser.add_argument(
-        "--unused_parameters",
-        type=str2bool,
-        default=False,
-        help="Whether to use the find_unused_parameters in "
-             "torch.nn.parallel.DistributedDataParallel ",
-    )
-    parser.add_argument(
-        "--gpu_id",
-        type=int,
-        default=0,
-        help="local gpu id.",
-    )
-
-    # cudnn related
-    parser.add_argument(
-        "--cudnn_enabled",
-        type=str2bool,
-        default=torch.backends.cudnn.enabled,
-        help="Enable CUDNN",
-    )
-    parser.add_argument(
-        "--cudnn_benchmark",
-        type=str2bool,
-        default=torch.backends.cudnn.benchmark,
-        help="Enable cudnn-benchmark mode",
-    )
-    parser.add_argument(
-        "--cudnn_deterministic",
-        type=str2bool,
-        default=True,
-        help="Enable cudnn-deterministic mode",
-    )
-
-    # trainer related
-    parser.add_argument(
-        "--max_epoch",
-        type=int,
-        default=40,
-        help="The maximum number epoch to train",
-    )
-    parser.add_argument(
-        "--max_update",
-        type=int,
-        default=sys.maxsize,
-        help="The maximum number update step to train",
-    )
-    parser.add_argument(
-        "--batch_interval",
-        type=int,
-        default=10000,
-        help="The batch interval for saving model.",
-    )
-    parser.add_argument(
-        "--patience",
-        type=int_or_none,
-        default=None,
-        help="Number of epochs to wait without improvement "
-             "before stopping the training",
-    )
-    parser.add_argument(
-        "--val_scheduler_criterion",
-        type=str,
-        nargs=2,
-        default=("valid", "loss"),
-        help="The criterion used for the value given to the lr scheduler. "
-             'Give a pair referring the phase, "train" or "valid",'
-             'and the criterion name. The mode specifying "min" or "max" can '
-             "be changed by --scheduler_conf",
-    )
-    parser.add_argument(
-        "--early_stopping_criterion",
-        type=str,
-        nargs=3,
-        default=("valid", "loss", "min"),
-        help="The criterion used for judging of early stopping. "
-             'Give a pair referring the phase, "train" or "valid",'
-             'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
-    )
-    parser.add_argument(
-        "--best_model_criterion",
-        nargs="+",
-        default=[
-            ("train", "loss", "min"),
-            ("valid", "loss", "min"),
-            ("train", "acc", "max"),
-            ("valid", "acc", "max"),
-        ],
-        help="The criterion used for judging of the best model. "
-             'Give a pair referring the phase, "train" or "valid",'
-             'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
-    )
-    parser.add_argument(
-        "--keep_nbest_models",
-        type=int,
-        nargs="+",
-        default=[10],
-        help="Remove previous snapshots excluding the n-best scored epochs",
-    )
-    parser.add_argument(
-        "--nbest_averaging_interval",
-        type=int,
-        default=0,
-        help="The epoch interval to apply model averaging and save nbest models",
-    )
-    parser.add_argument(
-        "--grad_clip",
-        type=float,
-        default=5.0,
-        help="Gradient norm threshold to clip",
-    )
-    parser.add_argument(
-        "--grad_clip_type",
-        type=float,
-        default=2.0,
-        help="The type of the used p-norm for gradient clip. Can be inf",
-    )
-    parser.add_argument(
-        "--grad_noise",
-        type=str2bool,
-        default=False,
-        help="The flag to switch to use noise injection to "
-             "gradients during training",
-    )
-    parser.add_argument(
-        "--accum_grad",
-        type=int,
-        default=1,
-        help="The number of gradient accumulation",
-    )
-    parser.add_argument(
-        "--resume",
-        type=str2bool,
-        default=False,
-        help="Enable resuming if checkpoint is existing",
-    )
-    parser.add_argument(
-        "--train_dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type for training.",
-    )
-    parser.add_argument(
-        "--use_amp",
-        type=str2bool,
-        default=False,
-        help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
-    )
-    parser.add_argument(
-        "--log_interval",
-        default=None,
-        help="Show the logs every the number iterations in each epochs at the "
-             "training phase. If None is given, it is decided according the number "
-             "of training samples automatically .",
-    )
-    parser.add_argument(
-        "--use_tensorboard",
-        type=str2bool,
-        default=True,
-        help="Enable tensorboard logging",
-    )
-
-    # pretrained model related
-    parser.add_argument(
-        "--init_param",
-        type=str,
-        action="append",
-        default=[],
-        help="Specify the file path used for initialization of parameters. "
-             "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
-             "where file_path is the model file path, "
-             "src_key specifies the key of model states to be used in the model file, "
-             "dst_key specifies the attribute of the model to be initialized, "
-             "and exclude_keys excludes keys of model states for the initialization."
-             "e.g.\n"
-             "  # Load all parameters"
-             "  --init_param some/where/model.pb\n"
-             "  # Load only decoder parameters"
-             "  --init_param some/where/model.pb:decoder:decoder\n"
-             "  # Load only decoder parameters excluding decoder.embed"
-             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
-             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
-    )
-    parser.add_argument(
-        "--ignore_init_mismatch",
-        type=str2bool,
-        default=False,
-        help="Ignore size mismatch when loading pre-trained model",
-    )
-    parser.add_argument(
-        "--freeze_param",
-        type=str,
-        default=[],
-        action="append",
-        help="Freeze parameters",
-    )
-
-    # dataset related
-    parser.add_argument(
-        "--dataset_type",
-        type=str,
-        default="small",
-        help="whether to use dataloader for large dataset",
-    )
-    parser.add_argument(
-        "--dataset_conf",
-        action=NestedDictAction,
-        default=dict(),
-        help=f"The keyword arguments for dataset",
-    )
-    parser.add_argument(
-        "--data_dir",
-        type=str,
-        default=None,
-        help="root path of data",
-    )
-    parser.add_argument(
-        "--train_set",
-        type=str,
-        default="train",
-        help="train dataset",
-    )
-    parser.add_argument(
-        "--valid_set",
-        type=str,
-        default="validation",
-        help="dev dataset",
-    )
-    parser.add_argument(
-        "--data_file_names",
-        type=str,
-        default="wav.scp,text",
-        help="input data files",
-    )
-    parser.add_argument(
-        "--speed_perturb",
-        type=float,
-        nargs="+",
-        default=None,
-        help="speed perturb",
-    )
-    parser.add_argument(
-        "--use_preprocessor",
-        type=str2bool,
-        default=True,
-        help="Apply preprocessing to data or not",
-    )
-
-    # optimization related
-    parser.add_argument(
-        "--optim",
-        type=lambda x: x.lower(),
-        default="adam",
-        help="The optimizer type",
-    )
-    parser.add_argument(
-        "--optim_conf",
-        action=NestedDictAction,
-        default=dict(),
-        help="The keyword arguments for optimizer",
-    )
-    parser.add_argument(
-        "--scheduler",
-        type=lambda x: str_or_none(x.lower()),
-        default=None,
-        help="The lr scheduler type",
-    )
-    parser.add_argument(
-        "--scheduler_conf",
-        action=NestedDictAction,
-        default=dict(),
-        help="The keyword arguments for lr scheduler",
-    )
-
-    # most task related
-    parser.add_argument(
-        "--init",
-        type=lambda x: str_or_none(x.lower()),
-        default=None,
-        help="The initialization method",
-        choices=[
-            "chainer",
-            "xavier_uniform",
-            "xavier_normal",
-            "kaiming_uniform",
-            "kaiming_normal",
-            None,
-        ],
-    )
-    parser.add_argument(
-        "--token_list",
-        type=str_or_none,
-        default=None,
-        help="A text mapping int-id to token",
-    )
-    parser.add_argument(
-        "--token_type",
-        type=str,
-        default="bpe",
-        choices=["bpe", "char", "word"],
-        help="",
-    )
-    parser.add_argument(
-        "--bpemodel",
-        type=str_or_none,
-        default=None,
-        help="The model file fo sentencepiece",
-    )
-    parser.add_argument(
-        "--cleaner",
-        type=str_or_none,
-        choices=[None, "tacotron", "jaconv", "vietnamese"],
-        default=None,
-        help="Apply text cleaning",
-    )
-    parser.add_argument(
-        "--g2p",
-        type=str_or_none,
-        choices=g2p_choices,
-        default=None,
-        help="Specify g2p method if --token_type=phn",
-    )
-
-    # pai related
-    parser.add_argument(
-        "--use_pai",
-        type=str2bool,
-        default=False,
-        help="flag to indicate whether training on PAI",
-    )
-    parser.add_argument(
-        "--simple_ddp",
-        type=str2bool,
-        default=False,
-    )
-    parser.add_argument(
-        "--num_worker_count",
-        type=int,
-        default=1,
-        help="The number of machines on PAI.",
-    )
-    parser.add_argument(
-        "--access_key_id",
-        type=str,
-        default=None,
-        help="The username for oss.",
-    )
-    parser.add_argument(
-        "--access_key_secret",
-        type=str,
-        default=None,
-        help="The password for oss.",
-    )
-    parser.add_argument(
-        "--endpoint",
-        type=str,
-        default=None,
-        help="The endpoint for oss.",
-    )
-    parser.add_argument(
-        "--bucket_name",
-        type=str,
-        default=None,
-        help="The bucket name for oss.",
-    )
-    parser.add_argument(
-        "--oss_bucket",
-        default=None,
-        help="oss bucket.",
-    )
-    parser.add_argument(
-        "--enable_lora",
-        type=str2bool,
-        default=False,
-        help="Apply lora for finetuning.",
-    )
-    parser.add_argument(
-        "--lora_bias",
-        type=str,
-        default="none",
-        help="lora bias.",
-    )
-
-    return parser
-
-
-if __name__ == '__main__':
-    parser = get_parser()
-    args, extra_task_params = parser.parse_known_args()
-    if extra_task_params:
-        args = build_args(args, parser, extra_task_params)
-
-    # set random seed
-    set_all_random_seed(args.seed)
-    torch.backends.cudnn.enabled = args.cudnn_enabled
-    torch.backends.cudnn.benchmark = args.cudnn_benchmark
-    torch.backends.cudnn.deterministic = args.cudnn_deterministic
-
-    # ddp init
-    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
-    args.distributed = args.ngpu > 1 or args.dist_world_size > 1
-    distributed_option = build_distributed(args)
-
-    # for logging
-    if not distributed_option.distributed or distributed_option.dist_rank == 0:
-        logging.basicConfig(
-            level="INFO",
-            format=f"[{os.uname()[1].split('.')[0]}]"
-                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-        )
-    else:
-        logging.basicConfig(
-            level="ERROR",
-            format=f"[{os.uname()[1].split('.')[0]}]"
-                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-        )
-
-    # prepare files for dataloader
-    prepare_data(args, distributed_option)
-
-    model = build_model(args)
-    model = model.to(
-        dtype=getattr(torch, args.train_dtype),
-        device="cuda" if args.ngpu > 0 else "cpu",
-    )
-    if args.enable_lora:
-        mark_only_lora_as_trainable(model, args.lora_bias)
-    for t in args.freeze_param:
-        for k, p in model.named_parameters():
-            if k.startswith(t + ".") or k == t:
-                logging.info(f"Setting {k}.requires_grad = False")
-                p.requires_grad = False
-
-    optimizers = build_optimizer(args, model=model)
-    schedulers = build_scheduler(args, optimizers)
-
-    logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
-                                                                   distributed_option.dist_rank,
-                                                                   distributed_option.local_rank))
-    logging.info(pytorch_cudnn_version())
-    logging.info("Args: {}".format(args))
-    logging.info(model_summary(model))
-    logging.info("Optimizer: {}".format(optimizers))
-    logging.info("Scheduler: {}".format(schedulers))
-
-    # dump args to config.yaml
-    if not distributed_option.distributed or distributed_option.dist_rank == 0:
-        os.makedirs(args.output_dir, exist_ok=True)
-        with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
-            logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
-            if args.use_pai:
-                buffer = BytesIO()
-                torch.save({"config": vars(args)}, buffer)
-                args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
-            else:
-                yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
-
-    for p in args.init_param:
-        logging.info(f"Loading pretrained params from {p}")
-        load_pretrained_model(
-            model=model,
-            init_param=p,
-            ignore_init_mismatch=args.ignore_init_mismatch,
-            map_location=f"cuda:{torch.cuda.current_device()}"
-            if args.ngpu > 0
-            else "cpu",
-            oss_bucket=args.oss_bucket,
-        )
-
-    # dataloader for training/validation
-    train_dataloader, valid_dataloader = build_dataloader(args)
-
-    # Trainer, including model, optimizers, etc.
-    trainer = build_trainer(
-        args=args,
-        model=model,
-        optimizers=optimizers,
-        schedulers=schedulers,
-        train_dataloader=train_dataloader,
-        valid_dataloader=valid_dataloader,
-        distributed_option=distributed_option
-    )
-
-    trainer.run()
diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py
deleted file mode 100644
index 08018a7..0000000
--- a/funasr/build_utils/build_args.py
+++ /dev/null
@@ -1,122 +0,0 @@
-from funasr.models.ctc import CTC
-from funasr.utils import config_argparse
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-
-
-def build_args(args, parser, extra_task_params):
-    task_parser = config_argparse.ArgumentParser("Task related config")
-    if args.task_name == "asr":
-        from funasr.build_utils.build_asr_model import class_choices_list
-        for class_choices in class_choices_list:
-            class_choices.add_arguments(task_parser)
-        task_parser.add_argument(
-            "--split_with_space",
-            type=str2bool,
-            default=True,
-            help="whether to split text using <space>",
-        )
-        task_parser.add_argument(
-            "--seg_dict_file",
-            type=str,
-            default=None,
-            help="seg_dict_file for text processing",
-        )
-        task_parser.add_argument(
-            "--input_size",
-            type=int_or_none,
-            default=None,
-            help="The number of input dimension of the feature",
-        )
-        task_parser.add_argument(
-            "--ctc_conf",
-            action=NestedDictAction,
-            default=get_default_kwargs(CTC),
-            help="The keyword arguments for CTC class.",
-        )
-        task_parser.add_argument(
-            "--cmvn_file",
-            type=str_or_none,
-            default=None,
-            help="The path of cmvn file.",
-        )
-
-    elif args.task_name == "pretrain":
-        from funasr.build_utils.build_pretrain_model import class_choices_list
-        for class_choices in class_choices_list:
-            class_choices.add_arguments(task_parser)
-        task_parser.add_argument(
-            "--input_size",
-            type=int_or_none,
-            default=None,
-            help="The number of input dimension of the feature",
-        )
-        task_parser.add_argument(
-            "--cmvn_file",
-            type=str_or_none,
-            default=None,
-            help="The path of cmvn file.",
-        )
-
-    elif args.task_name == "lm":
-        from funasr.build_utils.build_lm_model import class_choices_list
-        for class_choices in class_choices_list:
-            class_choices.add_arguments(task_parser)
-
-    elif args.task_name == "punc":
-        from funasr.build_utils.build_punc_model import class_choices_list
-        for class_choices in class_choices_list:
-            class_choices.add_arguments(task_parser)
-
-    elif args.task_name == "vad":
-        from funasr.build_utils.build_vad_model import class_choices_list
-        for class_choices in class_choices_list:
-            class_choices.add_arguments(task_parser)
-        task_parser.add_argument(
-            "--input_size",
-            type=int_or_none,
-            default=None,
-            help="The number of input dimension of the feature",
-        )
-        task_parser.add_argument(
-            "--cmvn_file",
-            type=str_or_none,
-            default=None,
-            help="The path of cmvn file.",
-        )
-
-    elif args.task_name == "diar":
-        from funasr.build_utils.build_diar_model import class_choices_list
-        for class_choices in class_choices_list:
-            class_choices.add_arguments(task_parser)
-        task_parser.add_argument(
-            "--input_size",
-            type=int_or_none,
-            default=None,
-            help="The number of input dimension of the feature",
-        )
-
-    elif args.task_name == "sv":
-        from funasr.build_utils.build_sv_model import class_choices_list
-        for class_choices in class_choices_list:
-            class_choices.add_arguments(task_parser)
-        task_parser.add_argument(
-            "--input_size",
-            type=int_or_none,
-            default=None,
-            help="The number of input dimension of the feature",
-        )
-
-    else:
-        raise NotImplementedError("Not supported task: {}".format(args.task_name))
-
-    for action in parser._actions:
-        if not any(action.dest == a.dest for a in task_parser._actions):
-            task_parser._add_action(action)
-
-    task_parser.set_defaults(**vars(args))
-    task_args = task_parser.parse_args(extra_task_params)
-    return task_args
diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
deleted file mode 100644
index fd47bd3..0000000
--- a/funasr/build_utils/build_asr_model.py
+++ /dev/null
@@ -1,559 +0,0 @@
-import logging
-
-from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
-from funasr.models.decoder.rnn_decoder import RNNDecoder
-from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
-from funasr.models.decoder.transformer_decoder import (
-    DynamicConvolution2DTransformerDecoder,  # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
-from funasr.models.decoder.transformer_decoder import (
-    LightweightConvolution2DTransformerDecoder,  # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import (
-    LightweightConvolutionTransformerDecoder,  # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
-from funasr.models.decoder.transformer_decoder import TransformerDecoder
-from funasr.models.decoder.rnnt_decoder import RNNTDecoder
-from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
-from funasr.models.e2e_asr import ASRModel
-from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
-from funasr.models.e2e_asr_mfcca import MFCCA
-
-from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
-from funasr.models.e2e_asr_bat import BATModel
-
-from funasr.models.e2e_sa_asr import SAASRModel
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
-
-from funasr.models.e2e_tp import TimestampPredictor
-from funasr.models.e2e_uni_asr import UniASR
-from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
-from funasr.models.encoder.branchformer_encoder import BranchformerEncoder
-from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder
-from funasr.models.encoder.transformer_encoder import TransformerEncoder
-from funasr.models.encoder.rwkv_encoder import RWKVEncoder
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.default import MultiChannelFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.joint_net.joint_network import JointNetwork
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3, BATPredictor
-from funasr.models.specaug.specaug import SpecAug
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.modules.subsampling import Conv1dSubsampling
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-frontend_choices = ClassChoices(
-    name="frontend",
-    classes=dict(
-        default=DefaultFrontend,
-        sliding_window=SlidingWindow,
-        s3prl=S3prlFrontend,
-        fused=FusedFrontends,
-        wav_frontend=WavFrontend,
-        multichannelfrontend=MultiChannelFrontend,
-    ),
-    default="default",
-)
-specaug_choices = ClassChoices(
-    name="specaug",
-    classes=dict(
-        specaug=SpecAug,
-        specaug_lfr=SpecAugLFR,
-    ),
-    default=None,
-    optional=True,
-)
-normalize_choices = ClassChoices(
-    "normalize",
-    classes=dict(
-        global_mvn=GlobalMVN,
-        utterance_mvn=UtteranceMVN,
-    ),
-    default=None,
-    optional=True,
-)
-model_choices = ClassChoices(
-    "model",
-    classes=dict(
-        asr=ASRModel,
-        uniasr=UniASR,
-        paraformer=Paraformer,
-        paraformer_online=ParaformerOnline,
-        paraformer_bert=ParaformerBert,
-        bicif_paraformer=BiCifParaformer,
-        contextual_paraformer=ContextualParaformer,
-        neatcontextual_paraformer=NeatContextualParaformer,
-        mfcca=MFCCA,
-        timestamp_prediction=TimestampPredictor,
-        rnnt=TransducerModel,
-        rnnt_unified=UnifiedTransducerModel,
-        sa_asr=SAASRModel,
-        bat=BATModel,
-    ),
-    default="asr",
-)
-encoder_choices = ClassChoices(
-    "encoder",
-    classes=dict(
-        conformer=ConformerEncoder,
-        transformer=TransformerEncoder,
-        rnn=RNNEncoder,
-        sanm=SANMEncoder,
-        sanm_chunk_opt=SANMEncoderChunkOpt,
-        data2vec_encoder=Data2VecEncoder,
-        branchformer=BranchformerEncoder,
-        e_branchformer=EBranchformerEncoder,
-        mfcca_enc=MFCCAEncoder,
-        chunk_conformer=ConformerChunkEncoder,
-        rwkv=RWKVEncoder,
-    ),
-    default="rnn",
-)
-asr_encoder_choices = ClassChoices(
-    "asr_encoder",
-    classes=dict(
-        conformer=ConformerEncoder,
-        transformer=TransformerEncoder,
-        rnn=RNNEncoder,
-        sanm=SANMEncoder,
-        sanm_chunk_opt=SANMEncoderChunkOpt,
-        data2vec_encoder=Data2VecEncoder,
-        mfcca_enc=MFCCAEncoder,
-    ),
-    default="rnn",
-)
-
-spk_encoder_choices = ClassChoices(
-    "spk_encoder",
-    classes=dict(
-        resnet34_diar=ResNet34Diar,
-    ),
-    default="resnet34_diar",
-)
-encoder_choices2 = ClassChoices(
-    "encoder2",
-    classes=dict(
-        conformer=ConformerEncoder,
-        transformer=TransformerEncoder,
-        rnn=RNNEncoder,
-        sanm=SANMEncoder,
-        sanm_chunk_opt=SANMEncoderChunkOpt,
-    ),
-    default="rnn",
-)
-decoder_choices = ClassChoices(
-    "decoder",
-    classes=dict(
-        transformer=TransformerDecoder,
-        lightweight_conv=LightweightConvolutionTransformerDecoder,
-        lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
-        dynamic_conv=DynamicConvolutionTransformerDecoder,
-        dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
-        rnn=RNNDecoder,
-        fsmn_scama_opt=FsmnDecoderSCAMAOpt,
-        paraformer_decoder_sanm=ParaformerSANMDecoder,
-        paraformer_decoder_san=ParaformerDecoderSAN,
-        contextual_paraformer_decoder=ContextualParaformerDecoder,
-        sa_decoder=SAAsrTransformerDecoder,
-    ),
-    default="rnn",
-)
-decoder_choices2 = ClassChoices(
-    "decoder2",
-    classes=dict(
-        transformer=TransformerDecoder,
-        lightweight_conv=LightweightConvolutionTransformerDecoder,
-        lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
-        dynamic_conv=DynamicConvolutionTransformerDecoder,
-        dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
-        rnn=RNNDecoder,
-        fsmn_scama_opt=FsmnDecoderSCAMAOpt,
-        paraformer_decoder_sanm=ParaformerSANMDecoder,
-    ),
-    type_check=AbsDecoder,
-    default="rnn",
-)
-predictor_choices = ClassChoices(
-    name="predictor",
-    classes=dict(
-        cif_predictor=CifPredictor,
-        ctc_predictor=None,
-        cif_predictor_v2=CifPredictorV2,
-        cif_predictor_v3=CifPredictorV3,
-        bat_predictor=BATPredictor,
-    ),
-    default="cif_predictor",
-    optional=True,
-)
-predictor_choices2 = ClassChoices(
-    name="predictor2",
-    classes=dict(
-        cif_predictor=CifPredictor,
-        ctc_predictor=None,
-        cif_predictor_v2=CifPredictorV2,
-    ),
-    default="cif_predictor",
-    optional=True,
-)
-stride_conv_choices = ClassChoices(
-    name="stride_conv",
-    classes=dict(
-        stride_conv1d=Conv1dSubsampling
-    ),
-    default="stride_conv1d",
-    optional=True,
-)
-rnnt_decoder_choices = ClassChoices(
-    name="rnnt_decoder",
-    classes=dict(
-        rnnt=RNNTDecoder,
-    ),
-    default="rnnt",
-    optional=True,
-)
-joint_network_choices = ClassChoices(
-    name="joint_network",
-    classes=dict(
-        joint_network=JointNetwork,
-    ),
-    default="joint_network",
-    optional=True,
-)
-
-class_choices_list = [
-    # --frontend and --frontend_conf
-    frontend_choices,
-    # --specaug and --specaug_conf
-    specaug_choices,
-    # --normalize and --normalize_conf
-    normalize_choices,
-    # --model and --model_conf
-    model_choices,
-    # --encoder and --encoder_conf
-    encoder_choices,
-    # --decoder and --decoder_conf
-    decoder_choices,
-    # --predictor and --predictor_conf
-    predictor_choices,
-    # --encoder2 and --encoder2_conf
-    encoder_choices2,
-    # --decoder2 and --decoder2_conf
-    decoder_choices2,
-    # --predictor2 and --predictor2_conf
-    predictor_choices2,
-    # --stride_conv and --stride_conv_conf
-    stride_conv_choices,
-    # --rnnt_decoder and --rnnt_decoder_conf
-    rnnt_decoder_choices,
-    # --joint_network and --joint_network_conf
-    joint_network_choices,
-    # --asr_encoder and --asr_encoder_conf
-    asr_encoder_choices,
-    # --spk_encoder and --spk_encoder_conf
-    spk_encoder_choices,
-]
-
-
-def build_asr_model(args):
-    # token_list
-    if isinstance(args.token_list, str):
-        with open(args.token_list, encoding="utf-8") as f:
-            token_list = [line.rstrip() for line in f]
-        args.token_list = list(token_list)
-        vocab_size = len(token_list)
-        logging.info(f"Vocabulary size: {vocab_size}")
-    elif isinstance(args.token_list, (tuple, list)):
-        token_list = list(args.token_list)
-        vocab_size = len(token_list)
-        logging.info(f"Vocabulary size: {vocab_size}")
-    else:
-        token_list = None
-        vocab_size = None
-
-    # frontend
-    if hasattr(args, "input_size") and args.input_size is None:
-        frontend_class = frontend_choices.get_class(args.frontend)
-        if args.frontend == 'wav_frontend' or args.frontend == 'multichannelfrontend':
-            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
-        else:
-            frontend = frontend_class(**args.frontend_conf)
-        input_size = frontend.output_size()
-    else:
-        args.frontend = None
-        args.frontend_conf = {}
-        frontend = None
-        input_size = args.input_size if hasattr(args, "input_size") else None
-
-    # data augmentation for spectrogram
-    if args.specaug is not None:
-        specaug_class = specaug_choices.get_class(args.specaug)
-        specaug = specaug_class(**args.specaug_conf)
-    else:
-        specaug = None
-
-    # normalization layer
-    if args.normalize is not None:
-        normalize_class = normalize_choices.get_class(args.normalize)
-        if args.model == "mfcca":
-            normalize = normalize_class(stats_file=args.cmvn_file, **args.normalize_conf)
-        else:
-            normalize = normalize_class(**args.normalize_conf)
-    else:
-        normalize = None
-
-    # encoder
-    encoder_class = encoder_choices.get_class(args.encoder)
-    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
-
-    # decoder
-    if hasattr(args, "decoder") and args.decoder is not None:
-        decoder_class = decoder_choices.get_class(args.decoder)
-        decoder = decoder_class(
-            vocab_size=vocab_size,
-            encoder_output_size=encoder.output_size(),
-            **args.decoder_conf,
-        )
-    else:
-        decoder = None
-
-    # ctc
-    ctc = CTC(
-        odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
-    )
-
-    if args.model in ["asr", "mfcca"]:
-        model_class = model_choices.get_class(args.model)
-        model = model_class(
-            vocab_size=vocab_size,
-            frontend=frontend,
-            specaug=specaug,
-            normalize=normalize,
-            encoder=encoder,
-            decoder=decoder,
-            ctc=ctc,
-            token_list=token_list,
-            **args.model_conf,
-        )
-    elif args.model in ["paraformer", "paraformer_online", "paraformer_bert", "bicif_paraformer",
-                        "contextual_paraformer", "neatcontextual_paraformer"]:
-        # predictor
-        predictor_class = predictor_choices.get_class(args.predictor)
-        predictor = predictor_class(**args.predictor_conf)
-
-        model_class = model_choices.get_class(args.model)
-        model = model_class(
-            vocab_size=vocab_size,
-            frontend=frontend,
-            specaug=specaug,
-            normalize=normalize,
-            encoder=encoder,
-            decoder=decoder,
-            ctc=ctc,
-            token_list=token_list,
-            predictor=predictor,
-            **args.model_conf,
-        )
-    elif args.model == "uniasr":
-        # stride_conv
-        stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
-        stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
-                                        odim=input_size + encoder.output_size())
-        stride_conv_output_size = stride_conv.output_size()
-
-        # encoder2
-        encoder_class2 = encoder_choices2.get_class(args.encoder2)
-        encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
-
-        # decoder2
-        decoder_class2 = decoder_choices2.get_class(args.decoder2)
-        decoder2 = decoder_class2(
-            vocab_size=vocab_size,
-            encoder_output_size=encoder2.output_size(),
-            **args.decoder2_conf,
-        )
-
-        # ctc2
-        ctc2 = CTC(
-            odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
-        )
-
-        # predictor
-        predictor_class = predictor_choices.get_class(args.predictor)
-        predictor = predictor_class(**args.predictor_conf)
-
-        # predictor2
-        predictor_class = predictor_choices2.get_class(args.predictor2)
-        predictor2 = predictor_class(**args.predictor2_conf)
-
-        model_class = model_choices.get_class(args.model)
-        model = model_class(
-            vocab_size=vocab_size,
-            frontend=frontend,
-            specaug=specaug,
-            normalize=normalize,
-            encoder=encoder,
-            decoder=decoder,
-            ctc=ctc,
-            token_list=token_list,
-            predictor=predictor,
-            ctc2=ctc2,
-            encoder2=encoder2,
-            decoder2=decoder2,
-            predictor2=predictor2,
-            stride_conv=stride_conv,
-            **args.model_conf,
-        )
-    elif args.model == "timestamp_prediction":
-        # predictor
-        predictor_class = predictor_choices.get_class(args.predictor)
-        predictor = predictor_class(**args.predictor_conf)
-        
-        model_class = model_choices.get_class(args.model)
-        model = model_class(
-            frontend=frontend,
-            encoder=encoder,
-            predictor=predictor,
-            token_list=token_list,
-            **args.model_conf,
-        )
-    elif args.model == "rnnt" or args.model == "rnnt_unified":
-        # 5. Decoder
-        encoder_output_size = encoder.output_size()
-
-        rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
-        decoder = rnnt_decoder_class(
-            vocab_size,
-            **args.rnnt_decoder_conf,
-        )
-        decoder_output_size = decoder.output_size
-
-        if getattr(args, "decoder", None) is not None:
-            att_decoder_class = decoder_choices.get_class(args.decoder)
-
-            att_decoder = att_decoder_class(
-                vocab_size=vocab_size,
-                encoder_output_size=encoder_output_size,
-                **args.decoder_conf,
-            )
-        else:
-            att_decoder = None
-        # 6. Joint Network
-        joint_network = JointNetwork(
-            vocab_size,
-            encoder_output_size,
-            decoder_output_size,
-            **args.joint_network_conf,
-        )
-
-        model_class = model_choices.get_class(args.model)
-        # 7. Build model
-        model = model_class(
-            vocab_size=vocab_size,
-            token_list=token_list,
-            frontend=frontend,
-            specaug=specaug,
-            normalize=normalize,
-            encoder=encoder,
-            decoder=decoder,
-            att_decoder=att_decoder,
-            joint_network=joint_network,
-            **args.model_conf,
-        )
-    elif args.model == "bat":
-        # 5. Decoder
-        encoder_output_size = encoder.output_size()
-
-        rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
-        decoder = rnnt_decoder_class(
-            vocab_size,
-            **args.rnnt_decoder_conf,
-        )
-        decoder_output_size = decoder.output_size
-
-        if getattr(args, "decoder", None) is not None:
-            att_decoder_class = decoder_choices.get_class(args.decoder)
-
-            att_decoder = att_decoder_class(
-                vocab_size=vocab_size,
-                encoder_output_size=encoder_output_size,
-                **args.decoder_conf,
-            )
-        else:
-            att_decoder = None
-        # 6. Joint Network
-        joint_network = JointNetwork(
-            vocab_size,
-            encoder_output_size,
-            decoder_output_size,
-            **args.joint_network_conf,
-        )
-
-        predictor_class = predictor_choices.get_class(args.predictor)
-        predictor = predictor_class(**args.predictor_conf)
-
-        model_class = model_choices.get_class(args.model)
-        # 7. Build model
-        model = model_class(
-            vocab_size=vocab_size,
-            token_list=token_list,
-            frontend=frontend,
-            specaug=specaug,
-            normalize=normalize,
-            encoder=encoder,
-            decoder=decoder,
-            att_decoder=att_decoder,
-            joint_network=joint_network,
-            predictor=predictor,
-            **args.model_conf,
-        )
-    elif args.model == "sa_asr":
-        asr_encoder_class = asr_encoder_choices.get_class(args.asr_encoder)
-        asr_encoder = asr_encoder_class(input_size=input_size, **args.asr_encoder_conf)
-        spk_encoder_class = spk_encoder_choices.get_class(args.spk_encoder)
-        spk_encoder = spk_encoder_class(input_size=input_size, **args.spk_encoder_conf)
-        decoder = decoder_class(
-            vocab_size=vocab_size,
-            encoder_output_size=asr_encoder.output_size(),
-            **args.decoder_conf,
-        )
-        ctc = CTC(
-            odim=vocab_size, encoder_output_size=asr_encoder.output_size(), **args.ctc_conf
-        )
-
-        model_class = model_choices.get_class(args.model)
-        model = model_class(
-            vocab_size=vocab_size,
-            frontend=frontend,
-            specaug=specaug,
-            normalize=normalize,
-            asr_encoder=asr_encoder,
-            spk_encoder=spk_encoder,
-            decoder=decoder,
-            ctc=ctc,
-            token_list=token_list,
-            **args.model_conf,
-        )
-
-    else:
-        raise NotImplementedError("Not supported model: {}".format(args.model))
-
-    # initialize
-    if args.init is not None:
-        initialize(model, args.init)
-
-    return model
diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
deleted file mode 100644
index 473097e..0000000
--- a/funasr/build_utils/build_dataloader.py
+++ /dev/null
@@ -1,28 +0,0 @@
-from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
-from funasr.datasets.small_datasets.sequence_iter_factory import SequenceIterFactory
-
-
-def build_dataloader(args):
-    if args.dataset_type == "small":
-        if args.task_name == "diar" and args.model == "eend_ola":
-            from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
-            train_iter_factory = EENDOLADataLoader(
-                data_file=args.train_data_path_and_name_and_type[0][0],
-                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
-                num_workers=args.dataset_conf["num_workers"],
-                shuffle=True)
-            valid_iter_factory = EENDOLADataLoader(
-                data_file=args.valid_data_path_and_name_and_type[0][0],
-                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
-                num_workers=0,
-                shuffle=False)
-        else:
-            train_iter_factory = SequenceIterFactory(args, mode="train")
-            valid_iter_factory = SequenceIterFactory(args, mode="valid")
-    elif args.dataset_type == "large":
-        train_iter_factory = LargeDataLoader(args, mode="train")
-        valid_iter_factory = LargeDataLoader(args, mode="valid")
-    else:
-        raise ValueError(f"Not supported dataset_type={args.dataset_type}")
-
-    return train_iter_factory, valid_iter_factory
diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py
deleted file mode 100644
index 1be04c7..0000000
--- a/funasr/build_utils/build_diar_model.py
+++ /dev/null
@@ -1,326 +0,0 @@
-import logging
-
-import torch
-
-from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.label_aggregation import LabelAggregate, LabelAggregateMaxPooling
-from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
-from funasr.models.e2e_diar_sond import DiarSondModel
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
-from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
-from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
-from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
-from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
-from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
-from funasr.models.encoder.transformer_encoder import TransformerEncoder
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.frontend.wav_frontend import WavFrontendMel23
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.specaug.specaug import SpecAug
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.models.specaug.abs_profileaug import AbsProfileAug
-from funasr.models.specaug.profileaug import ProfileAug
-from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
-from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-frontend_choices = ClassChoices(
-    name="frontend",
-    classes=dict(
-        default=DefaultFrontend,
-        sliding_window=SlidingWindow,
-        s3prl=S3prlFrontend,
-        fused=FusedFrontends,
-        wav_frontend=WavFrontend,
-        wav_frontend_mel23=WavFrontendMel23,
-    ),
-    default="default",
-)
-specaug_choices = ClassChoices(
-    name="specaug",
-    classes=dict(
-        specaug=SpecAug,
-        specaug_lfr=SpecAugLFR,
-    ),
-    default=None,
-    optional=True,
-)
-profileaug_choices = ClassChoices(
-    name="profileaug",
-    classes=dict(
-        profileaug=ProfileAug,
-    ),
-    type_check=AbsProfileAug,
-    default=None,
-    optional=True,
-)
-normalize_choices = ClassChoices(
-    "normalize",
-    classes=dict(
-        global_mvn=GlobalMVN,
-        utterance_mvn=UtteranceMVN,
-    ),
-    default=None,
-    optional=True,
-)
-label_aggregator_choices = ClassChoices(
-    "label_aggregator",
-    classes=dict(
-        label_aggregator=LabelAggregate,
-        label_aggregator_max_pool=LabelAggregateMaxPooling,
-    ),
-    default=None,
-    optional=True,
-)
-model_choices = ClassChoices(
-    "model",
-    classes=dict(
-        sond=DiarSondModel,
-        eend_ola=DiarEENDOLAModel,
-    ),
-    default="sond",
-)
-encoder_choices = ClassChoices(
-    "encoder",
-    classes=dict(
-        conformer=ConformerEncoder,
-        transformer=TransformerEncoder,
-        rnn=RNNEncoder,
-        sanm=SANMEncoder,
-        san=SelfAttentionEncoder,
-        fsmn=FsmnEncoder,
-        conv=ConvEncoder,
-        resnet34=ResNet34Diar,
-        resnet34_sp_l2reg=ResNet34SpL2RegDiar,
-        sanm_chunk_opt=SANMEncoderChunkOpt,
-        data2vec_encoder=Data2VecEncoder,
-        ecapa_tdnn=ECAPA_TDNN,
-        eend_ola_transformer=EENDOLATransformerEncoder,
-    ),
-    default="resnet34",
-)
-speaker_encoder_choices = ClassChoices(
-    "speaker_encoder",
-    classes=dict(
-        conformer=ConformerEncoder,
-        transformer=TransformerEncoder,
-        rnn=RNNEncoder,
-        sanm=SANMEncoder,
-        san=SelfAttentionEncoder,
-        fsmn=FsmnEncoder,
-        conv=ConvEncoder,
-        sanm_chunk_opt=SANMEncoderChunkOpt,
-        data2vec_encoder=Data2VecEncoder,
-    ),
-    default=None,
-    optional=True
-)
-cd_scorer_choices = ClassChoices(
-    "cd_scorer",
-    classes=dict(
-        san=SelfAttentionEncoder,
-    ),
-    default=None,
-    optional=True,
-)
-ci_scorer_choices = ClassChoices(
-    "ci_scorer",
-    classes=dict(
-        dot=DotScorer,
-        cosine=CosScorer,
-        conv=ConvEncoder,
-    ),
-    type_check=torch.nn.Module,
-    default=None,
-    optional=True,
-)
-# decoder is used for output (e.g. post_net in SOND)
-decoder_choices = ClassChoices(
-    "decoder",
-    classes=dict(
-        rnn=RNNEncoder,
-        fsmn=FsmnEncoder,
-    ),
-    type_check=torch.nn.Module,
-    default="fsmn",
-)
-# encoder_decoder_attractor is used for EEND-OLA
-encoder_decoder_attractor_choices = ClassChoices(
-    "encoder_decoder_attractor",
-    classes=dict(
-        eda=EncoderDecoderAttractor,
-    ),
-    type_check=torch.nn.Module,
-    default="eda",
-)
-class_choices_list = [
-    # --frontend and --frontend_conf
-    frontend_choices,
-    # --specaug and --specaug_conf
-    specaug_choices,
-    # --profileaug and --profileaug_conf
-    profileaug_choices,
-    # --normalize and --normalize_conf
-    normalize_choices,
-    # --label_aggregator and --label_aggregator_conf
-    label_aggregator_choices,
-    # --model and --model_conf
-    model_choices,
-    # --encoder and --encoder_conf
-    encoder_choices,
-    # --speaker_encoder and --speaker_encoder_conf
-    speaker_encoder_choices,
-    # --cd_scorer and cd_scorer_conf
-    cd_scorer_choices,
-    # --ci_scorer and ci_scorer_conf
-    ci_scorer_choices,
-    # --decoder and --decoder_conf
-    decoder_choices,
-    # --eda and --eda_conf
-    encoder_decoder_attractor_choices,
-]
-
-
-def build_diar_model(args):
-    # token_list
-    if args.token_list is not None:
-        if isinstance(args.token_list, str):
-            with open(args.token_list, encoding="utf-8") as f:
-                token_list = [line.rstrip() for line in f]
-
-            # Overwriting token_list to keep it as "portable".
-            args.token_list = list(token_list)
-        elif isinstance(args.token_list, (tuple, list)):
-            token_list = list(args.token_list)
-        else:
-            raise RuntimeError("token_list must be str or list")
-        vocab_size = len(token_list)
-        logging.info(f"Vocabulary size: {vocab_size}")
-    else:
-        token_list = None
-        vocab_size = None
-
-    # frontend
-    if args.input_size is None:
-        frontend_class = frontend_choices.get_class(args.frontend)
-        if args.frontend == 'wav_frontend':
-            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
-        else:
-            frontend = frontend_class(**args.frontend_conf)
-        input_size = frontend.output_size()
-    else:
-        args.frontend = None
-        args.frontend_conf = {}
-        frontend = None
-        input_size = args.input_size
-
-    if args.model == "sond":
-        # encoder
-        encoder_class = encoder_choices.get_class(args.encoder)
-        encoder = encoder_class(input_size=input_size ,**args.encoder_conf)
-
-        # data augmentation for spectrogram
-        if args.specaug is not None:
-            specaug_class = specaug_choices.get_class(args.specaug)
-            specaug = specaug_class(**args.specaug_conf)
-        else:
-            specaug = None
-
-        # Data augmentation for Profiles
-        if hasattr(args, "profileaug") and args.profileaug is not None:
-            profileaug_class = profileaug_choices.get_class(args.profileaug)
-            profileaug = profileaug_class(**args.profileaug_conf)
-        else:
-            profileaug = None
-
-        # normalization layer
-        if args.normalize is not None:
-            normalize_class = normalize_choices.get_class(args.normalize)
-            normalize = normalize_class(**args.normalize_conf)
-        else:
-            normalize = None
-
-        # speaker encoder
-        if getattr(args, "speaker_encoder", None) is not None:
-            speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
-            speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
-        else:
-            speaker_encoder = None
-
-        # ci scorer
-        if getattr(args, "ci_scorer", None) is not None:
-            ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
-            ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
-        else:
-            ci_scorer = None
-
-        # cd scorer
-        if getattr(args, "cd_scorer", None) is not None:
-            cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
-            cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
-        else:
-            cd_scorer = None
-
-        # decoder
-        decoder_class = decoder_choices.get_class(args.decoder)
-        decoder = decoder_class(**args.decoder_conf)
-
-        # logger aggregator
-        if getattr(args, "label_aggregator", None) is not None:
-            label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
-            label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
-        else:
-            label_aggregator = None
-
-        model_class = model_choices.get_class(args.model)
-        model = model_class(
-            vocab_size=vocab_size,
-            frontend=frontend,
-            specaug=specaug,
-            profileaug=profileaug,
-            normalize=normalize,
-            label_aggregator=label_aggregator,
-            encoder=encoder,
-            speaker_encoder=speaker_encoder,
-            ci_scorer=ci_scorer,
-            cd_scorer=cd_scorer,
-            decoder=decoder,
-            token_list=token_list,
-            **args.model_conf,
-        )
-
-    elif args.model == "eend_ola":
-        # encoder
-        encoder_class = encoder_choices.get_class(args.encoder)
-        encoder = encoder_class(**args.encoder_conf)
-
-        # encoder-decoder attractor
-        encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
-        encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
-
-        # 9. Build model
-        model_class = model_choices.get_class(args.model)
-        model = model_class(
-            frontend=frontend,
-            encoder=encoder,
-            encoder_decoder_attractor=encoder_decoder_attractor,
-            **args.model_conf,
-        )
-
-    else:
-        raise NotImplementedError("Not supported model: {}".format(args.model))
-
-    # 10. Initialize
-    if args.init is not None:
-        initialize(model, args.init)
-
-    return model
diff --git a/funasr/build_utils/build_distributed.py b/funasr/build_utils/build_distributed.py
deleted file mode 100644
index b64b4c0..0000000
--- a/funasr/build_utils/build_distributed.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import logging
-import os
-
-import torch
-
-from funasr.train.distributed_utils import DistributedOption
-from funasr.utils.build_dataclass import build_dataclass
-
-
-def build_distributed(args):
-    distributed_option = build_dataclass(DistributedOption, args)
-    if args.use_pai:
-        distributed_option.init_options_pai()
-        distributed_option.init_torch_distributed_pai(args)
-    elif not args.simple_ddp:
-        distributed_option.init_torch_distributed(args)
-    elif args.distributed and args.simple_ddp:
-        distributed_option.init_torch_distributed_pai(args)
-        args.ngpu = torch.distributed.get_world_size()
-
-    for handler in logging.root.handlers[:]:
-        logging.root.removeHandler(handler)
-    if not distributed_option.distributed or distributed_option.dist_rank == 0:
-        logging.basicConfig(
-            level="INFO",
-            format=f"[{os.uname()[1].split('.')[0]}]"
-                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-        )
-    else:
-        logging.basicConfig(
-            level="ERROR",
-            format=f"[{os.uname()[1].split('.')[0]}]"
-                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-        )
-    logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
-                                                                   distributed_option.dist_rank,
-                                                                   distributed_option.local_rank))
-    return distributed_option
diff --git a/funasr/build_utils/build_lm_model.py b/funasr/build_utils/build_lm_model.py
deleted file mode 100644
index f78a20e..0000000
--- a/funasr/build_utils/build_lm_model.py
+++ /dev/null
@@ -1,62 +0,0 @@
-import logging
-
-from funasr.train.abs_model import AbsLM
-from funasr.train.abs_model import LanguageModel
-from funasr.models.seq_rnn_lm import SequentialRNNLM
-from funasr.models.transformer_lm import TransformerLM
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-lm_choices = ClassChoices(
-    "lm",
-    classes=dict(
-        seq_rnn=SequentialRNNLM,
-        transformer=TransformerLM,
-    ),
-    type_check=AbsLM,
-    default="seq_rnn",
-)
-model_choices = ClassChoices(
-    "model",
-    classes=dict(
-        lm=LanguageModel,
-    ),
-    default="lm",
-)
-
-class_choices_list = [
-    # --lm and --lm_conf
-    lm_choices,
-    # --model and --model_conf
-    model_choices
-]
-
-
-def build_lm_model(args):
-    # token_list
-    if isinstance(args.token_list, str):
-        with open(args.token_list, encoding="utf-8") as f:
-            token_list = [line.rstrip() for line in f]
-        args.token_list = list(token_list)
-        vocab_size = len(token_list)
-        logging.info(f"Vocabulary size: {vocab_size}")
-    elif isinstance(args.token_list, (tuple, list)):
-        token_list = list(args.token_list)
-        vocab_size = len(token_list)
-        logging.info(f"Vocabulary size: {vocab_size}")
-    else:
-        vocab_size = None
-
-    # lm
-    lm_class = lm_choices.get_class(args.lm)
-    lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
-
-    args.model = args.model if hasattr(args, "model") else "lm"
-    model_class = model_choices.get_class(args.model)
-    model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
-
-    # initialize
-    if args.init is not None:
-        initialize(model, args.init)
-
-    return model
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
deleted file mode 100644
index 66fdfd0..0000000
--- a/funasr/build_utils/build_model.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from funasr.build_utils.build_asr_model import build_asr_model
-from funasr.build_utils.build_diar_model import build_diar_model
-from funasr.build_utils.build_lm_model import build_lm_model
-from funasr.build_utils.build_pretrain_model import build_pretrain_model
-from funasr.build_utils.build_punc_model import build_punc_model
-from funasr.build_utils.build_sv_model import build_sv_model
-from funasr.build_utils.build_vad_model import build_vad_model
-from funasr.build_utils.build_ss_model import build_ss_model
-
-
-def build_model(args):
-    if args.task_name == "asr":
-        model = build_asr_model(args)
-    elif args.task_name == "pretrain":
-        model = build_pretrain_model(args)
-    elif args.task_name == "lm":
-        model = build_lm_model(args)
-    elif args.task_name == "punc":
-        model = build_punc_model(args)
-    elif args.task_name == "vad":
-        model = build_vad_model(args)
-    elif args.task_name == "diar":
-        model = build_diar_model(args)
-    elif args.task_name == "sv":
-        model = build_sv_model(args)
-    elif args.task_name == "ss":
-        model = build_ss_model(args)
-    else:
-        raise NotImplementedError("Not supported task: {}".format(args.task_name))
-
-    return model
diff --git a/funasr/build_utils/build_model_from_file.py b/funasr/build_utils/build_model_from_file.py
deleted file mode 100644
index 65e0d5f..0000000
--- a/funasr/build_utils/build_model_from_file.py
+++ /dev/null
@@ -1,193 +0,0 @@
-import argparse
-import logging
-import os
-from pathlib import Path
-from typing import Union
-
-import torch
-import yaml
-
-from funasr.build_utils.build_model import build_model
-from funasr.models.base_model import FunASRModel
-
-
-def build_model_from_file(
-        config_file: Union[Path, str] = None,
-        model_file: Union[Path, str] = None,
-        cmvn_file: Union[Path, str] = None,
-        device: str = "cpu",
-        task_name: str = "asr",
-        mode: str = "paraformer",
-):
-    """Build model from the files.
-
-    This method is used for inference or fine-tuning.
-
-    Args:
-        config_file: The yaml file saved when training.
-        model_file: The model file saved when training.
-        device: Device type, "cpu", "cuda", or "cuda:N".
-
-    """
-    if config_file is None:
-        assert model_file is not None, (
-            "The argument 'model_file' must be provided "
-            "if the argument 'config_file' is not specified."
-        )
-        config_file = Path(model_file).parent / "config.yaml"
-    else:
-        config_file = Path(config_file)
-
-    with config_file.open("r", encoding="utf-8") as f:
-        args = yaml.safe_load(f)
-    if cmvn_file is not None:
-        args["cmvn_file"] = cmvn_file
-    args = argparse.Namespace(**args)
-    args.task_name = task_name
-    model = build_model(args)
-    if not isinstance(model, FunASRModel):
-        raise RuntimeError(
-            f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
-        )
-    model.to(device)
-    model_dict = dict()
-    model_name_pth = None
-    if model_file is not None:
-        logging.info("model_file is {}".format(model_file))
-        if device == "cuda":
-            device = f"cuda:{torch.cuda.current_device()}"
-        model_dir = os.path.dirname(model_file)
-        model_name = os.path.basename(model_file)
-        if "model.ckpt-" in model_name or ".bin" in model_name:
-            model_name_pth = os.path.join(model_dir, model_name.replace('.bin',
-                                                                        '.pb')) if ".bin" in model_name else os.path.join(
-                model_dir, "{}.pb".format(model_name))
-            if os.path.exists(model_name_pth):
-                logging.info("model_file is load from pth: {}".format(model_name_pth))
-                model_dict = torch.load(model_name_pth, map_location=device)
-            else:
-                model_dict = convert_tf2torch(model, model_file, mode)
-            model.load_state_dict(model_dict)
-        else:
-            model_dict = torch.load(model_file, map_location=device)
-    if task_name == "ss":
-        model_dict = model_dict['model']
-    if task_name == "diar" and mode == "sond":
-        model_dict = fileter_model_dict(model_dict, model.state_dict())
-    if task_name == "vad":
-        model.encoder.load_state_dict(model_dict)
-    else:
-        model.load_state_dict(model_dict)
-    if model_name_pth is not None and not os.path.exists(model_name_pth):
-        torch.save(model_dict, model_name_pth)
-        logging.info("model_file is saved to pth: {}".format(model_name_pth))
-
-    return model, args
-
-
-def convert_tf2torch(
-        model,
-        ckpt,
-        mode,
-):
-    assert mode == "paraformer" or mode == "uniasr" or mode == "sond" or mode == "sv" or mode == "tp"
-    logging.info("start convert tf model to torch model")
-    from funasr.modules.streaming_utils.load_fr_tf import load_tf_dict
-    var_dict_tf = load_tf_dict(ckpt)
-    var_dict_torch = model.state_dict()
-    var_dict_torch_update = dict()
-    if mode == "uniasr":
-        # encoder
-        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # predictor
-        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # decoder
-        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # encoder2
-        var_dict_torch_update_local = model.encoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # predictor2
-        var_dict_torch_update_local = model.predictor2.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # decoder2
-        var_dict_torch_update_local = model.decoder2.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # stride_conv
-        var_dict_torch_update_local = model.stride_conv.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-    elif mode == "paraformer":
-        # encoder
-        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # predictor
-        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # decoder
-        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # bias_encoder
-        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-    elif "mode" == "sond":
-        if model.encoder is not None:
-            var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-            var_dict_torch_update.update(var_dict_torch_update_local)
-        # speaker encoder
-        if model.speaker_encoder is not None:
-            var_dict_torch_update_local = model.speaker_encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-            var_dict_torch_update.update(var_dict_torch_update_local)
-        # cd scorer
-        if model.cd_scorer is not None:
-            var_dict_torch_update_local = model.cd_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
-            var_dict_torch_update.update(var_dict_torch_update_local)
-        # ci scorer
-        if model.ci_scorer is not None:
-            var_dict_torch_update_local = model.ci_scorer.convert_tf2torch(var_dict_tf, var_dict_torch)
-            var_dict_torch_update.update(var_dict_torch_update_local)
-        # decoder
-        if model.decoder is not None:
-            var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-            var_dict_torch_update.update(var_dict_torch_update_local)
-    elif "mode" == "sv":
-        # speech encoder
-        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # pooling layer
-        var_dict_torch_update_local = model.pooling_layer.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # decoder
-        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-    else:
-        # encoder
-        var_dict_torch_update_local = model.encoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # predictor
-        var_dict_torch_update_local = model.predictor.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # decoder
-        var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        # bias_encoder
-        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
-        var_dict_torch_update.update(var_dict_torch_update_local)
-        return var_dict_torch_update
-
-    return var_dict_torch_update
-
-
-def fileter_model_dict(src_dict: dict, dest_dict: dict):
-    from collections import OrderedDict
-    new_dict = OrderedDict()
-    for key, value in src_dict.items():
-        if key in dest_dict:
-            new_dict[key] = value
-        else:
-            logging.info("{} is no longer needed in this model.".format(key))
-    for key, value in dest_dict.items():
-        if key not in new_dict:
-            logging.warning("{} is missed in checkpoint.".format(key))
-    return new_dict
diff --git a/funasr/build_utils/build_optimizer.py b/funasr/build_utils/build_optimizer.py
deleted file mode 100644
index bd0b73d..0000000
--- a/funasr/build_utils/build_optimizer.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import torch
-
-from funasr.optimizers.fairseq_adam import FairseqAdam
-from funasr.optimizers.sgd import SGD
-
-
-def build_optimizer(args, model):
-    optim_classes = dict(
-        adam=torch.optim.Adam,
-        fairseq_adam=FairseqAdam,
-        adamw=torch.optim.AdamW,
-        sgd=SGD,
-        adadelta=torch.optim.Adadelta,
-        adagrad=torch.optim.Adagrad,
-        adamax=torch.optim.Adamax,
-        asgd=torch.optim.ASGD,
-        lbfgs=torch.optim.LBFGS,
-        rmsprop=torch.optim.RMSprop,
-        rprop=torch.optim.Rprop,
-    )
-
-    optim_class = optim_classes.get(args.optim)
-    if optim_class is None:
-        raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
-    optimizer = optim_class(model.parameters(), **args.optim_conf)
-
-    optimizers = [optimizer]
-    return optimizers
\ No newline at end of file
diff --git a/funasr/build_utils/build_pretrain_model.py b/funasr/build_utils/build_pretrain_model.py
deleted file mode 100644
index 0784fb2..0000000
--- a/funasr/build_utils/build_pretrain_model.py
+++ /dev/null
@@ -1,112 +0,0 @@
-from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.data2vec import Data2VecPretrainModel
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.specaug.specaug import SpecAug
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-frontend_choices = ClassChoices(
-    name="frontend",
-    classes=dict(
-        default=DefaultFrontend,
-        sliding_window=SlidingWindow,
-        wav_frontend=WavFrontend,
-    ),
-    default="default",
-)
-specaug_choices = ClassChoices(
-    name="specaug",
-    classes=dict(specaug=SpecAug),
-    default=None,
-    optional=True,
-)
-normalize_choices = ClassChoices(
-    "normalize",
-    classes=dict(
-        global_mvn=GlobalMVN,
-        utterance_mvn=UtteranceMVN,
-    ),
-    default=None,
-    optional=True,
-)
-encoder_choices = ClassChoices(
-    "encoder",
-    classes=dict(
-        data2vec_encoder=Data2VecEncoder,
-    ),
-    default="data2vec_encoder",
-)
-model_choices = ClassChoices(
-    "model",
-    classes=dict(
-        data2vec=Data2VecPretrainModel,
-    ),
-    default="data2vec",
-)
-class_choices_list = [
-    # --frontend and --frontend_conf
-    frontend_choices,
-    # --specaug and --specaug_conf
-    specaug_choices,
-    # --normalize and --normalize_conf
-    normalize_choices,
-    # --encoder and --encoder_conf
-    encoder_choices,
-    # --model and --model_conf
-    model_choices,
-]
-
-
-def build_pretrain_model(args):
-    # frontend
-    if args.input_size is None:
-        frontend_class = frontend_choices.get_class(args.frontend)
-        frontend = frontend_class(**args.frontend_conf)
-        input_size = frontend.output_size()
-    else:
-        args.frontend = None
-        args.frontend_conf = {}
-        frontend = None
-        input_size = args.input_size
-
-    # data augmentation for spectrogram
-    if args.specaug is not None:
-        specaug_class = specaug_choices.get_class(args.specaug)
-        specaug = specaug_class(**args.specaug_conf)
-    else:
-        specaug = None
-
-    # normalization layer
-    if args.normalize is not None:
-        normalize_class = normalize_choices.get_class(args.normalize)
-        normalize = normalize_class(**args.normalize_conf)
-    else:
-        normalize = None
-
-    # encoder
-    encoder_class = encoder_choices.get_class(args.encoder)
-    encoder = encoder_class(
-        input_size=input_size,
-        **args.encoder_conf,
-    )
-
-    if args.model == "data2vec":
-        model_class = model_choices.get_class("data2vec")
-        model = model_class(
-            frontend=frontend,
-            specaug=specaug,
-            normalize=normalize,
-            encoder=encoder,
-        )
-    else:
-        raise NotImplementedError("Not supported model: {}".format(args.model))
-
-    # initialize
-    if args.init is not None:
-        initialize(model, args.init)
-
-    return model
diff --git a/funasr/build_utils/build_punc_model.py b/funasr/build_utils/build_punc_model.py
deleted file mode 100644
index 62ccaf2..0000000
--- a/funasr/build_utils/build_punc_model.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import logging
-
-from funasr.models.target_delay_transformer import TargetDelayTransformer
-from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
-from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_model import PunctuationModel
-from funasr.train.class_choices import ClassChoices
-
-punc_choices = ClassChoices(
-    "punctuation",
-    classes=dict(
-        target_delay=TargetDelayTransformer,
-        vad_realtime=VadRealtimeTransformer
-    ),
-    default="target_delay",
-)
-model_choices = ClassChoices(
-    "model",
-    classes=dict(
-        punc=PunctuationModel,
-    ),
-    default="punc",
-)
-class_choices_list = [
-    # --punc and --punc_conf
-    punc_choices,
-    # --model and --model_conf
-    model_choices
-]
-
-
-def build_punc_model(args):
-    # token_list and punc list
-    if isinstance(args.token_list, str):
-        with open(args.token_list, encoding="utf-8") as f:
-            token_list = [line.rstrip() for line in f]
-        args.token_list = token_list.copy()
-    if isinstance(args.punc_list, str):
-        with open(args.punc_list, encoding="utf-8") as f2:
-            pairs = [line.rstrip().split(":") for line in f2]
-        punc_list = [pair[0] for pair in pairs]
-        punc_weight_list = [float(pair[1]) for pair in pairs]
-        args.punc_list = punc_list.copy()
-    elif isinstance(args.punc_list, list):
-        punc_list = args.punc_list.copy()
-        punc_weight_list = [1] * len(punc_list)
-    if isinstance(args.token_list, (tuple, list)):
-        token_list = args.token_list.copy()
-    else:
-        raise RuntimeError("token_list must be str or dict")
-
-    vocab_size = len(token_list)
-    punc_size = len(punc_list)
-    logging.info(f"Vocabulary size: {vocab_size}")
-
-    # punc
-    punc_class = punc_choices.get_class(args.punctuation)
-    punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
-
-    if "punc_weight" in args.model_conf:
-        args.model_conf.pop("punc_weight")
-    model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
-
-    # initialize
-    if args.init is not None:
-        initialize(model, args.init)
-
-    return model
diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py
deleted file mode 100644
index 4b9990e..0000000
--- a/funasr/build_utils/build_scheduler.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import torch
-import torch.multiprocessing
-import torch.nn
-import torch.optim
-
-from funasr.schedulers.noam_lr import NoamLR
-from funasr.schedulers.tri_stage_scheduler import TriStageLR
-from funasr.schedulers.warmup_lr import WarmupLR
-
-
-def build_scheduler(args, optimizers):
-    scheduler_classes = dict(
-        ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
-        lambdalr=torch.optim.lr_scheduler.LambdaLR,
-        steplr=torch.optim.lr_scheduler.StepLR,
-        multisteplr=torch.optim.lr_scheduler.MultiStepLR,
-        exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
-        CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
-        noamlr=NoamLR,
-        warmuplr=WarmupLR,
-        tri_stage=TriStageLR,
-        cycliclr=torch.optim.lr_scheduler.CyclicLR,
-        onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
-        CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
-    )
-
-    schedulers = []
-    for i, optim in enumerate(optimizers, 1):
-        suf = "" if i == 1 else str(i)
-        name = getattr(args, f"scheduler{suf}")
-        conf = getattr(args, f"scheduler{suf}_conf")
-        if name is not None:
-            cls_ = scheduler_classes.get(name)
-            if cls_ is None:
-                raise ValueError(
-                    f"must be one of {list(scheduler_classes)}: {name}"
-                )
-            scheduler = cls_(optim, **conf)
-        else:
-            scheduler = None
-
-        schedulers.append(scheduler)
-
-    return schedulers
\ No newline at end of file
diff --git a/funasr/build_utils/build_ss_model.py b/funasr/build_utils/build_ss_model.py
deleted file mode 100644
index a6b5209..0000000
--- a/funasr/build_utils/build_ss_model.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from funasr.models.e2e_ss import MossFormer
-
-def build_ss_model(args):
-    model = MossFormer(
-        in_channels=args.encoder_embedding_dim,
-        out_channels=args.mossformer_sequence_dim,
-        num_blocks=args.num_mossformer_layer,
-        kernel_size=args.encoder_kernel_size,
-        norm=args.norm,
-        num_spks=args.num_spks,
-        skip_around_intra=args.skip_around_intra,
-        use_global_pos_enc=args.use_global_pos_enc,
-        max_length=args.max_length)
-
-    return model
diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py
deleted file mode 100644
index 02fc263..0000000
--- a/funasr/build_utils/build_streaming_iterator.py
+++ /dev/null
@@ -1,65 +0,0 @@
-import numpy as np
-from torch.utils.data import DataLoader
-
-from funasr.datasets.iterable_dataset import IterableESPnetDataset
-from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
-from funasr.datasets.small_datasets.preprocessor import build_preprocess
-
-
-def build_streaming_iterator(
-        task_name,
-        preprocess_args,
-        data_path_and_name_and_type,
-        key_file: str = None,
-        batch_size: int = 1,
-        fs: dict = None,
-        mc: bool = False,
-        dtype: str = np.float32,
-        num_workers: int = 1,
-        use_collate_fn: bool = True,
-        preprocess_fn=None,
-        ngpu: int = 0,
-        train: bool = False,
-) -> DataLoader:
-    """Build DataLoader using iterable dataset"""
-
-    # preprocess
-    if preprocess_fn is not None:
-        preprocess_fn = preprocess_fn
-    elif preprocess_args is not None:
-        preprocess_args.task_name = task_name
-        preprocess_fn = build_preprocess(preprocess_args, train)
-    else:
-        preprocess_fn = None
-
-    # collate
-    if not use_collate_fn:
-        collate_fn = None
-    elif task_name in ["punc", "lm"]:
-        collate_fn = CommonCollateFn(int_pad_value=0)
-    else:
-        collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
-    if collate_fn is not None:
-        kwargs = dict(collate_fn=collate_fn)
-    else:
-        kwargs = {}
-
-    dataset = IterableESPnetDataset(
-        data_path_and_name_and_type,
-        float_dtype=dtype,
-        fs=fs,
-        mc=mc,
-        preprocess=preprocess_fn,
-        key_file=key_file,
-    )
-    if dataset.apply_utt2category:
-        kwargs.update(batch_size=1)
-    else:
-        kwargs.update(batch_size=batch_size)
-
-    return DataLoader(
-        dataset=dataset,
-        pin_memory=ngpu > 0,
-        num_workers=num_workers,
-        **kwargs,
-    )
diff --git a/funasr/build_utils/build_sv_model.py b/funasr/build_utils/build_sv_model.py
deleted file mode 100644
index 55df75a..0000000
--- a/funasr/build_utils/build_sv_model.py
+++ /dev/null
@@ -1,256 +0,0 @@
-import logging
-
-import torch
-
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.layers.global_mvn import GlobalMVN
-from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.base_model import FunASRModel
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.sv_decoder import DenseDecoder
-from funasr.models.e2e_sv import ESPnetSVModel
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.resnet34_encoder import ResNet34, ResNet34_SP_L2Reg
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.pooling.statistic_pooling import StatisticPooling
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.postencoder.hugging_face_transformers_postencoder import (
-    HuggingFaceTransformersPostEncoder,  # noqa: H301
-)
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.preencoder.linear import LinearProjection
-from funasr.models.preencoder.sinc import LightweightSincConvs
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.specaug.specaug import SpecAug
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-frontend_choices = ClassChoices(
-    name="frontend",
-    classes=dict(
-        default=DefaultFrontend,
-        sliding_window=SlidingWindow,
-        s3prl=S3prlFrontend,
-        fused=FusedFrontends,
-        wav_frontend=WavFrontend,
-    ),
-    type_check=AbsFrontend,
-    default="default",
-)
-specaug_choices = ClassChoices(
-    name="specaug",
-    classes=dict(
-        specaug=SpecAug,
-    ),
-    type_check=AbsSpecAug,
-    default=None,
-    optional=True,
-)
-normalize_choices = ClassChoices(
-    "normalize",
-    classes=dict(
-        global_mvn=GlobalMVN,
-        utterance_mvn=UtteranceMVN,
-    ),
-    type_check=AbsNormalize,
-    default=None,
-    optional=True,
-)
-model_choices = ClassChoices(
-    "model",
-    classes=dict(
-        espnet=ESPnetSVModel,
-    ),
-    type_check=FunASRModel,
-    default="espnet",
-)
-preencoder_choices = ClassChoices(
-    name="preencoder",
-    classes=dict(
-        sinc=LightweightSincConvs,
-        linear=LinearProjection,
-    ),
-    type_check=AbsPreEncoder,
-    default=None,
-    optional=True,
-)
-encoder_choices = ClassChoices(
-    "encoder",
-    classes=dict(
-        resnet34=ResNet34,
-        resnet34_sp_l2reg=ResNet34_SP_L2Reg,
-        rnn=RNNEncoder,
-    ),
-    type_check=AbsEncoder,
-    default="resnet34",
-)
-postencoder_choices = ClassChoices(
-    name="postencoder",
-    classes=dict(
-        hugging_face_transformers=HuggingFaceTransformersPostEncoder,
-    ),
-    type_check=AbsPostEncoder,
-    default=None,
-    optional=True,
-)
-pooling_choices = ClassChoices(
-    name="pooling_type",
-    classes=dict(
-        statistic=StatisticPooling,
-    ),
-    type_check=torch.nn.Module,
-    default="statistic",
-)
-decoder_choices = ClassChoices(
-    "decoder",
-    classes=dict(
-        dense=DenseDecoder,
-    ),
-    type_check=AbsDecoder,
-    default="dense",
-)
-
-class_choices_list = [
-    # --frontend and --frontend_conf
-    frontend_choices,
-    # --specaug and --specaug_conf
-    specaug_choices,
-    # --normalize and --normalize_conf
-    normalize_choices,
-    # --model and --model_conf
-    model_choices,
-    # --preencoder and --preencoder_conf
-    preencoder_choices,
-    # --encoder and --encoder_conf
-    encoder_choices,
-    # --postencoder and --postencoder_conf
-    postencoder_choices,
-    # --pooling and --pooling_conf
-    pooling_choices,
-    # --decoder and --decoder_conf
-    decoder_choices,
-]
-
-
-def build_sv_model(args):
-    # token_list
-    if isinstance(args.token_list, str):
-        with open(args.token_list, encoding="utf-8") as f:
-            token_list = [line.rstrip() for line in f]
-
-        # Overwriting token_list to keep it as "portable".
-        args.token_list = list(token_list)
-    elif isinstance(args.token_list, (tuple, list)):
-        token_list = list(args.token_list)
-    else:
-        raise RuntimeError("token_list must be str or list")
-    vocab_size = len(token_list)
-    logging.info(f"Speaker number: {vocab_size}")
-
-    # 1. frontend
-    if args.input_size is None:
-        # Extract features in the model
-        frontend_class = frontend_choices.get_class(args.frontend)
-        frontend = frontend_class(**args.frontend_conf)
-        input_size = frontend.output_size()
-    else:
-        # Give features from data-loader
-        args.frontend = None
-        args.frontend_conf = {}
-        frontend = None
-        input_size = args.input_size
-
-    # 2. Data augmentation for spectrogram
-    if args.specaug is not None:
-        specaug_class = specaug_choices.get_class(args.specaug)
-        specaug = specaug_class(**args.specaug_conf)
-    else:
-        specaug = None
-
-    # 3. Normalization layer
-    if args.normalize is not None:
-        normalize_class = normalize_choices.get_class(args.normalize)
-        normalize = normalize_class(**args.normalize_conf)
-    else:
-        normalize = None
-
-    # 4. Pre-encoder input block
-    # NOTE(kan-bayashi): Use getattr to keep the compatibility
-    if getattr(args, "preencoder", None) is not None:
-        preencoder_class = preencoder_choices.get_class(args.preencoder)
-        preencoder = preencoder_class(**args.preencoder_conf)
-        input_size = preencoder.output_size()
-    else:
-        preencoder = None
-
-    # 5. Encoder
-    encoder_class = encoder_choices.get_class(args.encoder)
-    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
-
-    # 6. Post-encoder block
-    # NOTE(kan-bayashi): Use getattr to keep the compatibility
-    encoder_output_size = encoder.output_size()
-    if getattr(args, "postencoder", None) is not None:
-        postencoder_class = postencoder_choices.get_class(args.postencoder)
-        postencoder = postencoder_class(
-            input_size=encoder_output_size, **args.postencoder_conf
-        )
-        encoder_output_size = postencoder.output_size()
-    else:
-        postencoder = None
-
-    # 7. Pooling layer
-    pooling_class = pooling_choices.get_class(args.pooling_type)
-    pooling_dim = (2, 3)
-    eps = 1e-12
-    if hasattr(args, "pooling_type_conf"):
-        if "pooling_dim" in args.pooling_type_conf:
-            pooling_dim = args.pooling_type_conf["pooling_dim"]
-        if "eps" in args.pooling_type_conf:
-            eps = args.pooling_type_conf["eps"]
-    pooling_layer = pooling_class(
-        pooling_dim=pooling_dim,
-        eps=eps,
-    )
-    if args.pooling_type == "statistic":
-        encoder_output_size *= 2
-
-    # 8. Decoder
-    decoder_class = decoder_choices.get_class(args.decoder)
-    decoder = decoder_class(
-        vocab_size=vocab_size,
-        encoder_output_size=encoder_output_size,
-        **args.decoder_conf,
-    )
-
-    # 7. Build model
-    try:
-        model_class = model_choices.get_class(args.model)
-    except AttributeError:
-        model_class = model_choices.get_class("espnet")
-    model = model_class(
-        vocab_size=vocab_size,
-        token_list=token_list,
-        frontend=frontend,
-        specaug=specaug,
-        normalize=normalize,
-        preencoder=preencoder,
-        encoder=encoder,
-        postencoder=postencoder,
-        pooling_layer=pooling_layer,
-        decoder=decoder,
-        **args.model_conf,
-    )
-
-    # FIXME(kamo): Should be done in model?
-    # 8. Initialize
-    if args.init is not None:
-        initialize(model, args.init)
-
-    return model
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
deleted file mode 100644
index 498d05d..0000000
--- a/funasr/build_utils/build_trainer.py
+++ /dev/null
@@ -1,812 +0,0 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
-"""Trainer module."""
-import argparse
-import dataclasses
-import logging
-import os
-import time
-from contextlib import contextmanager
-from dataclasses import is_dataclass
-from distutils.version import LooseVersion
-from io import BytesIO
-from pathlib import Path
-from typing import Dict
-from typing import Iterable
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-
-import humanfriendly
-import oss2
-import torch
-import torch.nn
-import torch.optim
-
-from funasr.iterators.abs_iter_factory import AbsIterFactory
-from funasr.main_funcs.average_nbest_models import average_nbest_models
-from funasr.models.base_model import FunASRModel
-from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
-from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
-from funasr.schedulers.abs_scheduler import AbsScheduler
-from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
-from funasr.torch_utils.add_gradient_noise import add_gradient_noise
-from funasr.torch_utils.device_funcs import to_device
-from funasr.torch_utils.recursive_op import recursive_average
-from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.distributed_utils import DistributedOption
-from funasr.train.reporter import Reporter
-from funasr.train.reporter import SubReporter
-from funasr.utils.build_dataclass import build_dataclass
-
-if torch.distributed.is_available():
-    from torch.distributed import ReduceOp
-
-if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
-    from torch.cuda.amp import autocast
-    from torch.cuda.amp import GradScaler
-else:
-    # Nothing to do if torch<1.6.0
-    @contextmanager
-    def autocast(enabled=True):
-        yield
-
-
-    GradScaler = None
-
-try:
-    import fairscale
-except ImportError:
-    fairscale = None
-
-
-@dataclasses.dataclass
-class TrainerOptions:
-    ngpu: int
-    resume: bool
-    use_amp: bool
-    train_dtype: str
-    grad_noise: bool
-    accum_grad: int
-    grad_clip: float
-    grad_clip_type: float
-    log_interval: Optional[int]
-    # no_forward_run: bool
-    use_tensorboard: bool
-    # use_wandb: bool
-    output_dir: Union[Path, str]
-    max_epoch: int
-    max_update: int
-    seed: int
-    # sharded_ddp: bool
-    patience: Optional[int]
-    keep_nbest_models: Union[int, List[int]]
-    nbest_averaging_interval: int
-    early_stopping_criterion: Sequence[str]
-    best_model_criterion: Sequence[Sequence[str]]
-    val_scheduler_criterion: Sequence[str]
-    unused_parameters: bool
-    # wandb_model_log_interval: int
-    use_pai: bool
-    oss_bucket: Union[oss2.Bucket, None]
-
-
-class Trainer:
-    """Trainer
-
-    """
-
-    def __init__(self,
-                 args,
-                 model: FunASRModel,
-                 optimizers: Sequence[torch.optim.Optimizer],
-                 schedulers: Sequence[Optional[AbsScheduler]],
-                 train_dataloader: AbsIterFactory,
-                 valid_dataloader: AbsIterFactory,
-                 distributed_option: DistributedOption):
-        self.trainer_options = self.build_options(args)
-        self.model = model
-        self.optimizers = optimizers
-        self.schedulers = schedulers
-        self.train_dataloader = train_dataloader
-        self.valid_dataloader = valid_dataloader
-        self.distributed_option = distributed_option
-
-    def build_options(self, args: argparse.Namespace) -> TrainerOptions:
-        """Build options consumed by train(), eval()"""
-        return build_dataclass(TrainerOptions, args)
-
-    @classmethod
-    def add_arguments(cls, parser: argparse.ArgumentParser):
-        """Reserved for future development of another Trainer"""
-        pass
-
-    def resume(self,
-               checkpoint: Union[str, Path],
-               model: torch.nn.Module,
-               reporter: Reporter,
-               optimizers: Sequence[torch.optim.Optimizer],
-               schedulers: Sequence[Optional[AbsScheduler]],
-               scaler: Optional[GradScaler],
-               ngpu: int = 0,
-               ):
-        states = torch.load(
-            checkpoint,
-            map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
-        )
-        model.load_state_dict(states["model"])
-        reporter.load_state_dict(states["reporter"])
-        for optimizer, state in zip(optimizers, states["optimizers"]):
-            optimizer.load_state_dict(state)
-        for scheduler, state in zip(schedulers, states["schedulers"]):
-            if scheduler is not None:
-                scheduler.load_state_dict(state)
-        if scaler is not None:
-            if states["scaler"] is None:
-                logging.warning("scaler state is not found")
-            else:
-                scaler.load_state_dict(states["scaler"])
-
-        logging.info(f"The training was resumed using {checkpoint}")
-
-    def run(self) -> None:
-        """Perform training. This method performs the main process of training."""
-        # NOTE(kamo): Don't check the type more strictly as far trainer_options
-        model = self.model
-        optimizers = self.optimizers
-        schedulers = self.schedulers
-        train_dataloader = self.train_dataloader
-        valid_dataloader = self.valid_dataloader
-        trainer_options = self.trainer_options
-        distributed_option = self.distributed_option
-        assert is_dataclass(trainer_options), type(trainer_options)
-        assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
-
-        if isinstance(trainer_options.keep_nbest_models, int):
-            keep_nbest_models = [trainer_options.keep_nbest_models]
-        else:
-            if len(trainer_options.keep_nbest_models) == 0:
-                logging.warning("No keep_nbest_models is given. Change to [1]")
-                trainer_options.keep_nbest_models = [1]
-            keep_nbest_models = trainer_options.keep_nbest_models
-
-        output_dir = Path(trainer_options.output_dir)
-        reporter = Reporter()
-        if trainer_options.use_amp:
-            if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
-                raise RuntimeError(
-                    "Require torch>=1.6.0 for  Automatic Mixed Precision"
-                )
-            # if trainer_options.sharded_ddp:
-            #     if fairscale is None:
-            #         raise RuntimeError(
-            #             "Requiring fairscale. Do 'pip install fairscale'"
-            #         )
-            #     scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
-            # else:
-            scaler = GradScaler()
-        else:
-            scaler = None
-
-        if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
-            self.resume(
-                checkpoint=output_dir / "checkpoint.pb",
-                model=model,
-                optimizers=optimizers,
-                schedulers=schedulers,
-                reporter=reporter,
-                scaler=scaler,
-                ngpu=trainer_options.ngpu,
-            )
-
-        start_epoch = reporter.get_epoch() + 1
-        if start_epoch == trainer_options.max_epoch + 1:
-            logging.warning(
-                f"The training has already reached at max_epoch: {start_epoch}"
-            )
-
-        if distributed_option.distributed:
-            dp_model = torch.nn.parallel.DistributedDataParallel(
-                model, find_unused_parameters=trainer_options.unused_parameters)
-        elif distributed_option.ngpu > 1:
-            dp_model = torch.nn.parallel.DataParallel(
-                model,
-                device_ids=list(range(distributed_option.ngpu)),
-            )
-        else:
-            # NOTE(kamo): DataParallel also should work with ngpu=1,
-            # but for debuggability it's better to keep this block.
-            dp_model = model
-
-        if trainer_options.use_tensorboard and (
-                not distributed_option.distributed or distributed_option.dist_rank == 0
-        ):
-            from torch.utils.tensorboard import SummaryWriter
-            if trainer_options.use_pai:
-                train_summary_writer = SummaryWriter(
-                    os.path.join(trainer_options.output_dir, "tensorboard/train")
-                )
-                valid_summary_writer = SummaryWriter(
-                    os.path.join(trainer_options.output_dir, "tensorboard/valid")
-                )
-            else:
-                train_summary_writer = SummaryWriter(
-                    str(output_dir / "tensorboard" / "train")
-                )
-                valid_summary_writer = SummaryWriter(
-                    str(output_dir / "tensorboard" / "valid")
-                )
-        else:
-            train_summary_writer = None
-
-        start_time = time.perf_counter()
-        for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
-            if iepoch != start_epoch:
-                logging.info(
-                    "{}/{}epoch started. Estimated time to finish: {} hours".format(
-                        iepoch,
-                        trainer_options.max_epoch,
-                        (time.perf_counter() - start_time) / 3600.0 / (iepoch - start_epoch) * (
-                                trainer_options.max_epoch - iepoch + 1),
-                    )
-                )
-            else:
-                logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
-            set_all_random_seed(trainer_options.seed + iepoch)
-
-            reporter.set_epoch(iepoch)
-            # 1. Train and validation for one-epoch
-            with reporter.observe("train") as sub_reporter:
-                all_steps_are_invalid, max_update_stop = self.train_one_epoch(
-                    model=dp_model,
-                    optimizers=optimizers,
-                    schedulers=schedulers,
-                    iterator=train_dataloader.build_iter(iepoch),
-                    reporter=sub_reporter,
-                    scaler=scaler,
-                    summary_writer=train_summary_writer,
-                    options=trainer_options,
-                    distributed_option=distributed_option,
-                )
-
-            with reporter.observe("valid") as sub_reporter:
-                self.validate_one_epoch(
-                    model=dp_model,
-                    iterator=valid_dataloader.build_iter(iepoch),
-                    reporter=sub_reporter,
-                    options=trainer_options,
-                    distributed_option=distributed_option,
-                )
-
-            # 2. LR Scheduler step
-            for scheduler in schedulers:
-                if isinstance(scheduler, AbsValEpochStepScheduler):
-                    scheduler.step(
-                        reporter.get_value(*trainer_options.val_scheduler_criterion)
-                    )
-                elif isinstance(scheduler, AbsEpochStepScheduler):
-                    scheduler.step()
-            # if trainer_options.sharded_ddp:
-            #     for optimizer in optimizers:
-            #         if isinstance(optimizer, fairscale.optim.oss.OSS):
-            #             optimizer.consolidate_state_dict()
-
-            if not distributed_option.distributed or distributed_option.dist_rank == 0:
-                # 3. Report the results
-                logging.info(reporter.log_message())
-                if train_summary_writer is not None:
-                    reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
-                    reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
-                # if trainer_options.use_wandb:
-                #     reporter.wandb_log()
-
-                # save tensorboard on oss
-                if trainer_options.use_pai and train_summary_writer is not None:
-                    def write_tensorboard_summary(summary_writer_path, oss_bucket):
-                        file_list = []
-                        for root, dirs, files in os.walk(summary_writer_path, topdown=False):
-                            for name in files:
-                                file_full_path = os.path.join(root, name)
-                                file_list.append(file_full_path)
-
-                        for file_full_path in file_list:
-                            with open(file_full_path, "rb") as f:
-                                oss_bucket.put_object(file_full_path, f)
-
-                    write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"),
-                                              trainer_options.oss_bucket)
-                    write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"),
-                                              trainer_options.oss_bucket)
-
-                # 4. Save/Update the checkpoint
-                if trainer_options.use_pai:
-                    buffer = BytesIO()
-                    torch.save(
-                        {
-                            "model": model.state_dict(),
-                            "reporter": reporter.state_dict(),
-                            "optimizers": [o.state_dict() for o in optimizers],
-                            "schedulers": [
-                                s.state_dict() if s is not None else None
-                                for s in schedulers
-                            ],
-                            "scaler": scaler.state_dict() if scaler is not None else None,
-                            "ema_model": model.encoder.ema.model.state_dict()
-                            if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
-                        },
-                        buffer,
-                    )
-                    trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"),
-                                                          buffer.getvalue())
-                else:
-                    torch.save(
-                        {
-                            "model": model.state_dict(),
-                            "reporter": reporter.state_dict(),
-                            "optimizers": [o.state_dict() for o in optimizers],
-                            "schedulers": [
-                                s.state_dict() if s is not None else None
-                                for s in schedulers
-                            ],
-                            "scaler": scaler.state_dict() if scaler is not None else None,
-                        },
-                        output_dir / "checkpoint.pb",
-                    )
-
-                # 5. Save and log the model and update the link to the best model
-                if trainer_options.use_pai:
-                    buffer = BytesIO()
-                    torch.save(model.state_dict(), buffer)
-                    trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
-                                                                       f"{iepoch}epoch.pb"), buffer.getvalue())
-                else:
-                    torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
-
-                # Creates a sym link latest.pb -> {iepoch}epoch.pb
-                if trainer_options.use_pai:
-                    p = os.path.join(trainer_options.output_dir, "latest.pb")
-                    if trainer_options.oss_bucket.object_exists(p):
-                        trainer_options.oss_bucket.delete_object(p)
-                    trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
-                                                           os.path.join(trainer_options.output_dir,
-                                                                        f"{iepoch}epoch.pb"), p)
-                else:
-                    p = output_dir / "latest.pb"
-                    if p.is_symlink() or p.exists():
-                        p.unlink()
-                    p.symlink_to(f"{iepoch}epoch.pb")
-
-                _improved = []
-                for _phase, k, _mode in trainer_options.best_model_criterion:
-                    # e.g. _phase, k, _mode = "train", "loss", "min"
-                    if reporter.has(_phase, k):
-                        best_epoch = reporter.get_best_epoch(_phase, k, _mode)
-                        # Creates sym links if it's the best result
-                        if best_epoch == iepoch:
-                            if trainer_options.use_pai:
-                                p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
-                                if trainer_options.oss_bucket.object_exists(p):
-                                    trainer_options.oss_bucket.delete_object(p)
-                                trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
-                                                                       os.path.join(trainer_options.output_dir,
-                                                                                    f"{iepoch}epoch.pb"), p)
-                            else:
-                                p = output_dir / f"{_phase}.{k}.best.pb"
-                                if p.is_symlink() or p.exists():
-                                    p.unlink()
-                                p.symlink_to(f"{iepoch}epoch.pb")
-                            _improved.append(f"{_phase}.{k}")
-                if len(_improved) == 0:
-                    logging.info("There are no improvements in this epoch")
-                else:
-                    logging.info(
-                        "The best model has been updated: " + ", ".join(_improved)
-                    )
-
-                # log_model = (
-                #         trainer_options.wandb_model_log_interval > 0
-                #         and iepoch % trainer_options.wandb_model_log_interval == 0
-                # )
-                # if log_model and trainer_options.use_wandb:
-                #     import wandb
-                #
-                #     logging.info("Logging Model on this epoch :::::")
-                #     artifact = wandb.Artifact(
-                #         name=f"model_{wandb.run.id}",
-                #         type="model",
-                #         metadata={"improved": _improved},
-                #     )
-                #     artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
-                #     aliases = [
-                #         f"epoch-{iepoch}",
-                #         "best" if best_epoch == iepoch else "",
-                #     ]
-                #     wandb.log_artifact(artifact, aliases=aliases)
-
-                # 6. Remove the model files excluding n-best epoch and latest epoch
-                _removed = []
-                # Get the union set of the n-best among multiple criterion
-                nbests = set().union(
-                    *[
-                        set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
-                        for ph, k, m in trainer_options.best_model_criterion
-                        if reporter.has(ph, k)
-                    ]
-                )
-
-                # Generated n-best averaged model
-                if (
-                        trainer_options.nbest_averaging_interval > 0
-                        and iepoch % trainer_options.nbest_averaging_interval == 0
-                ):
-                    average_nbest_models(
-                        reporter=reporter,
-                        output_dir=output_dir,
-                        best_model_criterion=trainer_options.best_model_criterion,
-                        nbest=keep_nbest_models,
-                        suffix=f"till{iepoch}epoch",
-                        oss_bucket=trainer_options.oss_bucket,
-                        pai_output_dir=trainer_options.output_dir,
-                    )
-
-                for e in range(1, iepoch):
-                    if trainer_options.use_pai:
-                        p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
-                        if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
-                            trainer_options.oss_bucket.delete_object(p)
-                            _removed.append(str(p))
-                    else:
-                        p = output_dir / f"{e}epoch.pb"
-                        if p.exists() and e not in nbests:
-                            p.unlink()
-                            _removed.append(str(p))
-                if len(_removed) != 0:
-                    logging.info("The model files were removed: " + ", ".join(_removed))
-
-            # 7. If any updating haven't happened, stops the training
-            if all_steps_are_invalid:
-                logging.warning(
-                    f"The gradients at all steps are invalid in this epoch. "
-                    f"Something seems wrong. This training was stopped at {iepoch}epoch"
-                )
-                break
-
-            if max_update_stop:
-                logging.info(
-                    f"Stopping training due to "
-                    f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
-                )
-                break
-
-            # 8. Check early stopping
-            if trainer_options.patience is not None:
-                if reporter.check_early_stopping(
-                        trainer_options.patience, *trainer_options.early_stopping_criterion
-                ):
-                    break
-
-        else:
-            logging.info(
-                f"The training was finished at {trainer_options.max_epoch} epochs "
-            )
-
-        # Generated n-best averaged model
-        if not distributed_option.distributed or distributed_option.dist_rank == 0:
-            average_nbest_models(
-                reporter=reporter,
-                output_dir=output_dir,
-                best_model_criterion=trainer_options.best_model_criterion,
-                nbest=keep_nbest_models,
-                oss_bucket=trainer_options.oss_bucket,
-                pai_output_dir=trainer_options.output_dir,
-            )
-
-    def train_one_epoch(
-            self,
-            model: torch.nn.Module,
-            iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
-            optimizers: Sequence[torch.optim.Optimizer],
-            schedulers: Sequence[Optional[AbsScheduler]],
-            scaler: Optional[GradScaler],
-            reporter: SubReporter,
-            summary_writer,
-            options: TrainerOptions,
-            distributed_option: DistributedOption,
-    ) -> Tuple[bool, bool]:
-
-        grad_noise = options.grad_noise
-        accum_grad = options.accum_grad
-        grad_clip = options.grad_clip
-        grad_clip_type = options.grad_clip_type
-        log_interval = options.log_interval
-        # no_forward_run = options.no_forward_run
-        ngpu = options.ngpu
-        # use_wandb = options.use_wandb
-        distributed = distributed_option.distributed
-
-        if log_interval is None:
-            try:
-                log_interval = max(len(iterator) // 20, 10)
-            except TypeError:
-                log_interval = 100
-
-        model.train()
-        all_steps_are_invalid = True
-        max_update_stop = False
-        # [For distributed] Because iteration counts are not always equals between
-        # processes, send stop-flag to the other processes if iterator is finished
-        iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
-
-        start_time = time.perf_counter()
-        for iiter, (_, batch) in enumerate(
-                reporter.measure_iter_time(iterator, "iter_time"), 1
-        ):
-            assert isinstance(batch, dict), type(batch)
-
-            if distributed:
-                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
-                if iterator_stop > 0:
-                    break
-
-            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
-            # if no_forward_run:
-            #     all_steps_are_invalid = False
-            #     continue
-
-            with autocast(scaler is not None):
-                with reporter.measure_time("forward_time"):
-                    retval = model(**batch)
-
-                    # Note(kamo):
-                    # Supporting two patterns for the returned value from the model
-                    #   a. dict type
-                    if isinstance(retval, dict):
-                        loss = retval["loss"]
-                        stats = retval["stats"]
-                        weight = retval["weight"]
-                        optim_idx = retval.get("optim_idx")
-                        if optim_idx is not None and not isinstance(optim_idx, int):
-                            if not isinstance(optim_idx, torch.Tensor):
-                                raise RuntimeError(
-                                    "optim_idx must be int or 1dim torch.Tensor, "
-                                    f"but got {type(optim_idx)}"
-                                )
-                            if optim_idx.dim() >= 2:
-                                raise RuntimeError(
-                                    "optim_idx must be int or 1dim torch.Tensor, "
-                                    f"but got {optim_idx.dim()}dim tensor"
-                                )
-                            if optim_idx.dim() == 1:
-                                for v in optim_idx:
-                                    if v != optim_idx[0]:
-                                        raise RuntimeError(
-                                            "optim_idx must be 1dim tensor "
-                                            "having same values for all entries"
-                                        )
-                                optim_idx = optim_idx[0].item()
-                            else:
-                                optim_idx = optim_idx.item()
-
-                    #   b. tuple or list type
-                    else:
-                        loss, stats, weight = retval
-                        optim_idx = None
-
-                stats = {k: v for k, v in stats.items() if v is not None}
-                if ngpu > 1 or distributed:
-                    # Apply weighted averaging for loss and stats
-                    loss = (loss * weight.type(loss.dtype)).sum()
-
-                    # if distributed, this method can also apply all_reduce()
-                    stats, weight = recursive_average(stats, weight, distributed)
-
-                    # Now weight is summation over all workers
-                    loss /= weight
-                if distributed:
-                    # NOTE(kamo): Multiply world_size because DistributedDataParallel
-                    # automatically normalizes the gradient by world_size.
-                    loss *= torch.distributed.get_world_size()
-
-                loss /= accum_grad
-
-            reporter.register(stats, weight)
-
-            with reporter.measure_time("backward_time"):
-                if scaler is not None:
-                    # Scales loss.  Calls backward() on scaled loss
-                    # to create scaled gradients.
-                    # Backward passes under autocast are not recommended.
-                    # Backward ops run in the same dtype autocast chose
-                    # for corresponding forward ops.
-                    scaler.scale(loss).backward()
-                else:
-                    loss.backward()
-
-            if iiter % accum_grad == 0:
-                if scaler is not None:
-                    # Unscales the gradients of optimizer's assigned params in-place
-                    for iopt, optimizer in enumerate(optimizers):
-                        if optim_idx is not None and iopt != optim_idx:
-                            continue
-                        scaler.unscale_(optimizer)
-
-                # gradient noise injection
-                if grad_noise:
-                    add_gradient_noise(
-                        model,
-                        reporter.get_total_count(),
-                        duration=100,
-                        eta=1.0,
-                        scale_factor=0.55,
-                    )
-
-                # compute the gradient norm to check if it is normal or not
-                grad_norm = torch.nn.utils.clip_grad_norm_(
-                    model.parameters(),
-                    max_norm=grad_clip,
-                    norm_type=grad_clip_type,
-                )
-                # PyTorch<=1.4, clip_grad_norm_ returns float value
-                if not isinstance(grad_norm, torch.Tensor):
-                    grad_norm = torch.tensor(grad_norm)
-
-                if not torch.isfinite(grad_norm):
-                    logging.warning(
-                        f"The grad norm is {grad_norm}. Skipping updating the model."
-                    )
-
-                    # Must invoke scaler.update() if unscale_() is used in the iteration
-                    # to avoid the following error:
-                    #   RuntimeError: unscale_() has already been called
-                    #   on this optimizer since the last update().
-                    # Note that if the gradient has inf/nan values,
-                    # scaler.step skips optimizer.step().
-                    if scaler is not None:
-                        for iopt, optimizer in enumerate(optimizers):
-                            if optim_idx is not None and iopt != optim_idx:
-                                continue
-                            scaler.step(optimizer)
-                            scaler.update()
-
-                else:
-                    all_steps_are_invalid = False
-                    with reporter.measure_time("optim_step_time"):
-                        for iopt, (optimizer, scheduler) in enumerate(
-                                zip(optimizers, schedulers)
-                        ):
-                            if optim_idx is not None and iopt != optim_idx:
-                                continue
-                            if scaler is not None:
-                                # scaler.step() first unscales the gradients of
-                                # the optimizer's assigned params.
-                                scaler.step(optimizer)
-                                # Updates the scale for next iteration.
-                                scaler.update()
-                            else:
-                                optimizer.step()
-                            if isinstance(scheduler, AbsBatchStepScheduler):
-                                scheduler.step()
-                for iopt, optimizer in enumerate(optimizers):
-                    if optim_idx is not None and iopt != optim_idx:
-                        continue
-                    optimizer.zero_grad()
-
-                # Register lr and train/load time[sec/step],
-                # where step refers to accum_grad * mini-batch
-                reporter.register(
-                    dict(
-                        {
-                            f"optim{i}_lr{j}": pg["lr"]
-                            for i, optimizer in enumerate(optimizers)
-                            for j, pg in enumerate(optimizer.param_groups)
-                            if "lr" in pg
-                        },
-                        train_time=time.perf_counter() - start_time,
-                    ),
-                )
-                start_time = time.perf_counter()
-
-                # update num_updates
-                if distributed:
-                    if hasattr(model.module, "num_updates"):
-                        model.module.set_num_updates(model.module.get_num_updates() + 1)
-                        options.num_updates = model.module.get_num_updates()
-                        if model.module.get_num_updates() >= options.max_update:
-                            max_update_stop = True
-                else:
-                    if hasattr(model, "num_updates"):
-                        model.set_num_updates(model.get_num_updates() + 1)
-                        options.num_updates = model.get_num_updates()
-                        if model.get_num_updates() >= options.max_update:
-                            max_update_stop = True
-
-            # NOTE(kamo): Call log_message() after next()
-            reporter.next()
-            if iiter % log_interval == 0:
-                num_updates = options.num_updates if hasattr(options, "num_updates") else None
-                logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
-                if summary_writer is not None:
-                    reporter.tensorboard_add_scalar(summary_writer, -log_interval)
-                # if use_wandb:
-                #     reporter.wandb_log()
-
-            if max_update_stop:
-                break
-
-        else:
-            if distributed:
-                iterator_stop.fill_(1)
-                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
-        return all_steps_are_invalid, max_update_stop
-
-    @torch.no_grad()
-    def validate_one_epoch(
-            self,
-            model: torch.nn.Module,
-            iterator: Iterable[Dict[str, torch.Tensor]],
-            reporter: SubReporter,
-            options: TrainerOptions,
-            distributed_option: DistributedOption,
-    ) -> None:
-        ngpu = options.ngpu
-        # no_forward_run = options.no_forward_run
-        distributed = distributed_option.distributed
-
-        model.eval()
-
-        # [For distributed] Because iteration counts are not always equals between
-        # processes, send stop-flag to the other processes if iterator is finished
-        iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
-        for (_, batch) in iterator:
-            assert isinstance(batch, dict), type(batch)
-            if distributed:
-                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
-                if iterator_stop > 0:
-                    break
-
-            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
-            # if no_forward_run:
-            #     continue
-
-            retval = model(**batch)
-            if isinstance(retval, dict):
-                stats = retval["stats"]
-                weight = retval["weight"]
-            else:
-                _, stats, weight = retval
-            if ngpu > 1 or distributed:
-                # Apply weighted averaging for stats.
-                # if distributed, this method can also apply all_reduce()
-                stats, weight = recursive_average(stats, weight, distributed)
-
-            reporter.register(stats, weight)
-            reporter.next()
-
-        else:
-            if distributed:
-                iterator_stop.fill_(1)
-                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
-
-
-def build_trainer(
-        args,
-        model: FunASRModel,
-        optimizers: Sequence[torch.optim.Optimizer],
-        schedulers: Sequence[Optional[AbsScheduler]],
-        train_dataloader: AbsIterFactory,
-        valid_dataloader: AbsIterFactory,
-        distributed_option: DistributedOption
-):
-    trainer = Trainer(
-        args=args,
-        model=model,
-        optimizers=optimizers,
-        schedulers=schedulers,
-        train_dataloader=train_dataloader,
-        valid_dataloader=valid_dataloader,
-        distributed_option=distributed_option
-    )
-    return trainer
diff --git a/funasr/build_utils/build_vad_model.py b/funasr/build_utils/build_vad_model.py
deleted file mode 100644
index 6a840cf..0000000
--- a/funasr/build_utils/build_vad_model.py
+++ /dev/null
@@ -1,81 +0,0 @@
-import torch
-
-from funasr.models.e2e_vad import E2EVadModel
-from funasr.models.encoder.fsmn_encoder import FSMN
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.torch_utils.initialize import initialize
-from funasr.train.class_choices import ClassChoices
-
-frontend_choices = ClassChoices(
-    name="frontend",
-    classes=dict(
-        default=DefaultFrontend,
-        sliding_window=SlidingWindow,
-        s3prl=S3prlFrontend,
-        fused=FusedFrontends,
-        wav_frontend=WavFrontend,
-        wav_frontend_online=WavFrontendOnline,
-    ),
-    default="default",
-)
-encoder_choices = ClassChoices(
-    "encoder",
-    classes=dict(
-        fsmn=FSMN,
-    ),
-    type_check=torch.nn.Module,
-    default="fsmn",
-)
-model_choices = ClassChoices(
-    "model",
-    classes=dict(
-        e2evad=E2EVadModel,
-    ),
-    default="e2evad",
-)
-
-class_choices_list = [
-    # --frontend and --frontend_conf
-    frontend_choices,
-    # --encoder and --encoder_conf
-    encoder_choices,
-    # --model and --model_conf
-    model_choices,
-]
-
-
-def build_vad_model(args):
-    # frontend
-    if not hasattr(args, "cmvn_file"):
-        args.cmvn_file = None
-    if not hasattr(args, "init"):
-        args.init = None
-    if args.input_size is None:
-        frontend_class = frontend_choices.get_class(args.frontend)
-        if args.frontend == 'wav_frontend':
-            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
-        else:
-            frontend = frontend_class(**args.frontend_conf)
-        input_size = frontend.output_size()
-    else:
-        args.frontend = None
-        args.frontend_conf = {}
-        frontend = None
-        input_size = args.input_size
-
-    # encoder
-    encoder_class = encoder_choices.get_class(args.encoder)
-    encoder = encoder_class(**args.encoder_conf)
-
-    model_class = model_choices.get_class(args.model)
-    model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
-
-    # initialize
-    if args.init is not None:
-        initialize(model, args.init)
-
-    return model
diff --git a/funasr/cli/train_cli.py b/funasr/cli/train_cli.py
index c62153e..a22d5d4 100644
--- a/funasr/cli/train_cli.py
+++ b/funasr/cli/train_cli.py
@@ -35,8 +35,9 @@
 @hydra.main(config_name=None, version_base=None)
 def main_hydra(kwargs: DictConfig):
 	import pdb; pdb.set_trace()
-	if kwargs.get("model_pretrain"):
-		kwargs = download_model(**kwargs)
+	if ":" in kwargs["model"]:
+		logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
+		kwargs = download_model(is_training=kwargs.get("is_training", True), **kwargs)
 	
 	import pdb;
 	pdb.set_trace()
@@ -84,8 +85,7 @@
 	# init_param
 	init_param = kwargs.get("init_param", None)
 	if init_param is not None:
-		init_param = init_param
-		if isinstance(init_param, Sequence):
+		if not isinstance(init_param, Sequence):
 			init_param = (init_param,)
 		logging.info("init_param is not None: %s", init_param)
 		for p in init_param:
diff --git a/funasr/datasets/dataset_jsonl.py b/funasr/datasets/dataset_jsonl.py
index 7f2cd83..21df89e 100644
--- a/funasr/datasets/dataset_jsonl.py
+++ b/funasr/datasets/dataset_jsonl.py
@@ -8,33 +8,7 @@
 import time
 import logging
 
-def load_audio(audio_path: str, fs: int=16000):
-	audio = None
-	if audio_path.startswith("oss:"):
-		pass
-	elif audio_path.startswith("odps:"):
-		pass
-	else:
-		if ".ark:" in audio_path:
-			audio = kaldiio.load_mat(audio_path)
-		else:
-			# audio, fs = librosa.load(audio_path, sr=fs)
-			audio, fs = torchaudio.load(audio_path)
-			audio = audio[0, :]
-	return audio
-
-def extract_features(data, date_type: str="sound", frontend=None):
-	if date_type == "sound":
-
-		if isinstance(data, np.ndarray):
-			data = torch.from_numpy(data).to(torch.float32)
-		data_len = torch.tensor([data.shape[0]]).to(torch.int32)
-		feat, feats_lens = frontend(data[None, :], data_len)
-
-		feat = feat[0, :, :]
-	else:
-		feat, feats_lens = torch.from_numpy(data).to(torch.float32), torch.tensor([data.shape[0]]).to(torch.int32)
-	return feat, feats_lens
+from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
 	
 	
 
@@ -115,17 +89,16 @@
 	
 	def __getitem__(self, index):
 		item = self.indexed_dataset[index]
-		# return item
 
 		source = item["source"]
 		data_src = load_audio(source, fs=self.fs)
-		speech, speech_lengths = extract_features(data_src, self.data_type, self.frontend)
+		speech, speech_lengths = extract_fbank(data_src, self.data_type, self.frontend) # speech: [b, T, d]
 		target = item["target"]
 		ids = self.tokenizer.encode(target)
 		ids_lengths = len(ids)
 		text, text_lengths = torch.tensor(ids, dtype=torch.int64), torch.tensor([ids_lengths], dtype=torch.int32)
 
-		return {"speech": speech,
+		return {"speech": speech[0, :, :],
 		        "speech_lengths": speech_lengths,
 		        "text": text,
 		        "text_lengths": text_lengths,
diff --git a/funasr/build_utils/__init__.py b/funasr/datasets/fun_datasets/__init__.py
similarity index 100%
copy from funasr/build_utils/__init__.py
copy to funasr/datasets/fun_datasets/__init__.py
diff --git a/funasr/datasets/fun_datasets/load_audio_extract_fbank.py b/funasr/datasets/fun_datasets/load_audio_extract_fbank.py
new file mode 100644
index 0000000..c76f346
--- /dev/null
+++ b/funasr/datasets/fun_datasets/load_audio_extract_fbank.py
@@ -0,0 +1,75 @@
+import os
+import torch
+import json
+import torch.distributed as dist
+import numpy as np
+import kaldiio
+import librosa
+import torchaudio
+import time
+import logging
+from torch.nn.utils.rnn import pad_sequence
+
+def load_audio(audio_or_path_or_list, fs: int=16000, audio_fs: int=16000):
+
+	if isinstance(audio_or_path_or_list, (list, tuple)):
+		return [load_audio(audio, fs=fs, audio_fs=audio_fs) for audio in audio_or_path_or_list]
+	
+	if isinstance(audio_or_path_or_list, str) and os.path.exists(audio_or_path_or_list):
+		audio_or_path_or_list, audio_fs = torchaudio.load(audio_or_path_or_list)
+		audio_or_path_or_list = audio_or_path_or_list[0, :]
+	elif isinstance(audio_or_path_or_list, np.ndarray): # audio sample point
+		audio_or_path_or_list = np.squeeze(audio_or_path_or_list) #[n_samples,]
+		
+	if audio_fs != fs:
+		resampler = torchaudio.transforms.Resample(audio_fs, fs)
+		resampled_waveform = resampler(audio_or_path_or_list[None, :])[0, :]
+	return audio_or_path_or_list
+#
+# def load_audio_from_list(audio_list, fs: int=16000, audio_fs: int=16000):
+# 	if isinstance(audio_list, (list, tuple)):
+# 		return [load_audio(audio_or_path, fs=fs, audio_fs=audio_fs) for audio_or_path in audio_list]
+
+
+def load_bytes(input):
+	middle_data = np.frombuffer(input, dtype=np.int16)
+	middle_data = np.asarray(middle_data)
+	if middle_data.dtype.kind not in 'iu':
+		raise TypeError("'middle_data' must be an array of integers")
+	dtype = np.dtype('float32')
+	if dtype.kind != 'f':
+		raise TypeError("'dtype' must be a floating point type")
+	
+	i = np.iinfo(middle_data.dtype)
+	abs_max = 2 ** (i.bits - 1)
+	offset = i.min + abs_max
+	array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
+	return array
+
+def extract_fbank(data, data_len = None, date_type: str="sound", frontend=None):
+	
+	if isinstance(data, np.ndarray):
+		data = torch.from_numpy(data)
+		if len(data) < 2:
+			data = data[None, :] # data: [batch, N]
+		data_len = [data.shape[1]] if data_len is None else data_len
+	elif isinstance(data, torch.Tensor):
+		if len(data) < 2:
+			data = data[None, :] # data: [batch, N]
+		data_len = [data.shape[1]] if data_len is None else data_len
+	elif isinstance(data, (list, tuple)):
+		data_list, data_len = [], []
+		for data_i in data:
+			if isinstance(data, np.ndarray):
+				data_i = torch.from_numpy(data_i)
+			data_list.append(data_i)
+			data_len.append(data_i.shape[0])
+		data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
+	# import pdb;
+	# pdb.set_trace()
+	if date_type == "sound":
+		data, data_len = frontend(data, data_len)
+	
+	if isinstance(data_len, (list, tuple)):
+		data_len = torch.tensor([data_len])
+	return data.to(torch.float32), data_len.to(torch.int32)
\ No newline at end of file
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index f92f322..ac16065 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -116,7 +116,7 @@
     def forward(
             self,
             input: torch.Tensor,
-            input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+            input_lengths) -> Tuple[torch.Tensor, torch.Tensor]:
         batch_size = input.size(0)
         feats = []
         feats_lens = []
diff --git a/funasr/build_utils/__init__.py b/funasr/models/paraformer/__init__.py
similarity index 100%
rename from funasr/build_utils/__init__.py
rename to funasr/models/paraformer/__init__.py
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
new file mode 100644
index 0000000..75b36a9
--- /dev/null
+++ b/funasr/models/paraformer/model.py
@@ -0,0 +1,1760 @@
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+import tempfile
+import codecs
+import requests
+import re
+import copy
+import torch
+import torch.nn as nn
+import random
+import numpy as np
+import time
+# from funasr.layers.abs_normalize import AbsNormalize
+from funasr.losses.label_smoothing_loss import (
+	LabelSmoothingLoss,  # noqa: H301
+)
+# from funasr.models.ctc import CTC
+# from funasr.models.decoder.abs_decoder import AbsDecoder
+# from funasr.models.e2e_asr_common import ErrorCalculator
+# from funasr.models.encoder.abs_encoder import AbsEncoder
+# from funasr.models.frontend.abs_frontend import AbsFrontend
+# from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
+from funasr.models.predictor.cif import mae_loss
+# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+# from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.modules.nets_utils import make_pad_mask, pad_list
+from funasr.modules.nets_utils import th_accuracy
+from funasr.torch_utils.device_funcs import force_gatherable
+# from funasr.models.base_model import FunASRModel
+# from funasr.models.predictor.cif import CifPredictorV3
+from funasr.models.paraformer.search import Hypothesis
+
+from funasr.cli.model_class_factory import *
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+	from torch.cuda.amp import autocast
+else:
+	# Nothing to do if torch<1.6.0
+	@contextmanager
+	def autocast(enabled=True):
+		yield
+from funasr.datasets.fun_datasets.load_audio_extract_fbank import load_audio, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+
+class Paraformer(nn.Module):
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+	https://arxiv.org/abs/2206.08317
+	"""
+	
+	def __init__(
+		self,
+		# token_list: Union[Tuple[str, ...], List[str]],
+		frontend: Optional[str] = None,
+		frontend_conf: Optional[Dict] = None,
+		specaug: Optional[str] = None,
+		specaug_conf: Optional[Dict] = None,
+		normalize: str = None,
+		normalize_conf: Optional[Dict] = None,
+		encoder: str = None,
+		encoder_conf: Optional[Dict] = None,
+		decoder: str = None,
+		decoder_conf: Optional[Dict] = None,
+		ctc: str = None,
+		ctc_conf: Optional[Dict] = None,
+		predictor: str = None,
+		predictor_conf: Optional[Dict] = None,
+		ctc_weight: float = 0.5,
+		input_size: int = 80,
+		vocab_size: int = -1,
+		ignore_id: int = -1,
+		blank_id: int = 0,
+		sos: int = 1,
+		eos: int = 2,
+		lsm_weight: float = 0.0,
+		length_normalized_loss: bool = False,
+		# report_cer: bool = True,
+		# report_wer: bool = True,
+		# sym_space: str = "<space>",
+		# sym_blank: str = "<blank>",
+		# extract_feats_in_collect_stats: bool = True,
+		# predictor=None,
+		predictor_weight: float = 0.0,
+		predictor_bias: int = 0,
+		sampling_ratio: float = 0.2,
+		share_embedding: bool = False,
+		# preencoder: Optional[AbsPreEncoder] = None,
+		# postencoder: Optional[AbsPostEncoder] = None,
+		use_1st_decoder_loss: bool = False,
+		**kwargs,
+	):
+
+		super().__init__()
+		
+		# import pdb;
+		# pdb.set_trace()
+		
+		if frontend is not None:
+			frontend_class = frontend_choices.get_class(frontend)
+			frontend = frontend_class(**frontend_conf)
+		if specaug is not None:
+			specaug_class = specaug_choices.get_class(specaug)
+			specaug = specaug_class(**specaug_conf)
+		if normalize is not None:
+			normalize_class = normalize_choices.get_class(normalize)
+			normalize = normalize_class(**normalize_conf)
+		encoder_class = encoder_choices.get_class(encoder)
+		encoder = encoder_class(input_size=input_size, **encoder_conf)
+		encoder_output_size = encoder.output_size()
+		if decoder is not None:
+			decoder_class = decoder_choices.get_class(decoder)
+			decoder = decoder_class(
+				vocab_size=vocab_size,
+				encoder_output_size=encoder_output_size,
+				**decoder_conf,
+			)
+		if ctc_weight > 0.0:
+			
+			if ctc_conf is None:
+				ctc_conf = {}
+			
+			ctc = CTC(
+				odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+			)
+		if predictor is not None:
+			predictor_class = predictor_choices.get_class(predictor)
+			predictor = predictor_class(**predictor_conf)
+		
+		# note that eos is the same as sos (equivalent ID)
+		self.blank_id = blank_id
+		self.sos = sos if sos is not None else vocab_size - 1
+		self.eos = eos if eos is not None else vocab_size - 1
+		self.vocab_size = vocab_size
+		self.ignore_id = ignore_id
+		self.ctc_weight = ctc_weight
+		# self.token_list = token_list.copy()
+		#
+		self.frontend = frontend
+		self.specaug = specaug
+		self.normalize = normalize
+		# self.preencoder = preencoder
+		# self.postencoder = postencoder
+		self.encoder = encoder
+		#
+		# if not hasattr(self.encoder, "interctc_use_conditioning"):
+		# 	self.encoder.interctc_use_conditioning = False
+		# if self.encoder.interctc_use_conditioning:
+		# 	self.encoder.conditioning_layer = torch.nn.Linear(
+		# 		vocab_size, self.encoder.output_size()
+		# 	)
+		#
+		# self.error_calculator = None
+		#
+		if ctc_weight == 1.0:
+			self.decoder = None
+		else:
+			self.decoder = decoder
+		
+		self.criterion_att = LabelSmoothingLoss(
+			size=vocab_size,
+			padding_idx=ignore_id,
+			smoothing=lsm_weight,
+			normalize_length=length_normalized_loss,
+		)
+		#
+		# if report_cer or report_wer:
+		# 	self.error_calculator = ErrorCalculator(
+		# 		token_list, sym_space, sym_blank, report_cer, report_wer
+		# 	)
+		#
+		if ctc_weight == 0.0:
+			self.ctc = None
+		else:
+			self.ctc = ctc
+		#
+		# self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+		self.predictor = predictor
+		self.predictor_weight = predictor_weight
+		self.predictor_bias = predictor_bias
+		self.sampling_ratio = sampling_ratio
+		self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+		# self.step_cur = 0
+		#
+		self.share_embedding = share_embedding
+		if self.share_embedding:
+			self.decoder.embed = None
+		
+		self.use_1st_decoder_loss = use_1st_decoder_loss
+		self.length_normalized_loss = length_normalized_loss
+		self.beam_search = None
+	
+	def forward(
+		self,
+		speech: torch.Tensor,
+		speech_lengths: torch.Tensor,
+		text: torch.Tensor,
+		text_lengths: torch.Tensor,
+		**kwargs,
+	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+		"""Encoder + Decoder + Calc loss
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				text: (Batch, Length)
+				text_lengths: (Batch,)
+		"""
+		# import pdb;
+		# pdb.set_trace()
+		if len(text_lengths.size()) > 1:
+			text_lengths = text_lengths[:, 0]
+		if len(speech_lengths.size()) > 1:
+			speech_lengths = speech_lengths[:, 0]
+		
+		batch_size = speech.shape[0]
+		
+		
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+		
+		loss_ctc, cer_ctc = None, None
+		loss_pre = None
+		stats = dict()
+		
+		# decoder: CTC branch
+		if self.ctc_weight != 0.0:
+			loss_ctc, cer_ctc = self._calc_ctc_loss(
+				encoder_out, encoder_out_lens, text, text_lengths
+			)
+			
+			# Collect CTC branch stats
+			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+			stats["cer_ctc"] = cer_ctc
+		
+
+		# decoder: Attention decoder branch
+		loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
+			encoder_out, encoder_out_lens, text, text_lengths
+		)
+		
+		# 3. CTC-Att loss definition
+		if self.ctc_weight == 0.0:
+			loss = loss_att + loss_pre * self.predictor_weight
+		else:
+			loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+		
+		
+		# Collect Attn branch stats
+		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+		stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+		stats["acc"] = acc_att
+		stats["cer"] = cer_att
+		stats["wer"] = wer_att
+		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+		
+		stats["loss"] = torch.clone(loss.detach())
+		
+		# force_gatherable: to-device and to-tensor if scalar for DataParallel
+		if self.length_normalized_loss:
+			batch_size = (text_lengths + self.predictor_bias).sum()
+		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+		return loss, stats, weight
+	
+
+	def encode(
+		self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+	) -> Tuple[torch.Tensor, torch.Tensor]:
+		"""Frontend + Encoder. Note that this method is used by asr_inference.py
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				ind: int
+		"""
+		with autocast(False):
+
+			# Data augmentation
+			if self.specaug is not None and self.training:
+				speech, speech_lengths = self.specaug(speech, speech_lengths)
+			
+			# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+			if self.normalize is not None:
+				speech, speech_lengths = self.normalize(speech, speech_lengths)
+		
+
+		# Forward encoder
+		encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+
+		return encoder_out, encoder_out_lens
+	
+	def calc_predictor(self, encoder_out, encoder_out_lens):
+		
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(encoder_out, None,
+		                                                                               encoder_out_mask,
+		                                                                               ignore_id=self.ignore_id)
+		return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+	
+	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+		
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+		)
+		decoder_out = decoder_outs[0]
+		decoder_out = torch.log_softmax(decoder_out, dim=-1)
+		return decoder_out, ys_pad_lens
+
+	def _calc_att_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+		                                                                          ignore_id=self.ignore_id)
+		
+		# 0. sampler
+		decoder_out_1st = None
+		pre_loss_att = None
+		if self.sampling_ratio > 0.0:
+
+			sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+			                                               pre_acoustic_embeds)
+		else:
+			sematic_embeds = pre_acoustic_embeds
+		
+		# 1. Forward decoder
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+		)
+		decoder_out, _ = decoder_outs[0], decoder_outs[1]
+		
+		if decoder_out_1st is None:
+			decoder_out_1st = decoder_out
+		# 2. Compute attention loss
+		loss_att = self.criterion_att(decoder_out, ys_pad)
+		acc_att = th_accuracy(
+			decoder_out_1st.view(-1, self.vocab_size),
+			ys_pad,
+			ignore_label=self.ignore_id,
+		)
+		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		
+		# Compute cer/wer using attention-decoder
+		if self.training or self.error_calculator is None:
+			cer_att, wer_att = None, None
+		else:
+			ys_hat = decoder_out_1st.argmax(dim=-1)
+			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+		
+		return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+	
+	def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
+		
+		tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+		ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+		if self.share_embedding:
+			ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+		else:
+			ys_pad_embed = self.decoder.embed(ys_pad_masked)
+		with torch.no_grad():
+			decoder_outs = self.decoder(
+				encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
+			)
+			decoder_out, _ = decoder_outs[0], decoder_outs[1]
+			pred_tokens = decoder_out.argmax(-1)
+			nonpad_positions = ys_pad.ne(self.ignore_id)
+			seq_lens = (nonpad_positions).sum(1)
+			same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+			input_mask = torch.ones_like(nonpad_positions)
+			bsz, seq_len = ys_pad.size()
+			for li in range(bsz):
+				target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+				if target_num > 0:
+					input_mask[li].scatter_(dim=0,
+					                        index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
+					                        value=0)
+			input_mask = input_mask.eq(1)
+			input_mask = input_mask.masked_fill(~nonpad_positions, False)
+			input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+		
+		sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+			input_mask_expand_dim, 0)
+		return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+		
+	def _calc_ctc_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		# Calc CTC loss
+		loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+		
+		# Calc CER using CTC
+		cer_ctc = None
+		if not self.training and self.error_calculator is not None:
+			ys_hat = self.ctc.argmax(encoder_out).data
+			cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+		return loss_ctc, cer_ctc
+
+	
+	def init_beam_search(self,
+	                     **kwargs,
+	                     ):
+		from funasr.models.paraformer.search import BeamSearchPara
+		from funasr.modules.scorers.ctc import CTCPrefixScorer
+		from funasr.modules.scorers.length_bonus import LengthBonus
+	
+		# 1. Build ASR model
+		scorers = {}
+		
+		if self.ctc != None:
+			ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+			scorers.update(
+				ctc=ctc
+			)
+		token_list = kwargs.get("token_list")
+		scorers.update(
+			length_bonus=LengthBonus(len(token_list)),
+		)
+
+		
+		# 3. Build ngram model
+		# ngram is not supported now
+		ngram = None
+		scorers["ngram"] = ngram
+		
+		weights = dict(
+			decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+			ctc=kwargs.get("decoding_ctc_weight", 0.0),
+			lm=kwargs.get("lm_weight", 0.0),
+			ngram=kwargs.get("ngram_weight", 0.0),
+			length_bonus=kwargs.get("penalty", 0.0),
+		)
+		beam_search = BeamSearchPara(
+			beam_size=kwargs.get("beam_size", 2),
+			weights=weights,
+			scorers=scorers,
+			sos=self.sos,
+			eos=self.eos,
+			vocab_size=len(token_list),
+			token_list=token_list,
+			pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+		)
+		# beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+		# for scorer in scorers.values():
+		# 	if isinstance(scorer, torch.nn.Module):
+		# 		scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+		self.beam_search = beam_search
+		
+	def generate(self,
+             data_in: list,
+             data_lengths: list=None,
+             key: list=None,
+             tokenizer=None,
+             **kwargs,
+             ):
+		
+		# init beamsearch
+		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		if self.beam_search is None and (is_use_lm or is_use_ctc):
+			logging.info("enable beam_search")
+			self.init_beam_search(**kwargs)
+			self.nbest = kwargs.get("nbest", 1)
+		
+		meta_data = {}
+		# extract fbank feats
+		time1 = time.perf_counter()
+		audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+		time2 = time.perf_counter()
+		meta_data["load_data"] = f"{time2 - time1:0.3f}"
+		speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"), frontend=self.frontend)
+		time3 = time.perf_counter()
+		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+		meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+		
+		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+		
+		# predictor
+		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+		                                                                predictor_outs[2], predictor_outs[3]
+		pre_token_length = pre_token_length.round().long()
+		if torch.max(pre_token_length) < 1:
+			return []
+		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+		                                                         pre_token_length)
+		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+
+		results = []
+		b, n, d = decoder_out.size()
+		for i in range(b):
+			x = encoder_out[i, :encoder_out_lens[i], :]
+			am_scores = decoder_out[i, :pre_token_length[i], :]
+			if self.beam_search is not None:
+				nbest_hyps = self.beam_search(
+					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+				)
+				
+				nbest_hyps = nbest_hyps[: self.nbest]
+			else:
+
+				yseq = am_scores.argmax(dim=-1)
+				score = am_scores.max(dim=-1)[0]
+				score = torch.sum(score, dim=-1)
+				# pad with mask tokens to ensure compatibility with sos/eos tokens
+				yseq = torch.tensor(
+					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+				)
+				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+			for nbest_idx, hyp in enumerate(nbest_hyps):
+				ibest_writer = None
+				if ibest_writer is None and kwargs.get("output_dir") is not None:
+					writer = DatadirWriter(kwargs.get("output_dir"))
+					ibest_writer = writer[f"{nbest_idx+1}best_recog"]
+				# remove sos/eos and get results
+				last_pos = -1
+				if isinstance(hyp.yseq, list):
+					token_int = hyp.yseq[1:last_pos]
+				else:
+					token_int = hyp.yseq[1:last_pos].tolist()
+					
+				# remove blank symbol id, which is assumed to be 0
+				token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+				
+				# Change integer-ids to tokens
+				token = tokenizer.ids2tokens(token_int)
+				text = tokenizer.tokens2text(token)
+				
+				text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+				result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+				results.append(result_i)
+				
+				if ibest_writer is not None:
+					ibest_writer["token"][key[i]] = " ".join(token)
+					ibest_writer["text"][key[i]] = text
+					ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+		
+		return results, meta_data
+
+
+
+class BiCifParaformer(Paraformer):
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+	https://arxiv.org/abs/2206.08317
+	"""
+	
+	def __init__(
+		self,
+		*args,
+		**kwargs,
+	):
+		super().__init__(*args, **kwargs)
+		assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
+
+
+	def _calc_pre2_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		_, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
+		
+		# loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
+		
+		return loss_pre2
+	
+	
+	def _calc_att_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		pre_acoustic_embeds, pre_token_length, _, pre_peak_index, _ = self.predictor(encoder_out, ys_pad,
+		                                                                             encoder_out_mask,
+		                                                                             ignore_id=self.ignore_id)
+		
+		# 0. sampler
+		decoder_out_1st = None
+		if self.sampling_ratio > 0.0:
+			sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+			                                               pre_acoustic_embeds)
+		else:
+			sematic_embeds = pre_acoustic_embeds
+		
+		# 1. Forward decoder
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+		)
+		decoder_out, _ = decoder_outs[0], decoder_outs[1]
+		
+		if decoder_out_1st is None:
+			decoder_out_1st = decoder_out
+		# 2. Compute attention loss
+		loss_att = self.criterion_att(decoder_out, ys_pad)
+		acc_att = th_accuracy(
+			decoder_out_1st.view(-1, self.vocab_size),
+			ys_pad,
+			ignore_label=self.ignore_id,
+		)
+		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		
+		# Compute cer/wer using attention-decoder
+		if self.training or self.error_calculator is None:
+			cer_att, wer_att = None, None
+		else:
+			ys_hat = decoder_out_1st.argmax(dim=-1)
+			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+		
+		return loss_att, acc_att, cer_att, wer_att, loss_pre
+
+
+	def calc_predictor(self, encoder_out, encoder_out_lens):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+		                                                                                                  None,
+		                                                                                                  encoder_out_mask,
+		                                                                                                  ignore_id=self.ignore_id)
+		return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+
+	def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
+		                                                                                    encoder_out_mask,
+		                                                                                    token_num)
+		return ds_alphas, ds_cif_peak, us_alphas, us_peaks
+	
+	
+	def forward(
+		self,
+		speech: torch.Tensor,
+		speech_lengths: torch.Tensor,
+		text: torch.Tensor,
+		text_lengths: torch.Tensor,
+		**kwargs,
+	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+		"""Frontend + Encoder + Decoder + Calc loss
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				text: (Batch, Length)
+				text_lengths: (Batch,)
+		"""
+		if len(text_lengths.size()) > 1:
+			text_lengths = text_lengths[:, 0]
+		if len(speech_lengths.size()) > 1:
+			speech_lengths = speech_lengths[:, 0]
+		
+		batch_size = speech.shape[0]
+		
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+
+		loss_ctc, cer_ctc = None, None
+		loss_pre = None
+		stats = dict()
+		
+		# decoder: CTC branch
+		if self.ctc_weight != 0.0:
+			loss_ctc, cer_ctc = self._calc_ctc_loss(
+				encoder_out, encoder_out_lens, text, text_lengths
+			)
+			
+			# Collect CTC branch stats
+			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+			stats["cer_ctc"] = cer_ctc
+
+
+		# decoder: Attention decoder branch
+		loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
+			encoder_out, encoder_out_lens, text, text_lengths
+		)
+		
+		loss_pre2 = self._calc_pre2_loss(
+			encoder_out, encoder_out_lens, text, text_lengths
+		)
+		
+		# 3. CTC-Att loss definition
+		if self.ctc_weight == 0.0:
+			loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+		else:
+			loss = self.ctc_weight * loss_ctc + (
+				1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+		
+		# Collect Attn branch stats
+		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+		stats["acc"] = acc_att
+		stats["cer"] = cer_att
+		stats["wer"] = wer_att
+		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+		stats["loss_pre2"] = loss_pre2.detach().cpu()
+		
+		stats["loss"] = torch.clone(loss.detach())
+		
+		# force_gatherable: to-device and to-tensor if scalar for DataParallel
+		if self.length_normalized_loss:
+			batch_size = int((text_lengths + self.predictor_bias).sum())
+		
+		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+		return loss, stats, weight
+	
+	def generate(self,
+	             data_in: list,
+	             data_lengths: list = None,
+	             key: list = None,
+	             tokenizer=None,
+	             **kwargs,
+	             ):
+		
+		# init beamsearch
+		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		if self.beam_search is None and (is_use_lm or is_use_ctc):
+			logging.info("enable beam_search")
+			self.init_beam_search(**kwargs)
+			self.nbest = kwargs.get("nbest", 1)
+		
+		meta_data = {}
+		# extract fbank feats
+		time1 = time.perf_counter()
+		audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+		time2 = time.perf_counter()
+		meta_data["load_data"] = f"{time2 - time1:0.3f}"
+		speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"),
+		                                       frontend=self.frontend)
+		time3 = time.perf_counter()
+		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+		meta_data[
+			"batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+		
+		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+		
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+		
+		# predictor
+		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+		                                                                predictor_outs[2], predictor_outs[3]
+		pre_token_length = pre_token_length.round().long()
+		if torch.max(pre_token_length) < 1:
+			return []
+		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+		                                               pre_token_length)
+		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+		
+		# BiCifParaformer, test no bias cif2
+
+		_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
+			                                                                    pre_token_length)
+		
+		results = []
+		b, n, d = decoder_out.size()
+		for i in range(b):
+			x = encoder_out[i, :encoder_out_lens[i], :]
+			am_scores = decoder_out[i, :pre_token_length[i], :]
+			if self.beam_search is not None:
+				nbest_hyps = self.beam_search(
+					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+					minlenratio=kwargs.get("minlenratio", 0.0)
+				)
+				
+				nbest_hyps = nbest_hyps[: self.nbest]
+			else:
+				
+				yseq = am_scores.argmax(dim=-1)
+				score = am_scores.max(dim=-1)[0]
+				score = torch.sum(score, dim=-1)
+				# pad with mask tokens to ensure compatibility with sos/eos tokens
+				yseq = torch.tensor(
+					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+				)
+				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+			for nbest_idx, hyp in enumerate(nbest_hyps):
+				ibest_writer = None
+				if ibest_writer is None and kwargs.get("output_dir") is not None:
+					writer = DatadirWriter(kwargs.get("output_dir"))
+					ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+				# remove sos/eos and get results
+				last_pos = -1
+				if isinstance(hyp.yseq, list):
+					token_int = hyp.yseq[1:last_pos]
+				else:
+					token_int = hyp.yseq[1:last_pos].tolist()
+				
+				# remove blank symbol id, which is assumed to be 0
+				token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+				
+				# Change integer-ids to tokens
+				token = tokenizer.ids2tokens(token_int)
+				text = tokenizer.tokens2text(token)
+				
+				_, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
+				                                           us_peaks[i][:encoder_out_lens[i] * 3],
+				                                           copy.copy(token),
+				                                           vad_offset=kwargs.get("begin_time", 0))
+				
+				text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token, timestamp)
+				
+				result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
+				            "time_stamp_postprocessed": time_stamp_postprocessed,
+				            "word_lists": word_lists
+				            }
+				results.append(result_i)
+				
+				if ibest_writer is not None:
+					ibest_writer["token"][key[i]] = " ".join(token)
+					ibest_writer["text"][key[i]] = text
+					ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+					
+		
+		return results, meta_data
+
+
+class NeatContextualParaformer(Paraformer):
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+	https://arxiv.org/abs/2206.08317
+	"""
+	
+	def __init__(
+		self,
+		*args,
+		**kwargs,
+	):
+		super().__init__(*args, **kwargs)
+		
+		self.target_buffer_length = kwargs.get("target_buffer_length", -1)
+		inner_dim = kwargs.get("inner_dim", 256)
+		bias_encoder_type = kwargs.get("bias_encoder_type", "lstm")
+		use_decoder_embedding = kwargs.get("use_decoder_embedding", False)
+		crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
+		crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
+		bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
+
+
+		if bias_encoder_type == 'lstm':
+			logging.warning("enable bias encoder sampling and contextual training")
+			self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
+			self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
+		elif bias_encoder_type == 'mean':
+			logging.warning("enable bias encoder sampling and contextual training")
+			self.bias_embed = torch.nn.Embedding(self.vocab_size, inner_dim)
+		else:
+			logging.error("Unsupport bias encoder type: {}".format(bias_encoder_type))
+		
+		if self.target_buffer_length > 0:
+			self.hotword_buffer = None
+			self.length_record = []
+			self.current_buffer_length = 0
+		self.use_decoder_embedding = use_decoder_embedding
+		self.crit_attn_weight = crit_attn_weight
+		if self.crit_attn_weight > 0:
+			self.attn_loss = torch.nn.L1Loss()
+		self.crit_attn_smooth = crit_attn_smooth
+
+
+	def forward(
+		self,
+		speech: torch.Tensor,
+		speech_lengths: torch.Tensor,
+		text: torch.Tensor,
+		text_lengths: torch.Tensor,
+		hotword_pad: torch.Tensor,
+		hotword_lengths: torch.Tensor,
+		dha_pad: torch.Tensor,
+		**kwargs,
+	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+		"""Frontend + Encoder + Decoder + Calc loss
+	
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				text: (Batch, Length)
+				text_lengths: (Batch,)
+		"""
+		if len(text_lengths.size()) > 1:
+			text_lengths = text_lengths[:, 0]
+		if len(speech_lengths.size()) > 1:
+			speech_lengths = speech_lengths[:, 0]
+		
+		batch_size = speech.shape[0]
+
+		# 1. Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+		
+		loss_ctc, cer_ctc = None, None
+		
+		stats = dict()
+		
+		# 1. CTC branch
+		if self.ctc_weight != 0.0:
+			loss_ctc, cer_ctc = self._calc_ctc_loss(
+				encoder_out, encoder_out_lens, text, text_lengths
+			)
+			
+			# Collect CTC branch stats
+			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+			stats["cer_ctc"] = cer_ctc
+		
+
+		# 2b. Attention decoder branch
+		loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
+			encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
+		)
+		
+		# 3. CTC-Att loss definition
+		if self.ctc_weight == 0.0:
+			loss = loss_att + loss_pre * self.predictor_weight
+		else:
+			loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+		
+		if loss_ideal is not None:
+			loss = loss + loss_ideal * self.crit_attn_weight
+			stats["loss_ideal"] = loss_ideal.detach().cpu()
+		
+		# Collect Attn branch stats
+		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+		stats["acc"] = acc_att
+		stats["cer"] = cer_att
+		stats["wer"] = wer_att
+		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+		
+		stats["loss"] = torch.clone(loss.detach())
+		# force_gatherable: to-device and to-tensor if scalar for DataParallel
+		if self.length_normalized_loss:
+			batch_size = int((text_lengths + self.predictor_bias).sum())
+		
+		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+		return loss, stats, weight
+	
+	
+	def _calc_att_clas_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+		hotword_pad: torch.Tensor,
+		hotword_lengths: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
+		                                                             ignore_id=self.ignore_id)
+		
+		# -1. bias encoder
+		if self.use_decoder_embedding:
+			hw_embed = self.decoder.embed(hotword_pad)
+		else:
+			hw_embed = self.bias_embed(hotword_pad)
+		hw_embed, (_, _) = self.bias_encoder(hw_embed)
+		_ind = np.arange(0, hotword_pad.shape[0]).tolist()
+		selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
+		contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
+		
+		# 0. sampler
+		decoder_out_1st = None
+		if self.sampling_ratio > 0.0:
+			if self.step_cur < 2:
+				logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+			sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+			                                               pre_acoustic_embeds, contextual_info)
+		else:
+			if self.step_cur < 2:
+				logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+			sematic_embeds = pre_acoustic_embeds
+		
+		# 1. Forward decoder
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+		)
+		decoder_out, _ = decoder_outs[0], decoder_outs[1]
+		'''
+		if self.crit_attn_weight > 0 and attn.shape[-1] > 1:
+			ideal_attn = ideal_attn + self.crit_attn_smooth / (self.crit_attn_smooth + 1.0)
+			attn_non_blank = attn[:,:,:,:-1]
+			ideal_attn_non_blank = ideal_attn[:,:,:-1]
+			loss_ideal = self.attn_loss(attn_non_blank.max(1)[0], ideal_attn_non_blank.to(attn.device))
+		else:
+			loss_ideal = None
+		'''
+		loss_ideal = None
+		
+		if decoder_out_1st is None:
+			decoder_out_1st = decoder_out
+		# 2. Compute attention loss
+		loss_att = self.criterion_att(decoder_out, ys_pad)
+		acc_att = th_accuracy(
+			decoder_out_1st.view(-1, self.vocab_size),
+			ys_pad,
+			ignore_label=self.ignore_id,
+		)
+		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		
+		# Compute cer/wer using attention-decoder
+		if self.training or self.error_calculator is None:
+			cer_att, wer_att = None, None
+		else:
+			ys_hat = decoder_out_1st.argmax(dim=-1)
+			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+		
+		return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal
+	
+	
+	def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
+		tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+		ys_pad = ys_pad * tgt_mask[:, :, 0]
+		if self.share_embedding:
+			ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
+		else:
+			ys_pad_embed = self.decoder.embed(ys_pad)
+		with torch.no_grad():
+			decoder_outs = self.decoder(
+				encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
+			)
+			decoder_out, _ = decoder_outs[0], decoder_outs[1]
+			pred_tokens = decoder_out.argmax(-1)
+			nonpad_positions = ys_pad.ne(self.ignore_id)
+			seq_lens = (nonpad_positions).sum(1)
+			same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+			input_mask = torch.ones_like(nonpad_positions)
+			bsz, seq_len = ys_pad.size()
+			for li in range(bsz):
+				target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+				if target_num > 0:
+					input_mask[li].scatter_(dim=0,
+					                        index=torch.randperm(seq_lens[li])[:target_num].to(pre_acoustic_embeds.device),
+					                        value=0)
+			input_mask = input_mask.eq(1)
+			input_mask = input_mask.masked_fill(~nonpad_positions, False)
+			input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+		
+		sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+			input_mask_expand_dim, 0)
+		return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+	
+	
+	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None,
+	                               clas_scale=1.0):
+		if hw_list is None:
+			hw_list = [torch.Tensor([1]).long().to(encoder_out.device)]  # empty hotword list
+			hw_list_pad = pad_list(hw_list, 0)
+			if self.use_decoder_embedding:
+				hw_embed = self.decoder.embed(hw_list_pad)
+			else:
+				hw_embed = self.bias_embed(hw_list_pad)
+			hw_embed, (h_n, _) = self.bias_encoder(hw_embed)
+			hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+		else:
+			hw_lengths = [len(i) for i in hw_list]
+			hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
+			if self.use_decoder_embedding:
+				hw_embed = self.decoder.embed(hw_list_pad)
+			else:
+				hw_embed = self.bias_embed(hw_list_pad)
+			hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
+			                                                   enforce_sorted=False)
+			_, (h_n, _) = self.bias_encoder(hw_embed)
+			hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
+		
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
+		)
+		decoder_out = decoder_outs[0]
+		decoder_out = torch.log_softmax(decoder_out, dim=-1)
+		return decoder_out, ys_pad_lens
+		
+	def generate(self,
+	             data_in: list,
+	             data_lengths: list = None,
+	             key: list = None,
+	             tokenizer=None,
+	             **kwargs,
+	             ):
+		
+		# init beamsearch
+		is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		if self.beam_search is None and (is_use_lm or is_use_ctc):
+			logging.info("enable beam_search")
+			self.init_beam_search(**kwargs)
+			self.nbest = kwargs.get("nbest", 1)
+		
+		meta_data = {}
+		
+		# extract fbank feats
+		time1 = time.perf_counter()
+		audio_sample_list = load_audio(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000))
+		time2 = time.perf_counter()
+		meta_data["load_data"] = f"{time2 - time1:0.3f}"
+		speech, speech_lengths = extract_fbank(audio_sample_list, date_type=kwargs.get("date_type", "sound"),
+		                                       frontend=self.frontend)
+		time3 = time.perf_counter()
+		meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+		meta_data[
+			"batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000
+		
+		speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
+		# hotword
+		self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer)
+		
+		# Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+		
+		# predictor
+		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+		                                                                predictor_outs[2], predictor_outs[3]
+		pre_token_length = pre_token_length.round().long()
+		if torch.max(pre_token_length) < 1:
+			return []
+
+
+		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
+		                                                         pre_acoustic_embeds,
+		                                                         pre_token_length,
+		                                                         hw_list=self.hotword_list,
+		                                                         clas_scale=kwargs.get("clas_scale", 1.0))
+		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+		
+		results = []
+		b, n, d = decoder_out.size()
+		for i in range(b):
+			x = encoder_out[i, :encoder_out_lens[i], :]
+			am_scores = decoder_out[i, :pre_token_length[i], :]
+			if self.beam_search is not None:
+				nbest_hyps = self.beam_search(
+					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+					minlenratio=kwargs.get("minlenratio", 0.0)
+				)
+				
+				nbest_hyps = nbest_hyps[: self.nbest]
+			else:
+				
+				yseq = am_scores.argmax(dim=-1)
+				score = am_scores.max(dim=-1)[0]
+				score = torch.sum(score, dim=-1)
+				# pad with mask tokens to ensure compatibility with sos/eos tokens
+				yseq = torch.tensor(
+					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+				)
+				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+			for nbest_idx, hyp in enumerate(nbest_hyps):
+				ibest_writer = None
+				if ibest_writer is None and kwargs.get("output_dir") is not None:
+					writer = DatadirWriter(kwargs.get("output_dir"))
+					ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
+				# remove sos/eos and get results
+				last_pos = -1
+				if isinstance(hyp.yseq, list):
+					token_int = hyp.yseq[1:last_pos]
+				else:
+					token_int = hyp.yseq[1:last_pos].tolist()
+				
+				# remove blank symbol id, which is assumed to be 0
+				token_int = list(
+					filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+				
+				# Change integer-ids to tokens
+				token = tokenizer.ids2tokens(token_int)
+				text = tokenizer.tokens2text(token)
+				
+				text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+				result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed}
+				results.append(result_i)
+				
+				if ibest_writer is not None:
+					ibest_writer["token"][key[i]] = " ".join(token)
+					ibest_writer["text"][key[i]] = text
+					ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+		
+		return results, meta_data
+
+
+	def generate_hotwords_list(self, hotword_list_or_file, tokenizer=None):
+		def load_seg_dict(seg_dict_file):
+			seg_dict = {}
+			assert isinstance(seg_dict_file, str)
+			with open(seg_dict_file, "r", encoding="utf8") as f:
+				lines = f.readlines()
+				for line in lines:
+					s = line.strip().split()
+					key = s[0]
+					value = s[1:]
+					seg_dict[key] = " ".join(value)
+			return seg_dict
+		
+		def seg_tokenize(txt, seg_dict):
+			pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+			out_txt = ""
+			for word in txt:
+				word = word.lower()
+				if word in seg_dict:
+					out_txt += seg_dict[word] + " "
+				else:
+					if pattern.match(word):
+						for char in word:
+							if char in seg_dict:
+								out_txt += seg_dict[char] + " "
+							else:
+								out_txt += "<unk>" + " "
+					else:
+						out_txt += "<unk>" + " "
+			return out_txt.strip().split()
+		
+		seg_dict = None
+		if self.frontend.cmvn_file is not None:
+			model_dir = os.path.dirname(self.frontend.cmvn_file)
+			seg_dict_file = os.path.join(model_dir, 'seg_dict')
+			if os.path.exists(seg_dict_file):
+				seg_dict = load_seg_dict(seg_dict_file)
+			else:
+				seg_dict = None
+		# for None
+		if hotword_list_or_file is None:
+			hotword_list = None
+		# for local txt inputs
+		elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
+			logging.info("Attempting to parse hotwords from local txt...")
+			hotword_list = []
+			hotword_str_list = []
+			with codecs.open(hotword_list_or_file, 'r') as fin:
+				for line in fin.readlines():
+					hw = line.strip()
+					hw_list = hw.split()
+					if seg_dict is not None:
+						hw_list = seg_tokenize(hw_list, seg_dict)
+					hotword_str_list.append(hw)
+					hotword_list.append(tokenizer.tokens2ids(hw_list))
+				hotword_list.append([self.sos])
+				hotword_str_list.append('<s>')
+			logging.info("Initialized hotword list from file: {}, hotword list: {}."
+			             .format(hotword_list_or_file, hotword_str_list))
+		# for url, download and generate txt
+		elif hotword_list_or_file.startswith('http'):
+			logging.info("Attempting to parse hotwords from url...")
+			work_dir = tempfile.TemporaryDirectory().name
+			if not os.path.exists(work_dir):
+				os.makedirs(work_dir)
+			text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
+			local_file = requests.get(hotword_list_or_file)
+			open(text_file_path, "wb").write(local_file.content)
+			hotword_list_or_file = text_file_path
+			hotword_list = []
+			hotword_str_list = []
+			with codecs.open(hotword_list_or_file, 'r') as fin:
+				for line in fin.readlines():
+					hw = line.strip()
+					hw_list = hw.split()
+					if seg_dict is not None:
+						hw_list = seg_tokenize(hw_list, seg_dict)
+					hotword_str_list.append(hw)
+					hotword_list.append(tokenizer.tokens2ids(hw_list))
+				hotword_list.append([self.sos])
+				hotword_str_list.append('<s>')
+			logging.info("Initialized hotword list from file: {}, hotword list: {}."
+			             .format(hotword_list_or_file, hotword_str_list))
+		# for text str input
+		elif not hotword_list_or_file.endswith('.txt'):
+			logging.info("Attempting to parse hotwords as str...")
+			hotword_list = []
+			hotword_str_list = []
+			for hw in hotword_list_or_file.strip().split():
+				hotword_str_list.append(hw)
+				hw_list = hw.strip().split()
+				if seg_dict is not None:
+					hw_list = seg_tokenize(hw_list, seg_dict)
+				hotword_list.append(tokenizer.tokens2ids(hw_list))
+			hotword_list.append([self.sos])
+			hotword_str_list.append('<s>')
+			logging.info("Hotword list: {}.".format(hotword_str_list))
+		else:
+			hotword_list = None
+		return hotword_list
+
+
+class ParaformerOnline(Paraformer):
+	"""
+	Author: Speech Lab of DAMO Academy, Alibaba Group
+	Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+	https://arxiv.org/abs/2206.08317
+	"""
+	
+	def __init__(
+		self,
+		*args,
+		**kwargs,
+	):
+		
+		super().__init__(*args, **kwargs)
+		
+		# import pdb;
+		# pdb.set_trace()
+		self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
+
+
+		self.scama_mask = None
+		if hasattr(self.encoder, "overlap_chunk_cls") and self.encoder.overlap_chunk_cls is not None:
+			from funasr.modules.streaming_utils.chunk_utilis import build_scama_mask_for_cross_attention_decoder
+			self.build_scama_mask_for_cross_attention_decoder_fn = build_scama_mask_for_cross_attention_decoder
+			self.decoder_attention_chunk_type = kwargs.get("decoder_attention_chunk_type", "chunk")
+
+
+	
+	def forward(
+		self,
+		speech: torch.Tensor,
+		speech_lengths: torch.Tensor,
+		text: torch.Tensor,
+		text_lengths: torch.Tensor,
+		**kwargs,
+	) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+		"""Encoder + Decoder + Calc loss
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				text: (Batch, Length)
+				text_lengths: (Batch,)
+		"""
+		# import pdb;
+		# pdb.set_trace()
+		decoding_ind = kwargs.get("decoding_ind")
+		if len(text_lengths.size()) > 1:
+			text_lengths = text_lengths[:, 0]
+		if len(speech_lengths.size()) > 1:
+			speech_lengths = speech_lengths[:, 0]
+		
+		batch_size = speech.shape[0]
+		
+		# Encoder
+		if hasattr(self.encoder, "overlap_chunk_cls"):
+			ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+			encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+		else:
+			encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		
+		loss_ctc, cer_ctc = None, None
+		loss_pre = None
+		stats = dict()
+		
+		# decoder: CTC branch
+
+		if self.ctc_weight > 0.0:
+			if hasattr(self.encoder, "overlap_chunk_cls"):
+				encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+				                                                                                    encoder_out_lens,
+				                                                                                    chunk_outs=None)
+			else:
+				encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
+				
+			loss_ctc, cer_ctc = self._calc_ctc_loss(
+				encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
+			)
+			# Collect CTC branch stats
+			stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+			stats["cer_ctc"] = cer_ctc
+		
+		# decoder: Attention decoder branch
+		loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_predictor_loss(
+			encoder_out, encoder_out_lens, text, text_lengths
+		)
+		
+		# 3. CTC-Att loss definition
+		if self.ctc_weight == 0.0:
+			loss = loss_att + loss_pre * self.predictor_weight
+		else:
+			loss = self.ctc_weight * loss_ctc + (
+					1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
+		
+		# Collect Attn branch stats
+		stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+		stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+		stats["acc"] = acc_att
+		stats["cer"] = cer_att
+		stats["wer"] = wer_att
+		stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+		
+		stats["loss"] = torch.clone(loss.detach())
+		
+		# force_gatherable: to-device and to-tensor if scalar for DataParallel
+		if self.length_normalized_loss:
+			batch_size = (text_lengths + self.predictor_bias).sum()
+		loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+		return loss, stats, weight
+	
+	def encode_chunk(
+		self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None, **kwargs,
+	) -> Tuple[torch.Tensor, torch.Tensor]:
+		"""Frontend + Encoder. Note that this method is used by asr_inference.py
+		Args:
+				speech: (Batch, Length, ...)
+				speech_lengths: (Batch, )
+				ind: int
+		"""
+		with autocast(False):
+			
+			# Data augmentation
+			if self.specaug is not None and self.training:
+				speech, speech_lengths = self.specaug(speech, speech_lengths)
+			
+			# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+			if self.normalize is not None:
+				speech, speech_lengths = self.normalize(speech, speech_lengths)
+		
+		# Forward encoder
+		encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(speech, speech_lengths, cache=cache["encoder"])
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+		
+		return encoder_out, torch.tensor([encoder_out.size(1)])
+	
+	def _calc_att_predictor_loss(
+		self,
+		encoder_out: torch.Tensor,
+		encoder_out_lens: torch.Tensor,
+		ys_pad: torch.Tensor,
+		ys_pad_lens: torch.Tensor,
+	):
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		if self.predictor_bias == 1:
+			_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+			ys_pad_lens = ys_pad_lens + self.predictor_bias
+		mask_chunk_predictor = None
+		if self.encoder.overlap_chunk_cls is not None:
+			mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+			                                                                               device=encoder_out.device,
+			                                                                               batch_size=encoder_out.size(
+				                                                                               0))
+			mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+			                                                                       batch_size=encoder_out.size(0))
+			encoder_out = encoder_out * mask_shfit_chunk
+		pre_acoustic_embeds, pre_token_length, pre_alphas, _ = self.predictor(encoder_out,
+		                                                                      ys_pad,
+		                                                                      encoder_out_mask,
+		                                                                      ignore_id=self.ignore_id,
+		                                                                      mask_chunk_predictor=mask_chunk_predictor,
+		                                                                      target_label_length=ys_pad_lens,
+		                                                                      )
+		predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+		                                                                                     encoder_out_lens)
+		
+		scama_mask = None
+		if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+			encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+			attention_chunk_center_bias = 0
+			attention_chunk_size = encoder_chunk_size
+			decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+			mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
+				get_mask_shift_att_chunk_decoder(None,
+			                                     device=encoder_out.device,
+			                                     batch_size=encoder_out.size(0)
+			                                     )
+			scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+				predictor_alignments=predictor_alignments,
+				encoder_sequence_length=encoder_out_lens,
+				chunk_size=1,
+				encoder_chunk_size=encoder_chunk_size,
+				attention_chunk_center_bias=attention_chunk_center_bias,
+				attention_chunk_size=attention_chunk_size,
+				attention_chunk_type=self.decoder_attention_chunk_type,
+				step=None,
+				predictor_mask_chunk_hopping=mask_chunk_predictor,
+				decoder_att_look_back_factor=decoder_att_look_back_factor,
+				mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+				target_length=ys_pad_lens,
+				is_training=self.training,
+			)
+		elif self.encoder.overlap_chunk_cls is not None:
+			encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out,
+			                                                                            encoder_out_lens,
+			                                                                            chunk_outs=None)
+		# 0. sampler
+		decoder_out_1st = None
+		pre_loss_att = None
+		if self.sampling_ratio > 0.0:
+			if self.step_cur < 2:
+				logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+			if self.use_1st_decoder_loss:
+				sematic_embeds, decoder_out_1st, pre_loss_att = \
+					self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
+					                       ys_pad_lens, pre_acoustic_embeds, scama_mask)
+			else:
+				sematic_embeds, decoder_out_1st = \
+					self.sampler(encoder_out, encoder_out_lens, ys_pad,
+					             ys_pad_lens, pre_acoustic_embeds, scama_mask)
+		else:
+			if self.step_cur < 2:
+				logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+			sematic_embeds = pre_acoustic_embeds
+		
+		# 1. Forward decoder
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, scama_mask
+		)
+		decoder_out, _ = decoder_outs[0], decoder_outs[1]
+		
+		if decoder_out_1st is None:
+			decoder_out_1st = decoder_out
+		# 2. Compute attention loss
+		loss_att = self.criterion_att(decoder_out, ys_pad)
+		acc_att = th_accuracy(
+			decoder_out_1st.view(-1, self.vocab_size),
+			ys_pad,
+			ignore_label=self.ignore_id,
+		)
+		loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+		
+		# Compute cer/wer using attention-decoder
+		if self.training or self.error_calculator is None:
+			cer_att, wer_att = None, None
+		else:
+			ys_hat = decoder_out_1st.argmax(dim=-1)
+			cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+		
+		return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+	
+	def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, chunk_mask=None):
+		
+		tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+		ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+		if self.share_embedding:
+			ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+		else:
+			ys_pad_embed = self.decoder.embed(ys_pad_masked)
+		with torch.no_grad():
+			decoder_outs = self.decoder(
+				encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, chunk_mask
+			)
+			decoder_out, _ = decoder_outs[0], decoder_outs[1]
+			pred_tokens = decoder_out.argmax(-1)
+			nonpad_positions = ys_pad.ne(self.ignore_id)
+			seq_lens = (nonpad_positions).sum(1)
+			same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+			input_mask = torch.ones_like(nonpad_positions)
+			bsz, seq_len = ys_pad.size()
+			for li in range(bsz):
+				target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+				if target_num > 0:
+					input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+			input_mask = input_mask.eq(1)
+			input_mask = input_mask.masked_fill(~nonpad_positions, False)
+			input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+		
+		sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+			input_mask_expand_dim, 0)
+		return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+	
+
+	def calc_predictor(self, encoder_out, encoder_out_lens):
+		
+		encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+			encoder_out.device)
+		mask_chunk_predictor = None
+		if self.encoder.overlap_chunk_cls is not None:
+			mask_chunk_predictor = self.encoder.overlap_chunk_cls.get_mask_chunk_predictor(None,
+			                                                                               device=encoder_out.device,
+			                                                                               batch_size=encoder_out.size(
+				                                                                               0))
+			mask_shfit_chunk = self.encoder.overlap_chunk_cls.get_mask_shfit_chunk(None, device=encoder_out.device,
+			                                                                       batch_size=encoder_out.size(0))
+			encoder_out = encoder_out * mask_shfit_chunk
+		pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index = self.predictor(encoder_out,
+		                                                                                   None,
+		                                                                                   encoder_out_mask,
+		                                                                                   ignore_id=self.ignore_id,
+		                                                                                   mask_chunk_predictor=mask_chunk_predictor,
+		                                                                                   target_label_length=None,
+		                                                                                   )
+		predictor_alignments, predictor_alignments_len = self.predictor.gen_frame_alignments(pre_alphas,
+		                                                                                     encoder_out_lens + 1 if self.predictor.tail_threshold > 0.0 else encoder_out_lens)
+		
+		scama_mask = None
+		if self.encoder.overlap_chunk_cls is not None and self.decoder_attention_chunk_type == 'chunk':
+			encoder_chunk_size = self.encoder.overlap_chunk_cls.chunk_size_pad_shift_cur
+			attention_chunk_center_bias = 0
+			attention_chunk_size = encoder_chunk_size
+			decoder_att_look_back_factor = self.encoder.overlap_chunk_cls.decoder_att_look_back_factor_cur
+			mask_shift_att_chunk_decoder = self.encoder.overlap_chunk_cls. \
+				get_mask_shift_att_chunk_decoder(None,
+			                                     device=encoder_out.device,
+			                                     batch_size=encoder_out.size(0)
+			                                     )
+			scama_mask = self.build_scama_mask_for_cross_attention_decoder_fn(
+				predictor_alignments=predictor_alignments,
+				encoder_sequence_length=encoder_out_lens,
+				chunk_size=1,
+				encoder_chunk_size=encoder_chunk_size,
+				attention_chunk_center_bias=attention_chunk_center_bias,
+				attention_chunk_size=attention_chunk_size,
+				attention_chunk_type=self.decoder_attention_chunk_type,
+				step=None,
+				predictor_mask_chunk_hopping=mask_chunk_predictor,
+				decoder_att_look_back_factor=decoder_att_look_back_factor,
+				mask_shift_att_chunk_decoder=mask_shift_att_chunk_decoder,
+				target_length=None,
+				is_training=self.training,
+			)
+		self.scama_mask = scama_mask
+		
+		return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
+	
+	def calc_predictor_chunk(self, encoder_out, cache=None):
+		
+		pre_acoustic_embeds, pre_token_length = \
+			self.predictor.forward_chunk(encoder_out, cache["encoder"])
+		return pre_acoustic_embeds, pre_token_length
+	
+	def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
+		decoder_outs = self.decoder(
+			encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, self.scama_mask
+		)
+		decoder_out = decoder_outs[0]
+		decoder_out = torch.log_softmax(decoder_out, dim=-1)
+		return decoder_out, ys_pad_lens
+	
+	def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
+		decoder_outs = self.decoder.forward_chunk(
+			encoder_out, sematic_embeds, cache["decoder"]
+		)
+		decoder_out = decoder_outs
+		decoder_out = torch.log_softmax(decoder_out, dim=-1)
+		return decoder_out
+
+	def generate(self,
+	             speech: torch.Tensor,
+	             speech_lengths: torch.Tensor,
+	             tokenizer=None,
+	             **kwargs,
+	             ):
+		
+		is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
+		print(is_use_ctc)
+		is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+		
+		if self.beam_search is None and (is_use_lm or is_use_ctc):
+			logging.info("enable beam_search")
+			self.init_beam_search(speech, speech_lengths, **kwargs)
+			self.nbest = kwargs.get("nbest", 1)
+		
+		# Forward Encoder
+		encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+		if isinstance(encoder_out, tuple):
+			encoder_out = encoder_out[0]
+		
+		# predictor
+		predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+		pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+		                                                                predictor_outs[2], predictor_outs[3]
+		pre_token_length = pre_token_length.round().long()
+		if torch.max(pre_token_length) < 1:
+			return []
+		decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
+		                                               pre_token_length)
+		decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+		
+		results = []
+		b, n, d = decoder_out.size()
+		for i in range(b):
+			x = encoder_out[i, :encoder_out_lens[i], :]
+			am_scores = decoder_out[i, :pre_token_length[i], :]
+			if self.beam_search is not None:
+				nbest_hyps = self.beam_search(
+					x=x, am_scores=am_scores, maxlenratio=kwargs.get("maxlenratio", 0.0),
+					minlenratio=kwargs.get("minlenratio", 0.0)
+				)
+				
+				nbest_hyps = nbest_hyps[: self.nbest]
+			else:
+				
+				yseq = am_scores.argmax(dim=-1)
+				score = am_scores.max(dim=-1)[0]
+				score = torch.sum(score, dim=-1)
+				# pad with mask tokens to ensure compatibility with sos/eos tokens
+				yseq = torch.tensor(
+					[self.sos] + yseq.tolist() + [self.eos], device=yseq.device
+				)
+				nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+			for hyp in nbest_hyps:
+				assert isinstance(hyp, (Hypothesis)), type(hyp)
+				
+				# remove sos/eos and get results
+				last_pos = -1
+				if isinstance(hyp.yseq, list):
+					token_int = hyp.yseq[1:last_pos]
+				else:
+					token_int = hyp.yseq[1:last_pos].tolist()
+				
+				# remove blank symbol id, which is assumed to be 0
+				token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+				
+				# Change integer-ids to tokens
+				token = tokenizer.ids2tokens(token_int)
+				text = tokenizer.tokens2text(token)
+				
+				timestamp = []
+				
+				results.append((text, token, timestamp))
+		
+		return results
+
diff --git a/funasr/models/paraformer/search.py b/funasr/models/paraformer/search.py
new file mode 100644
index 0000000..440b48e
--- /dev/null
+++ b/funasr/models/paraformer/search.py
@@ -0,0 +1,453 @@
+from itertools import chain
+import logging
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import NamedTuple
+from typing import Tuple
+from typing import Union
+
+import torch
+
+from funasr.modules.e2e_asr_common import end_detect
+from funasr.modules.scorers.scorer_interface import PartialScorerInterface
+from funasr.modules.scorers.scorer_interface import ScorerInterface
+
+class Hypothesis(NamedTuple):
+    """Hypothesis data type."""
+
+    yseq: torch.Tensor
+    score: Union[float, torch.Tensor] = 0
+    scores: Dict[str, Union[float, torch.Tensor]] = dict()
+    states: Dict[str, Any] = dict()
+
+    def asdict(self) -> dict:
+        """Convert data to JSON-friendly dict."""
+        return self._replace(
+            yseq=self.yseq.tolist(),
+            score=float(self.score),
+            scores={k: float(v) for k, v in self.scores.items()},
+        )._asdict()
+
+
+class BeamSearchPara(torch.nn.Module):
+    """Beam search implementation."""
+
+    def __init__(
+        self,
+        scorers: Dict[str, ScorerInterface],
+        weights: Dict[str, float],
+        beam_size: int,
+        vocab_size: int,
+        sos: int,
+        eos: int,
+        token_list: List[str] = None,
+        pre_beam_ratio: float = 1.5,
+        pre_beam_score_key: str = None,
+    ):
+        """Initialize beam search.
+
+        Args:
+            scorers (dict[str, ScorerInterface]): Dict of decoder modules
+                e.g., Decoder, CTCPrefixScorer, LM
+                The scorer will be ignored if it is `None`
+            weights (dict[str, float]): Dict of weights for each scorers
+                The scorer will be ignored if its weight is 0
+            beam_size (int): The number of hypotheses kept during search
+            vocab_size (int): The number of vocabulary
+            sos (int): Start of sequence id
+            eos (int): End of sequence id
+            token_list (list[str]): List of tokens for debug log
+            pre_beam_score_key (str): key of scores to perform pre-beam search
+            pre_beam_ratio (float): beam size in the pre-beam search
+                will be `int(pre_beam_ratio * beam_size)`
+
+        """
+        super().__init__()
+        # set scorers
+        self.weights = weights
+        self.scorers = dict()
+        self.full_scorers = dict()
+        self.part_scorers = dict()
+        # this module dict is required for recursive cast
+        # `self.to(device, dtype)` in `recog.py`
+        self.nn_dict = torch.nn.ModuleDict()
+        for k, v in scorers.items():
+            w = weights.get(k, 0)
+            if w == 0 or v is None:
+                continue
+            assert isinstance(
+                v, ScorerInterface
+            ), f"{k} ({type(v)}) does not implement ScorerInterface"
+            self.scorers[k] = v
+            if isinstance(v, PartialScorerInterface):
+                self.part_scorers[k] = v
+            else:
+                self.full_scorers[k] = v
+            if isinstance(v, torch.nn.Module):
+                self.nn_dict[k] = v
+
+        # set configurations
+        self.sos = sos
+        self.eos = eos
+        self.token_list = token_list
+        self.pre_beam_size = int(pre_beam_ratio * beam_size)
+        self.beam_size = beam_size
+        self.n_vocab = vocab_size
+        if (
+            pre_beam_score_key is not None
+            and pre_beam_score_key != "full"
+            and pre_beam_score_key not in self.full_scorers
+        ):
+            raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
+        self.pre_beam_score_key = pre_beam_score_key
+        self.do_pre_beam = (
+            self.pre_beam_score_key is not None
+            and self.pre_beam_size < self.n_vocab
+            and len(self.part_scorers) > 0
+        )
+
+    def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
+        """Get an initial hypothesis data.
+
+        Args:
+            x (torch.Tensor): The encoder output feature
+
+        Returns:
+            Hypothesis: The initial hypothesis.
+
+        """
+        init_states = dict()
+        init_scores = dict()
+        for k, d in self.scorers.items():
+            init_states[k] = d.init_state(x)
+            init_scores[k] = 0.0
+        return [
+            Hypothesis(
+                score=0.0,
+                scores=init_scores,
+                states=init_states,
+                yseq=torch.tensor([self.sos], device=x.device),
+            )
+        ]
+
+    @staticmethod
+    def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+        """Append new token to prefix tokens.
+
+        Args:
+            xs (torch.Tensor): The prefix token
+            x (int): The new token to append
+
+        Returns:
+            torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
+
+        """
+        x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+        return torch.cat((xs, x))
+
+    def score_full(
+        self, hyp: Hypothesis, x: torch.Tensor
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.full_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.full_scorers`
+                and tensor score values of shape: `(self.n_vocab,)`,
+                and state dict that has string keys
+                and state values of `self.full_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.full_scorers.items():
+            scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
+        return scores, states
+
+    def score_partial(
+        self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.part_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            ids (torch.Tensor): 1D tensor of new partial tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.part_scorers`
+                and tensor score values of shape: `(len(ids),)`,
+                and state dict that has string keys
+                and state values of `self.part_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.part_scorers.items():
+            scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
+        return scores, states
+
+    def beam(
+        self, weighted_scores: torch.Tensor, ids: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute topk full token ids and partial token ids.
+
+        Args:
+            weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
+            Its shape is `(self.n_vocab,)`.
+            ids (torch.Tensor): The partial token ids to compute topk
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]:
+                The topk full token ids and partial token ids.
+                Their shapes are `(self.beam_size,)`
+
+        """
+        # no pre beam performed
+        if weighted_scores.size(0) == ids.size(0):
+            top_ids = weighted_scores.topk(self.beam_size)[1]
+            return top_ids, top_ids
+
+        # mask pruned in pre-beam not to select in topk
+        tmp = weighted_scores[ids]
+        weighted_scores[:] = -float("inf")
+        weighted_scores[ids] = tmp
+        top_ids = weighted_scores.topk(self.beam_size)[1]
+        local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+        return top_ids, local_ids
+
+    @staticmethod
+    def merge_scores(
+        prev_scores: Dict[str, float],
+        next_full_scores: Dict[str, torch.Tensor],
+        full_idx: int,
+        next_part_scores: Dict[str, torch.Tensor],
+        part_idx: int,
+    ) -> Dict[str, torch.Tensor]:
+        """Merge scores for new hypothesis.
+
+        Args:
+            prev_scores (Dict[str, float]):
+                The previous hypothesis scores by `self.scorers`
+            next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
+            full_idx (int): The next token id for `next_full_scores`
+            next_part_scores (Dict[str, torch.Tensor]):
+                scores of partial tokens by `self.part_scorers`
+            part_idx (int): The new token id for `next_part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are scalar tensors by the scorers.
+
+        """
+        new_scores = dict()
+        for k, v in next_full_scores.items():
+            new_scores[k] = prev_scores[k] + v[full_idx]
+        for k, v in next_part_scores.items():
+            new_scores[k] = prev_scores[k] + v[part_idx]
+        return new_scores
+
+    def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+        """Merge states for new hypothesis.
+
+        Args:
+            states: states of `self.full_scorers`
+            part_states: states of `self.part_scorers`
+            part_idx (int): The new token id for `part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are states of the scorers.
+
+        """
+        new_states = dict()
+        for k, v in states.items():
+            new_states[k] = v
+        for k, d in self.part_scorers.items():
+            new_states[k] = d.select_state(part_states[k], part_idx)
+        return new_states
+
+    def search(
+        self, running_hyps: List[Hypothesis], x: torch.Tensor, am_score: torch.Tensor
+    ) -> List[Hypothesis]:
+        """Search new tokens for running hypotheses and encoded speech x.
+
+        Args:
+            running_hyps (List[Hypothesis]): Running hypotheses on beam
+            x (torch.Tensor): Encoded speech feature (T, D)
+
+        Returns:
+            List[Hypotheses]: Best sorted hypotheses
+
+        """
+        best_hyps = []
+        part_ids = torch.arange(self.n_vocab, device=x.device)  # no pre-beam
+        for hyp in running_hyps:
+            # scoring
+            weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
+            weighted_scores += am_score
+            scores, states = self.score_full(hyp, x)
+            for k in self.full_scorers:
+                weighted_scores += self.weights[k] * scores[k]
+            # partial scoring
+            if self.do_pre_beam:
+                pre_beam_scores = (
+                    weighted_scores
+                    if self.pre_beam_score_key == "full"
+                    else scores[self.pre_beam_score_key]
+                )
+                part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+            part_scores, part_states = self.score_partial(hyp, part_ids, x)
+            for k in self.part_scorers:
+                weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+            # add previous hyp score
+            weighted_scores += hyp.score
+
+            # update hyps
+            for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+                # will be (2 x beam at most)
+                best_hyps.append(
+                    Hypothesis(
+                        score=weighted_scores[j],
+                        yseq=self.append_token(hyp.yseq, j),
+                        scores=self.merge_scores(
+                            hyp.scores, scores, j, part_scores, part_j
+                        ),
+                        states=self.merge_states(states, part_states, part_j),
+                    )
+                )
+
+            # sort and prune 2 x beam -> beam
+            best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+                : min(len(best_hyps), self.beam_size)
+            ]
+        return best_hyps
+
+    def forward(
+        self, x: torch.Tensor, am_scores: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
+    ) -> List[Hypothesis]:
+        """Perform beam search.
+
+        Args:
+            x (torch.Tensor): Encoded speech feature (T, D)
+            maxlenratio (float): Input length ratio to obtain max output length.
+                If maxlenratio=0.0 (default), it uses a end-detect function
+                to automatically find maximum hypothesis lengths
+                If maxlenratio<0.0, its absolute value is interpreted
+                as a constant max output length.
+            minlenratio (float): Input length ratio to obtain min output length.
+
+        Returns:
+            list[Hypothesis]: N-best decoding results
+
+        """
+        # set length bounds
+        maxlen = am_scores.shape[0]
+        logging.info("decoder input length: " + str(x.shape[0]))
+        logging.info("max output length: " + str(maxlen))
+
+        # main loop of prefix search
+        running_hyps = self.init_hyp(x)
+        ended_hyps = []
+        for i in range(maxlen):
+            logging.debug("position " + str(i))
+            best = self.search(running_hyps, x, am_scores[i])
+            # post process of one iteration
+            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+            # end detection
+            if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+                logging.info(f"end detected at {i}")
+                break
+            if len(running_hyps) == 0:
+                logging.info("no hypothesis. Finish decoding.")
+                break
+            else:
+                logging.debug(f"remained hypotheses: {len(running_hyps)}")
+
+        nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+        # check the number of hypotheses reaching to eos
+        if len(nbest_hyps) == 0:
+            logging.warning(
+                "there is no N-best results, perform recognition "
+                "again with smaller minlenratio."
+            )
+            return (
+                []
+                if minlenratio < 0.1
+                else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
+            )
+
+        # report the best result
+        best = nbest_hyps[0]
+        for k, v in best.scores.items():
+            logging.info(
+                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
+            )
+        logging.info(f"total log probability: {best.score:.2f}")
+        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
+        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
+        if self.token_list is not None:
+            logging.info(
+                "best hypo: "
+                + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]])
+                + "\n"
+            )
+        return nbest_hyps
+
+    def post_process(
+        self,
+        i: int,
+        maxlen: int,
+        maxlenratio: float,
+        running_hyps: List[Hypothesis],
+        ended_hyps: List[Hypothesis],
+    ) -> List[Hypothesis]:
+        """Perform post-processing of beam search iterations.
+
+        Args:
+            i (int): The length of hypothesis tokens.
+            maxlen (int): The maximum length of tokens in beam search.
+            maxlenratio (int): The maximum length ratio in beam search.
+            running_hyps (List[Hypothesis]): The running hypotheses in beam search.
+            ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+        Returns:
+            List[Hypothesis]: The new running hypotheses.
+
+        """
+        logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
+        if self.token_list is not None:
+            logging.debug(
+                "best hypo: "
+                + "".join([self.token_list[x.item()] for x in running_hyps[0].yseq[1:]])
+            )
+        # add eos in the final loop to avoid that there are no ended hyps
+        if i == maxlen - 1:
+            logging.info("adding <eos> in the last position in the loop")
+            running_hyps = [
+                h._replace(yseq=self.append_token(h.yseq, self.eos))
+                for h in running_hyps
+            ]
+
+        # add ended hypotheses to a final list, and removed them from current hypotheses
+        # (this will be a problem, number of hyps < beam)
+        remained_hyps = []
+        for hyp in running_hyps:
+            if hyp.yseq[-1] == self.eos:
+                # e.g., Word LM needs to add final <eos> score
+                for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+                    s = d.final_score(hyp.states[k])
+                    hyp.scores[k] += s
+                    hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+                ended_hyps.append(hyp)
+            else:
+                remained_hyps.append(hyp)
+        return remained_hyps
diff --git a/funasr/tokenizer/abs_tokenizer.py b/funasr/tokenizer/abs_tokenizer.py
index d2fc3f0..bbfd173 100644
--- a/funasr/tokenizer/abs_tokenizer.py
+++ b/funasr/tokenizer/abs_tokenizer.py
@@ -66,7 +66,9 @@
         return text_ints
     
     def decode(self, text_ints):
-        return self.ids2tokens(text_ints)
+        token = self.ids2tokens(text_ints)
+        text = self.tokens2text(token)
+        return text
     
     def get_num_vocabulary_size(self) -> int:
         return len(self.token_list)
diff --git a/funasr/utils/download_from_hub.py b/funasr/utils/download_from_hub.py
index d6e4ab4..d6b79d3 100644
--- a/funasr/utils/download_from_hub.py
+++ b/funasr/utils/download_from_hub.py
@@ -11,10 +11,10 @@
 	return kwargs
 
 def download_fr_ms(**kwargs):
-	model_or_path = kwargs.get("model_pretrain")
-	model_revision = kwargs.get("model_pretrain_revision")
+	model_or_path = kwargs.get("model")
+	model_revision = kwargs.get("model_revision")
 	if not os.path.exists(model_or_path):
-		model_or_path = get_or_download_model_dir(model_or_path, model_revision, third_party="funasr")
+		model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"))
 	
 	config = os.path.join(model_or_path, "config.yaml")
 	assert os.path.exists(config), "{} is not exist!".format(config)
@@ -23,25 +23,29 @@
 	init_param = os.path.join(model_or_path, "model.pb")
 	kwargs["init_param"] = init_param
 	kwargs["token_list"] = os.path.join(model_or_path, "tokens.txt")
+	kwargs["model"] = cfg["model"]
+	kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
 	
 	return kwargs
 
 def get_or_download_model_dir(
                               model,
                               model_revision=None,
-                              third_party=None):
+							  is_training=False,
+	):
 	""" Get local model directory or download model if necessary.
 
 	Args:
 		model (str): model id or path to local model directory.
 		model_revision  (str, optional): model version number.
-		third_party (str, optional): in which third party library
-			this function is called.
+		:param is_training:
 	"""
 	from modelscope.hub.check_model import check_local_model_is_latest
 	from modelscope.hub.snapshot_download import snapshot_download
 
 	from modelscope.utils.constant import Invoke, ThirdParty
+	
+	key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
 	
 	if os.path.exists(model):
 		model_cache_dir = model if os.path.isdir(
@@ -49,15 +53,15 @@
 		check_local_model_is_latest(
 			model_cache_dir,
 			user_agent={
-				Invoke.KEY: Invoke.LOCAL_TRAINER,
-				ThirdParty.KEY: third_party
+				Invoke.KEY: key,
+				ThirdParty.KEY: "funasr"
 			})
 	else:
 		model_cache_dir = snapshot_download(
 			model,
 			revision=model_revision,
 			user_agent={
-				Invoke.KEY: Invoke.TRAINER,
-				ThirdParty.KEY: third_party
+				Invoke.KEY: key,
+				ThirdParty.KEY: "funasr"
 			})
 	return model_cache_dir
\ No newline at end of file

--
Gitblit v1.9.1