From 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 16 五月 2023 15:07:20 +0800
Subject: [PATCH] Merge pull request #518 from alibaba-damo-academy/dev_wjm2

---
 egs/aishell/data2vec_transformer_finetune/run.sh                                                    |  110 
 egs/librispeech_100h/conformer/local/spm_train.py                                                   |   12 
 funasr/datasets/small_datasets/length_batch_sampler.py                                              |  147 
 funasr/models/e2e_asr.py                                                                            |   13 
 egs/aishell2/transformer/utils/fix_data.sh                                                          |    4 
 egs/librispeech_100h/conformer/conf/decode_asr_transformer.yaml                                     |    6 
 funasr/datasets/small_datasets/preprocessor.py                                                      |  875 +++++
 egs/librispeech_100h/conformer/utils                                                                |    1 
 funasr/build_utils/build_trainer.py                                                                 |  820 +++++
 egs/aishell2/conformer/local/prepare_data.sh                                                        |    6 
 egs/aishell2/transformer/utils/compute_cmvn.py                                                      |   88 
 funasr/layers/abs_normalize.py                                                                      |    2 
 funasr/tasks/vad.py                                                                                 |   79 
 egs/aishell2/paraformerbert/local/extract_embeds.sh                                                 |   31 
 egs/aishell/conformer/run.sh                                                                        |  103 
 funasr/tasks/punctuation.py                                                                         |    2 
 egs/aishell/transformer/utils/compute_cmvn.sh                                                       |   21 
 egs/librispeech/conformer/local/spm_train.py                                                        |   12 
 egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh                               |  105 
 egs/aishell2/transformer/utils/fix_data_feat.sh                                                     |    8 
 funasr/models/encoder/abs_encoder.py                                                                |    2 
 funasr/build_utils/__init__.py                                                                      |    0 
 funasr/main_funcs/collect_stats.py                                                                  |    4 
 funasr/models/encoder/resnet34_encoder.py                                                           |    2 
 funasr/models/encoder/conformer_encoder.py                                                          |    2 
 egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml                                 |   68 
 funasr/build_utils/build_scheduler.py                                                               |   44 
 egs/aishell/paraformerbert/local/extract_embeds.sh                                                  |   12 
 funasr/models/encoder/sanm_encoder.py                                                               |    3 
 funasr/tasks/diar.py                                                                                |   12 
 egs/aishell/paraformer/local/download_and_untar.sh                                                  |  105 
 funasr/tasks/lm.py                                                                                  |    3 
 egs/aishell/conformer/conf/train_asr_conformer.yaml                                                 |   28 
 funasr/build_utils/build_distributed.py                                                             |   38 
 funasr/tasks/abs_task.py                                                                            |   55 
 egs/aishell2/paraformerbert/local/prepare_data.sh                                                   |    7 
 egs/aishell/data2vec_paraformer_finetune/run.sh                                                     |  112 
 funasr/models/encoder/data2vec_encoder.py                                                           |    2 
 funasr/models/e2e_vad.py                                                                            |    8 
 egs/librispeech/conformer/conf/train_asr_conformer.yaml                                             |   14 
 egs/aishell2/paraformer/run.sh                                                                      |  105 
 funasr/main_funcs/calculate_all_attentions.py                                                       |    4 
 funasr/models/e2e_diar_eend_ola.py                                                                  |    5 
 funasr/models/e2e_diar_sond.py                                                                      |    8 
 egs/aishell2/transformer/conf/train_asr_transformer.yaml                                            |   22 
 funasr/models/e2e_tp.py                                                                             |   15 
 egs/aishell2/transformer/utils/compute_cmvn.sh                                                      |   21 
 funasr/train/abs_model.py                                                                           |   28 
 egs/aishell/transformer/path.sh                                                                     |    2 
 funasr/models/frontend/windowing.py                                                                 |    7 
 egs/aishell/transformer/utils/cmvn_converter.py                                                     |    6 
 funasr/build_utils/build_diar_model.py                                                              |  296 +
 egs/aishell2/paraformer/local/prepare_data.sh                                                       |    6 
 egs/aishell2/paraformerbert/run.sh                                                                  |  136 
 funasr/models/base_model.py                                                                         |   17 
 egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml                     |   29 
 funasr/datasets/small_datasets/sequence_iter_factory.py                                             |  189 +
 funasr/bin/train.py                                                                                 |  572 +++
 docs/reference/build_task.md                                                                        |    2 
 egs/aishell/conformer/local/download_and_untar.sh                                                   |  105 
 funasr/build_utils/build_vad_model.py                                                               |   77 
 egs/aishell2/data2vec_pretrain/run.sh                                                               |  131 
 funasr/build_utils/build_optimizer.py                                                               |   28 
 funasr/models/encoder/rnn_encoder.py                                                                |    3 
 funasr/models/frontend/wav_frontend_kaldifeat.py                                                    |  119 
 funasr/build_utils/build_pretrain_model.py                                                          |  107 
 funasr/models/e2e_asr_paraformer.py                                                                 |  162 
 egs/aishell2/transformer/utils/combine_cmvn_file.py                                                 |   27 
 egs/librispeech_100h/conformer/path.sh                                                              |    0 
 egs/aishell2/transformer/utils/compute_fbank.py                                                     |   24 
 egs/librispeech/conformer/run.sh                                                                    |  143 
 egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh                                |  105 
 egs/librispeech_100h/conformer/run.sh                                                               |  212 +
 funasr/utils/prepare_data.py                                                                        |  209 +
 egs/aishell2/transformer/run.sh                                                                     |  101 
 egs/librispeech_100h/conformer/local/spm_encode.py                                                  |   98 
 egs/aishell/paraformer/run.sh                                                                       |  100 
 funasr/build_utils/build_dataloader.py                                                              |   15 
 egs/aishell/paraformerbert/run.sh                                                                   |  114 
 egs/librispeech_100h/conformer/local/data_prep.sh                                                   |    0 
 egs/aishell/paraformerbert/local/download_and_untar.sh                                              |  105 
 funasr/layers/global_mvn.py                                                                         |    5 
 funasr/build_utils/build_args.py                                                                    |   93 
 funasr/models/e2e_sv.py                                                                             |    8 
 funasr/datasets/large_datasets/build_dataloader.py                                                  |   36 
 funasr/train/trainer.py                                                                             |    4 
 funasr/datasets/small_datasets/dataset.py                                                           |  258 +
 egs/librispeech/conformer/local/spm_encode.py                                                       |   98 
 funasr/models/frontend/abs_frontend.py                                                              |    2 
 funasr/build_utils/build_punc_model.py                                                              |   68 
 egs/aishell/transformer/run.sh                                                                      |  112 
 egs/aishell/paraformerbert/local/aishell_data_prep.sh                                               |   23 
 egs/librispeech/conformer/local/download_and_untar.sh                                               |   97 
 funasr/models/data2vec.py                                                                           |    6 
 egs/aishell2/transformer/utils/download_model.py                                                    |   20 
 funasr/build_utils/build_lm_model.py                                                                |   57 
 funasr/build_utils/build_asr_model.py                                                               |  352 ++
 egs/aishell/transformer/utils/compute_cmvn.py                                                       |   88 
 egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml                                        |   32 
 egs/aishell2/transformer/utils/cmvn_converter.py                                                    |   51 
 egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml |   29 
 egs/aishell/transformer/conf/train_asr_transformer.yaml                                             |   30 
 funasr/models/vad_realtime_transformer.py                                                           |    3 
 funasr/datasets/large_datasets/dataset.py                                                           |   24 
 egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml           |   39 
 funasr/bin/asr_test.py                                                                              |  226 +
 funasr/datasets/small_datasets/collate_fn.py                                                        |   93 
 funasr/datasets/large_datasets/utils/tokenize.py                                                    |    2 
 egs/aishell/transformer/utils/combine_cmvn_file.py                                                  |   27 
 funasr/fileio/sound_scp.py                                                                          |   15 
 funasr/models/e2e_asr_mfcca.py                                                                      |  138 
 egs/librispeech_100h/conformer/local/download_and_untar.sh                                          |   97 
 funasr/layers/inversible_interface.py                                                               |    2 
 funasr/models/encoder/mfcca_encoder.py                                                              |  120 
 egs/aishell2/transformer/local/prepare_data.sh                                                      |    6 
 funasr/models/frontend/fused.py                                                                     |    2 
 funasr/models/target_delay_transformer.py                                                           |    5 
 egs/aishell2/transformer/utils/compute_fbank.sh                                                     |    5 
 funasr/models/frontend/s3prl.py                                                                     |    1 
 egs/aishell2/transformer/utils/compute_wer.py                                                       |    4 
 funasr/models/e2e_uni_asr.py                                                                        |   11 
 funasr/lm/abs_model.py                                                                              |    4 
 egs/aishell2/conformer/conf/train_asr_conformer.yaml                                                |   22 
 egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml           |   22 
 funasr/models/specaug/abs_specaug.py                                                                |    2 
 funasr/tasks/sv.py                                                                                  |   10 
 egs/aishell/transformer/local/download_and_untar.sh                                                 |  105 
 egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml            |   29 
 funasr/tasks/asr.py                                                                                 |   18 
 egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml             |   39 
 egs/aishell2/conformer/run.sh                                                                       |   95 
 funasr/models/frontend/wav_frontend.py                                                              |    2 
 egs/librispeech/conformer/local/data_prep.sh                                                        |    0 
 /dev/null                                                                                           |   55 
 funasr/build_utils/build_model.py                                                                   |   25 
 funasr/models/frontend/default.py                                                                   |    4 
 136 files changed, 7,292 insertions(+), 1,850 deletions(-)

diff --git a/docs/reference/build_task.md b/docs/reference/build_task.md
index be2d1af..2020860 100644
--- a/docs/reference/build_task.md
+++ b/docs/reference/build_task.md
@@ -103,7 +103,7 @@
         )
     return model
 ```
-This function defines the detail of the model. For different speech recognition models, the same speech recognition `Task` can usually be shared and the remaining thing needed to be done is to define a specific model in this function. For example, a speech recognition model with a standard encoder-decoder structure has been shown above. Specifically, it first defines each module of the model, including encoder, decoder, etc. and then combine these modules together to generate a complete model. In FunASR, the model needs to inherit `AbsESPnetModel` and the corresponding code can be seen in `funasr/train/abs_espnet_model.py`. The main function needed to be implemented is the `forward` function.
+This function defines the detail of the model. For different speech recognition models, the same speech recognition `Task` can usually be shared and the remaining thing needed to be done is to define a specific model in this function. For example, a speech recognition model with a standard encoder-decoder structure has been shown above. Specifically, it first defines each module of the model, including encoder, decoder, etc. and then combine these modules together to generate a complete model. In FunASR, the model needs to inherit `FunASRModel` and the corresponding code can be seen in `funasr/train/abs_espnet_model.py`. The main function needed to be implemented is the `forward` function.
 
 Next, we take `SANMEncoder` as an example to introduce how to use a custom encoder as a part of the model when defining the specified model and the corresponding code can be seen in `funasr/models/encoder/sanm_encoder.py`. For a custom encoder, in addition to inheriting the common encoder class `AbsEncoder`, it is also necessary to define the `forward` function to achieve the forward computation of the `encoder`. After defining the `encoder`, it should also be registered in the `Task`. The corresponding code example can be seen as below:
 ```python
diff --git a/egs/aishell/conformer/conf/train_asr_conformer.yaml b/egs/aishell/conformer/conf/train_asr_conformer.yaml
index ddf217e..ea2353e 100644
--- a/egs/aishell/conformer/conf/train_asr_conformer.yaml
+++ b/egs/aishell/conformer/conf/train_asr_conformer.yaml
@@ -29,21 +29,27 @@
     self_attention_dropout_rate: 0.0
     src_attention_dropout_rate: 0.0
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
 # hybrid CTC/attention
 model_conf:
     ctc_weight: 0.3
     lsm_weight: 0.1     # label smoothing option
     length_normalized_loss: false
 
-# minibatch related
-batch_type: length
-batch_bins: 25000
-num_workers: 16
-
 # optimization related
 accum_grad: 1
 grad_clip: 5
-max_epoch: 50
+max_epoch: 150
 val_scheduler_criterion:
     - valid
     - acc
@@ -76,5 +82,15 @@
     - 40
     num_time_mask: 2
 
+dataset_conf:
+    shuffle: True
+    shuffle_conf:
+        shuffle_size: 2048
+        sort_size: 500
+    batch_conf:
+        batch_type: token
+        batch_size: 25000
+    num_workers: 8
+
 log_interval: 50
 normalize: None
diff --git a/egs/aishell/conformer/local/download_and_untar.sh b/egs/aishell/conformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/conformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+#             2017  Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1;
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tgz
+  else
+    echo "$data/$part.tgz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+  if ! command -v wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1;
+  fi
+  full_url=$url/$part.tgz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  cd $data || exit 1
+  if ! wget --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1;
+  fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+  echo "$0: error un-tarring archive $data/$part.tgz"
+  exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+  cd $data/$part/wav || exit 1
+  for wav in ./*.tar.gz; do
+    echo "Extracting wav from $wav"
+    tar -zxf $wav && rm $wav
+  done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+  rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/conformer/local/prepare_data.sh b/egs/aishell/conformer/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/conformer/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-#           2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
-  echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
-  echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
-  exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
-  echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
-  exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
-  cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/conformer/run.sh b/egs/aishell/conformer/run.sh
index 227b3f2..8db0d97 100755
--- a/egs/aishell/conformer/run.sh
+++ b/egs/aishell/conformer/run.sh
@@ -16,22 +16,20 @@
 feats_dir="../DATA" #feature output dictionary
 exp_dir="."
 lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
 stop_stage=4
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
 
 # exp tag
 tag="exp1"
@@ -49,7 +47,7 @@
 test_sets="dev test"
 
 asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
 
 inference_config=conf/decode_asr_transformer.yaml
 inference_asr_model=valid.acc.ave_10best.pb
@@ -66,10 +64,16 @@
     _ngpu=0
 fi
 
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    echo "stage -1: Data Download"
+    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
     echo "stage 0: Data preparation"
     # Data preparation
-    local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
     for x in train dev test; do
         cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
         paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -79,46 +83,9 @@
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
-        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
-    utils/fix_data_feat.sh ${fbankdir}/train
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
-    utils/fix_data_feat.sh ${fbankdir}/dev
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
-    utils/fix_data_feat.sh ${fbankdir}/test
-     
-    # compute global cmvn
-    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
-        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
-    # apply cmvn 
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-    
-    cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
-    cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
-    cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
-    utils/fix_data_feat.sh ${feat_train_dir}
-    utils/fix_data_feat.sh ${feat_dev_dir}
-    utils/fix_data_feat.sh ${feat_test_dir}
-
-    #generate ark list 
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
 fi
 
 token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -126,22 +93,14 @@
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     echo "stage 2: Dictionary Preparation"
     mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-   
+
     echo "make a dictionary"
     echo "<blank>" > ${token_list}
     echo "<s>" >> ${token_list}
     echo "</s>" >> ${token_list}
-    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
-    num_token=$(cat ${token_list} | wc -l)
     echo "<unk>" >> ${token_list}
-    vocab_size=$(cat ${token_list} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train 
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
 fi
 
 # Training Stage
@@ -161,26 +120,22 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train.py \
+            train.py \
+                --task_name asr \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
                 --token_type char \
                 --token_list $token_list \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char  \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
-                --input_size $feats_dim \
                 --ngpu $gpu_num \
                 --num_worker_count $count \
-                --multiprocessing_distributed true \
                 --dist_init_method $init_method \
                 --dist_world_size $world_size \
                 --dist_rank $rank \
@@ -203,7 +158,7 @@
             exit 0
         fi
         mkdir -p "${_logdir}"
-        _data="${feats_dir}/${dumpdir}/${dset}"
+        _data="${feats_dir}/data/${dset}"
         key_file=${_data}/${scp}
         num_scp_file="$(<${key_file} wc -l)"
         _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -224,6 +179,7 @@
                 --njob ${njob} \
                 --gpuid_list ${gpuid_list} \
                 --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                 --key_file "${_logdir}"/keys.JOB.scp \
                 --asr_train_config "${asr_exp}"/config.yaml \
                 --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -244,5 +200,4 @@
         tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
         cat ${_dir}/text.cer.txt
     done
-fi
-
+fi
\ No newline at end of file
diff --git a/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml b/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
index f9a2cdb..287b088 100644
--- a/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
+++ b/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
@@ -30,7 +30,6 @@
     require_same_masks: true
     mask_dropout: 0
 
-
 # decoder related
 decoder: paraformer_decoder_san
 decoder_conf:
@@ -42,6 +41,18 @@
     self_attention_dropout_rate: 0.0
     src_attention_dropout_rate: 0.0
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
+# hybrid CTC/attention
 model: paraformer
 model_conf:
     ctc_weight: 0.3
@@ -50,15 +61,10 @@
     predictor_weight: 1.0
     sampling_ratio: 0.4
 
-# minibatch related
-batch_type: length
-batch_bins: 25000
-num_workers: 16
-
 # optimization related
 accum_grad: 1
 grad_clip: 5
-max_epoch: 50
+max_epoch: 150
 val_scheduler_criterion:
     - valid
     - acc
@@ -98,6 +104,15 @@
   l_order: 1
   r_order: 1
 
+dataset_conf:
+    shuffle: True
+    shuffle_conf:
+        shuffle_size: 2048
+        sort_size: 500
+    batch_conf:
+        batch_type: token
+        batch_size: 25000
+    num_workers: 8
 
 log_interval: 50
 unused_parameters: true
diff --git a/egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh b/egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+#             2017  Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1;
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tgz
+  else
+    echo "$data/$part.tgz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+  if ! command -v wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1;
+  fi
+  full_url=$url/$part.tgz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  cd $data || exit 1
+  if ! wget --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1;
+  fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+  echo "$0: error un-tarring archive $data/$part.tgz"
+  exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+  cd $data/$part/wav || exit 1
+  for wav in ./*.tar.gz; do
+    echo "Extracting wav from $wav"
+    tar -zxf $wav && rm $wav
+  done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+  rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh b/egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-#           2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
-  echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
-  echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
-  exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
-  echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
-  exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
-  cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/data2vec_paraformer_finetune/run.sh b/egs/aishell/data2vec_paraformer_finetune/run.sh
index d033ce2..b9d1668 100755
--- a/egs/aishell/data2vec_paraformer_finetune/run.sh
+++ b/egs/aishell/data2vec_paraformer_finetune/run.sh
@@ -8,33 +8,31 @@
 count=1
 gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
 # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
+njob=1
 train_cmd=utils/run.pl
 infer_cmd=utils/run.pl
 
 # general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
 exp_dir="."
 lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
 stop_stage=4
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
 
 # exp tag
-tag=""
+tag="exp1"
 
 model_name=damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch
 init_param="$HOME/.cache/modelscope/hub/$model_name/basemodel.pb"
@@ -52,7 +50,7 @@
 test_sets="dev test"
 
 asr_config=conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
 
 inference_config=conf/decode_asr_transformer_noctc_1best.yaml
 inference_asr_model=valid.acc.ave_10best.pb
@@ -69,10 +67,16 @@
     _ngpu=0
 fi
 
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    echo "stage -1: Data Download"
+    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
     echo "stage 0: Data preparation"
     # Data preparation
-    local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
     for x in train dev test; do
         cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
         paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -82,46 +86,9 @@
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
-        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
-    utils/fix_data_feat.sh ${fbankdir}/train
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
-    utils/fix_data_feat.sh ${fbankdir}/dev
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
-    utils/fix_data_feat.sh ${fbankdir}/test
-     
-    # compute global cmvn
-    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
-        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
-    # apply cmvn 
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-    
-    cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
-    cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
-    cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
-    utils/fix_data_feat.sh ${feat_train_dir}
-    utils/fix_data_feat.sh ${feat_dev_dir}
-    utils/fix_data_feat.sh ${feat_test_dir}
-
-    #generate ark list 
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
 fi
 
 token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -129,35 +96,27 @@
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     echo "stage 2: Dictionary Preparation"
     mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-   
+
     echo "make a dictionary"
     echo "<blank>" > ${token_list}
     echo "<s>" >> ${token_list}
     echo "</s>" >> ${token_list}
-    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
-    num_token=$(cat ${token_list} | wc -l)
     echo "<unk>" >> ${token_list}
-    vocab_size=$(cat ${token_list} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train 
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
 fi
 
 # Training Stage
 world_size=$gpu_num  # run on one machine
 if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
     echo "stage 3: Training"
-    python utils/download_model.py  --model_name ${model_name}  # download pretrained model on ModelScope
+     python utils/download_model.py  --model_name ${model_name}  # download pretrained model on ModelScope
     mkdir -p ${exp_dir}/exp/${model_dir}
     mkdir -p ${exp_dir}/exp/${model_dir}/log
     INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
     if [ -f $INIT_FILE ];then
         rm -f $INIT_FILE
-    fi 
+    fi
     init_method=file://$(readlink -f $INIT_FILE)
     echo "$0: init method is $init_method"
     for ((i = 0; i < $gpu_num; ++i)); do
@@ -165,27 +124,22 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train_paraformer.py \
+            train.py \
+                --task_name asr \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
                 --token_type char \
                 --token_list $token_list \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char  \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
                 --init_param ${init_param} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
-                --input_size $feats_dim \
                 --ngpu $gpu_num \
                 --num_worker_count $count \
-                --multiprocessing_distributed true \
                 --dist_init_method $init_method \
                 --dist_world_size $world_size \
                 --dist_rank $rank \
@@ -208,7 +162,7 @@
             exit 0
         fi
         mkdir -p "${_logdir}"
-        _data="${feats_dir}/${dumpdir}/${dset}"
+        _data="${feats_dir}/data/${dset}"
         key_file=${_data}/${scp}
         num_scp_file="$(<${key_file} wc -l)"
         _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -229,6 +183,8 @@
                 --njob ${njob} \
                 --gpuid_list ${gpuid_list} \
                 --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
                 --key_file "${_logdir}"/keys.JOB.scp \
                 --asr_train_config "${asr_exp}"/config.yaml \
                 --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -249,4 +205,4 @@
         tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
         cat ${_dir}/text.cer.txt
     done
-fi
+fi
\ No newline at end of file
diff --git a/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml b/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
index 5bc5236..ad3ad2e 100644
--- a/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
+++ b/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
@@ -30,25 +30,28 @@
   require_same_masks: true
   mask_dropout: 0
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
 # hybrid CTC/attention
 model_conf:
     ctc_weight: 1.0
     lsm_weight: 0.1     # label smoothing option
     length_normalized_loss: false
 
-# for logger
-log_interval: 50
-
-# minibatch related
-batch_type: length
-batch_bins: 16000
-num_workers: 16
-
 # optimization related
 accum_grad: 1
 grad_clip: 5
 patience: none
-max_epoch: 50
+max_epoch: 150
 val_scheduler_criterion:
     - valid
     - acc
@@ -57,8 +60,6 @@
     - cer_ctc
     - min
 keep_nbest_models: 10
-unused_parameters: true
-normalize: None
 
 # NoamLR is deprecated. Use WarmupLR.
 # The following is equivalent setting for NoamLR:
@@ -92,4 +93,18 @@
     time_mask_width_range:
     - 0
     - 40
-    num_time_mask: 2
\ No newline at end of file
+    num_time_mask: 2
+
+dataset_conf:
+    shuffle: True
+    shuffle_conf:
+        shuffle_size: 2048
+        sort_size: 500
+    batch_conf:
+        batch_type: token
+        batch_size: 25000
+    num_workers: 8
+
+log_interval: 50
+unused_parameters: true
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh b/egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+#             2017  Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1;
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tgz
+  else
+    echo "$data/$part.tgz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+  if ! command -v wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1;
+  fi
+  full_url=$url/$part.tgz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  cd $data || exit 1
+  if ! wget --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1;
+  fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+  echo "$0: error un-tarring archive $data/$part.tgz"
+  exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+  cd $data/$part/wav || exit 1
+  for wav in ./*.tar.gz; do
+    echo "Extracting wav from $wav"
+    tar -zxf $wav && rm $wav
+  done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+  rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh b/egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-#           2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
-  echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
-  echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
-  exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
-  echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
-  exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
-  cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/data2vec_transformer_finetune/run.sh b/egs/aishell/data2vec_transformer_finetune/run.sh
index 26222e6..7bbcf90 100755
--- a/egs/aishell/data2vec_transformer_finetune/run.sh
+++ b/egs/aishell/data2vec_transformer_finetune/run.sh
@@ -13,28 +13,26 @@
 infer_cmd=utils/run.pl
 
 # general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
 exp_dir="."
 lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
 stop_stage=4
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
 
 # exp tag
-tag=""
+tag="exp1"
 
 model_name=damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch
 init_param="$HOME/.cache/modelscope/hub/$model_name/basemodel.pb"
@@ -52,7 +50,7 @@
 test_sets="dev test"
 
 asr_config=conf/train_asr_transformer_12e_6d_3072_768.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
 
 inference_config=conf/decode_asr_transformer.yaml
 inference_asr_model=valid.cer_ctc.ave_10best.pb
@@ -69,10 +67,16 @@
     _ngpu=0
 fi
 
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    echo "stage -1: Data Download"
+    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
     echo "stage 0: Data preparation"
     # Data preparation
-    local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
     for x in train dev test; do
         cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
         paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -82,46 +86,9 @@
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
-        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
-    utils/fix_data_feat.sh ${fbankdir}/train
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
-    utils/fix_data_feat.sh ${fbankdir}/dev
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
-    utils/fix_data_feat.sh ${fbankdir}/test
-     
-    # compute global cmvn
-    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
-        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
-    # apply cmvn 
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-    
-    cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
-    cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
-    cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
-    utils/fix_data_feat.sh ${feat_train_dir}
-    utils/fix_data_feat.sh ${feat_dev_dir}
-    utils/fix_data_feat.sh ${feat_test_dir}
-
-    #generate ark list 
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
 fi
 
 token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -129,35 +96,27 @@
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     echo "stage 2: Dictionary Preparation"
     mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-   
+
     echo "make a dictionary"
     echo "<blank>" > ${token_list}
     echo "<s>" >> ${token_list}
     echo "</s>" >> ${token_list}
-    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
-    num_token=$(cat ${token_list} | wc -l)
     echo "<unk>" >> ${token_list}
-    vocab_size=$(cat ${token_list} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train 
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
 fi
 
 # Training Stage
 world_size=$gpu_num  # run on one machine
 if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
     echo "stage 3: Training"
-    python utils/download_model.py  --model_name ${model_name}  # download pretrained model on ModelScope
+     python utils/download_model.py  --model_name ${model_name}  # download pretrained model on ModelScope
     mkdir -p ${exp_dir}/exp/${model_dir}
     mkdir -p ${exp_dir}/exp/${model_dir}/log
     INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
     if [ -f $INIT_FILE ];then
         rm -f $INIT_FILE
-    fi 
+    fi
     init_method=file://$(readlink -f $INIT_FILE)
     echo "$0: init method is $init_method"
     for ((i = 0; i < $gpu_num; ++i)); do
@@ -165,27 +124,23 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train.py \
+            train.py \
+                --task_name asr \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
                 --token_type char \
                 --token_list $token_list \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char  \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
                 --init_param ${init_param} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
-                --input_size $feats_dim \
                 --ngpu $gpu_num \
                 --num_worker_count $count \
-                --multiprocessing_distributed true \
                 --dist_init_method $init_method \
                 --dist_world_size $world_size \
                 --dist_rank $rank \
@@ -208,7 +163,7 @@
             exit 0
         fi
         mkdir -p "${_logdir}"
-        _data="${feats_dir}/${dumpdir}/${dset}"
+        _data="${feats_dir}/data/${dset}"
         key_file=${_data}/${scp}
         num_scp_file="$(<${key_file} wc -l)"
         _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -229,6 +184,7 @@
                 --njob ${njob} \
                 --gpuid_list ${gpuid_list} \
                 --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                 --key_file "${_logdir}"/keys.JOB.scp \
                 --asr_train_config "${asr_exp}"/config.yaml \
                 --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -249,4 +205,4 @@
         tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
         cat ${_dir}/text.cer.txt
     done
-fi
+fi
\ No newline at end of file
diff --git a/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml b/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
index 24b2620..bac8d04 100644
--- a/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
+++ b/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
@@ -29,6 +29,17 @@
     self_attention_dropout_rate: 0.0
     src_attention_dropout_rate: 0.0
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
 model: paraformer
 model_conf:
     ctc_weight: 0.3
@@ -37,15 +48,10 @@
     predictor_weight: 1.0
     sampling_ratio: 0.4
 
-# minibatch related
-batch_type: length
-batch_bins: 25000
-num_workers: 16
-
 # optimization related
 accum_grad: 1
 grad_clip: 5
-max_epoch: 50
+max_epoch: 150
 val_scheduler_criterion:
     - valid
     - acc
@@ -78,7 +84,7 @@
     - 40
     num_time_mask: 2
 
-predictor: cif_predictor_v2
+predictor: cif_predictor
 predictor_conf:
     idim: 256
     threshold: 1.0
@@ -86,6 +92,15 @@
     r_order: 1
     tail_threshold: 0.45
 
+dataset_conf:
+    shuffle: True
+    shuffle_conf:
+        shuffle_size: 2048
+        sort_size: 500
+    batch_conf:
+        batch_type: token
+        batch_size: 25000
+    num_workers: 8
 
 log_interval: 50
 normalize: None
\ No newline at end of file
diff --git a/egs/aishell/paraformer/local/download_and_untar.sh b/egs/aishell/paraformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/paraformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+#             2017  Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1;
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tgz
+  else
+    echo "$data/$part.tgz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+  if ! command -v wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1;
+  fi
+  full_url=$url/$part.tgz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  cd $data || exit 1
+  if ! wget --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1;
+  fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+  echo "$0: error un-tarring archive $data/$part.tgz"
+  exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+  cd $data/$part/wav || exit 1
+  for wav in ./*.tar.gz; do
+    echo "Extracting wav from $wav"
+    tar -zxf $wav && rm $wav
+  done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+  rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/paraformer/local/prepare_data.sh b/egs/aishell/paraformer/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/paraformer/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-#           2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
-  echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
-  echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
-  exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
-  echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
-  exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
-  cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/paraformer/run.sh b/egs/aishell/paraformer/run.sh
index 53b5f90..430da6e 100755
--- a/egs/aishell/paraformer/run.sh
+++ b/egs/aishell/paraformer/run.sh
@@ -16,25 +16,23 @@
 feats_dir="../DATA" #feature output dictionary
 exp_dir="."
 lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
 stop_stage=4
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
 
 # exp tag
-tag=""
+tag="exp1"
 
 . utils/parse_options.sh || exit 1;
 
@@ -49,7 +47,7 @@
 test_sets="dev test"
 
 asr_config=conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
 
 inference_config=conf/decode_asr_transformer_noctc_1best.yaml
 inference_asr_model=valid.acc.ave_10best.pb
@@ -66,10 +64,16 @@
     _ngpu=0
 fi
 
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    echo "stage -1: Data Download"
+    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
     echo "stage 0: Data preparation"
     # Data preparation
-    local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
     for x in train dev test; do
         cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
         paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -79,46 +83,9 @@
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
-        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
-    utils/fix_data_feat.sh ${fbankdir}/train
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
-    utils/fix_data_feat.sh ${fbankdir}/dev
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
-    utils/fix_data_feat.sh ${fbankdir}/test
-     
-    # compute global cmvn
-    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
-        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
-    # apply cmvn 
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-    
-    cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
-    cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
-    cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
-    utils/fix_data_feat.sh ${feat_train_dir}
-    utils/fix_data_feat.sh ${feat_dev_dir}
-    utils/fix_data_feat.sh ${feat_test_dir}
-
-    #generate ark list 
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
 fi
 
 token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -131,17 +98,9 @@
     echo "<blank>" > ${token_list}
     echo "<s>" >> ${token_list}
     echo "</s>" >> ${token_list}
-    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
-    num_token=$(cat ${token_list} | wc -l)
     echo "<unk>" >> ${token_list}
-    vocab_size=$(cat ${token_list} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train 
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
 fi
 
 # Training Stage
@@ -161,26 +120,22 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train_paraformer.py \
+            train.py \
+                --task_name asr \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
                 --token_type char \
                 --token_list $token_list \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char  \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
-                --input_size $feats_dim \
                 --ngpu $gpu_num \
                 --num_worker_count $count \
-                --multiprocessing_distributed true \
                 --dist_init_method $init_method \
                 --dist_world_size $world_size \
                 --dist_rank $rank \
@@ -203,7 +158,7 @@
             exit 0
         fi
         mkdir -p "${_logdir}"
-        _data="${feats_dir}/${dumpdir}/${dset}"
+        _data="${feats_dir}/data/${dset}"
         key_file=${_data}/${scp}
         num_scp_file="$(<${key_file} wc -l)"
         _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -224,6 +179,7 @@
                 --njob ${njob} \
                 --gpuid_list ${gpuid_list} \
                 --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                 --key_file "${_logdir}"/keys.JOB.scp \
                 --asr_train_config "${asr_exp}"/config.yaml \
                 --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml b/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
index f51a2ea..f2652e8 100644
--- a/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
+++ b/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
@@ -29,6 +29,17 @@
     self_attention_dropout_rate: 0.0
     src_attention_dropout_rate: 0.0
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
 # hybrid CTC/attention
 model: paraformer_bert
 model_conf:
@@ -41,19 +52,10 @@
     embed_dims: 768
     embeds_loss_weight: 2.0
 
-
-
-# minibatch related
-#batch_type: length
-#batch_bins: 40000
-batch_type: numel
-batch_bins: 2000000
-num_workers: 16
-
 # optimization related
-accum_grad: 4
+accum_grad: 1
 grad_clip: 5
-max_epoch: 50
+max_epoch: 150
 val_scheduler_criterion:
     - valid
     - acc
@@ -92,8 +94,19 @@
     threshold: 1.0
     l_order: 1
     r_order: 1
+    tail_threshold: 0.45
 
+dataset_conf:
+    data_names: speech,text,embed
+    data_types: sound,text,kaldi_ark
+    shuffle: True
+    shuffle_conf:
+        shuffle_size: 2048
+        sort_size: 500
+    batch_conf:
+        batch_type: token
+        batch_size: 25000
+    num_workers: 8
 
 log_interval: 50
-normalize: None
-allow_variable_data_keys: true
\ No newline at end of file
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell/paraformerbert/local/aishell_data_prep.sh b/egs/aishell/paraformerbert/local/aishell_data_prep.sh
index b6ea36b..83f489b 100755
--- a/egs/aishell/paraformerbert/local/aishell_data_prep.sh
+++ b/egs/aishell/paraformerbert/local/aishell_data_prep.sh
@@ -5,19 +5,20 @@
 
 #. ./path.sh || exit 1;
 
-if [ $# != 2 ]; then
-  echo "Usage: $0 <audio-path> <text-path>"
-  echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript"
+if [ $# != 3 ]; then
+  echo "Usage: $0 <audio-path> <text-path> <output-path>"
+  echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
   exit 1;
 fi
 
 aishell_audio_dir=$1
 aishell_text=$2/aishell_transcript_v0.8.txt
+output_dir=$3
 
-train_dir=data/local/train
-dev_dir=data/local/dev
-test_dir=data/local/test
-tmp_dir=data/local/tmp
+train_dir=$output_dir/data/local/train
+dev_dir=$output_dir/data/local/dev
+test_dir=$output_dir/data/local/test
+tmp_dir=$output_dir/data/local/tmp
 
 mkdir -p $train_dir
 mkdir -p $dev_dir
@@ -53,12 +54,12 @@
   sort -u $dir/transcripts.txt > $dir/text
 done
 
-mkdir -p data/train data/dev data/test
+mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
 
 for f in wav.scp text; do
-  cp $train_dir/$f data/train/$f || exit 1;
-  cp $dev_dir/$f data/dev/$f || exit 1;
-  cp $test_dir/$f data/test/$f || exit 1;
+  cp $train_dir/$f $output_dir/data/train/$f || exit 1;
+  cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
+  cp $test_dir/$f $output_dir/data/test/$f || exit 1;
 done
 
 echo "$0: AISHELL data preparation succeeded"
diff --git a/egs/aishell/paraformerbert/local/download_and_untar.sh b/egs/aishell/paraformerbert/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/paraformerbert/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+#             2017  Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1;
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tgz
+  else
+    echo "$data/$part.tgz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+  if ! command -v wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1;
+  fi
+  full_url=$url/$part.tgz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  cd $data || exit 1
+  if ! wget --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1;
+  fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+  echo "$0: error un-tarring archive $data/$part.tgz"
+  exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+  cd $data/$part/wav || exit 1
+  for wav in ./*.tar.gz; do
+    echo "Extracting wav from $wav"
+    tar -zxf $wav && rm $wav
+  done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+  rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/paraformerbert/local/extract_embeds.sh b/egs/aishell/paraformerbert/local/extract_embeds.sh
index 9cf5940..c07e528 100755
--- a/egs/aishell/paraformerbert/local/extract_embeds.sh
+++ b/egs/aishell/paraformerbert/local/extract_embeds.sh
@@ -3,20 +3,16 @@
 stage=1
 stop_stage=3
 
-bert_model_root="../../huggingface_models"
 bert_model_name="bert-base-chinese"
-#bert_model_name="chinese-roberta-wwm-ext"
-#bert_model_name="mengzi-bert-base"
 raw_dataset_path="../DATA"
-model_path=${bert_model_root}/${bert_model_name}
+nj=64
+model_path=${bert_model_name}
 
 . utils/parse_options.sh || exit 1;
 
-nj=32
-
 for data_set in train dev test;do
-    scp=$raw_dataset_path/dump/fbank/${data_set}/text
-    local_scp_dir_raw=$raw_dataset_path/embeds/$bert_model_name/${data_set}
+    scp=$raw_dataset_path/data/${data_set}/text
+    local_scp_dir_raw=${raw_dataset_path}/data/embeds/${data_set}
     local_scp_dir=$local_scp_dir_raw/split$nj
     local_records_dir=$local_scp_dir_raw/ark
 
diff --git a/egs/aishell/paraformerbert/run.sh b/egs/aishell/paraformerbert/run.sh
index 2487eac..a4ad195 100755
--- a/egs/aishell/paraformerbert/run.sh
+++ b/egs/aishell/paraformerbert/run.sh
@@ -8,7 +8,7 @@
 count=1
 gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
 # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
+njob=1
 train_cmd=utils/run.pl
 infer_cmd=utils/run.pl
 
@@ -16,29 +16,26 @@
 feats_dir="../DATA" #feature output dictionary, for large data
 exp_dir="."
 lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
 stop_stage=4
 
 skip_extract_embed=false
-bert_model_root="../../huggingface_models"
 bert_model_name="bert-base-chinese"
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
 
 # exp tag
-tag=""
+tag="exp1"
 
 . utils/parse_options.sh || exit 1;
 
@@ -53,7 +50,7 @@
 test_sets="dev test"
 
 asr_config=conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
 
 inference_config=conf/decode_asr_transformer_noctc_1best.yaml
 inference_asr_model=valid.acc.ave_10best.pb
@@ -70,10 +67,17 @@
     _ngpu=0
 fi
 
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    echo "stage -1: Data Download"
+    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
+
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
     echo "stage 0: Data preparation"
     # Data preparation
-    local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
     for x in train dev test; do
         cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
         paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -83,46 +87,9 @@
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
-        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
-    utils/fix_data_feat.sh ${fbankdir}/train
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
-    utils/fix_data_feat.sh ${fbankdir}/dev
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
-    utils/fix_data_feat.sh ${fbankdir}/test
-     
-    # compute global cmvn
-    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
-        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
-    # apply cmvn 
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-    
-    cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
-    cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
-    cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
-    utils/fix_data_feat.sh ${feat_train_dir}
-    utils/fix_data_feat.sh ${feat_dev_dir}
-    utils/fix_data_feat.sh ${feat_test_dir}
-
-    #generate ark list 
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
 fi
 
 token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -135,17 +102,9 @@
     echo "<blank>" > ${token_list}
     echo "<s>" >> ${token_list}
     echo "</s>" >> ${token_list}
-    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
-    num_token=$(cat ${token_list} | wc -l)
     echo "<unk>" >> ${token_list}
-    vocab_size=$(cat ${token_list} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train 
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
 fi
 
 # Training Stage
@@ -155,9 +114,9 @@
     if ! "${skip_extract_embed}"; then
         echo "extract embeddings..."
         local/extract_embeds.sh \
-            --bert_model_root ${bert_model_root} \
             --bert_model_name ${bert_model_name} \
-            --raw_dataset_path ${feats_dir}
+            --raw_dataset_path ${feats_dir} \
+            --nj $nj
     fi
     mkdir -p ${exp_dir}/exp/${model_dir}
     mkdir -p ${exp_dir}/exp/${model_dir}/log
@@ -172,31 +131,23 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train_paraformer.py \
+            train.py \
+                --task_name asr \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
                 --token_type char \
                 --token_list $token_list \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
-                --train_data_path_and_name_and_type ${feats_dir}/embeds/${bert_model_name}/${train_set}/embeds.scp,embed,${type} \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
-                --train_shape_file ${feats_dir}/embeds/${bert_model_name}/${train_set}/embeds.shape \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
-                --valid_data_path_and_name_and_type ${feats_dir}/embeds/${bert_model_name}/${valid_set}/embeds.scp,embed,${type} \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char  \
-                --valid_shape_file ${feats_dir}/embeds/${bert_model_name}/${valid_set}/embeds.shape \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --embed_path ${feats_dir}/data \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
-                --allow_variable_data_keys true \
-                --input_size $feats_dim \
                 --ngpu $gpu_num \
                 --num_worker_count $count \
-                --multiprocessing_distributed true \
                 --dist_init_method $init_method \
                 --dist_world_size $world_size \
                 --dist_rank $rank \
@@ -260,5 +211,4 @@
         tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
         cat ${_dir}/text.cer.txt
     done
-fi
-
+fi
\ No newline at end of file
diff --git a/egs/aishell/transformer/conf/train_asr_transformer.yaml b/egs/aishell/transformer/conf/train_asr_transformer.yaml
index ce987e7..22e651b 100644
--- a/egs/aishell/transformer/conf/train_asr_transformer.yaml
+++ b/egs/aishell/transformer/conf/train_asr_transformer.yaml
@@ -23,22 +23,28 @@
     self_attention_dropout_rate: 0.0
     src_attention_dropout_rate: 0.0
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
 # hybrid CTC/attention
 model_conf:
     ctc_weight: 0.3
     lsm_weight: 0.1     # label smoothing option
     length_normalized_loss: false
 
-# minibatch related
-batch_type: length
-batch_bins: 32000
-num_workers: 8
-
 # optimization related
 accum_grad: 1
 grad_clip: 5
-patience: 3
-max_epoch: 20
+patience: none
+max_epoch: 60
 val_scheduler_criterion:
     - valid
     - acc
@@ -66,5 +72,15 @@
 scheduler_conf:
     warmup_steps: 25000
 
+dataset_conf:
+    shuffle: True
+    shuffle_conf:
+        shuffle_size: 2048
+        sort_size: 500
+    batch_conf:
+        batch_type: token
+        batch_size: 25000
+    num_workers: 8
+
 log_interval: 50
 normalize: None
diff --git a/egs/aishell/transformer/local/download_and_untar.sh b/egs/aishell/transformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/transformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+#             2017  Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1;
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tgz
+  else
+    echo "$data/$part.tgz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+  if ! command -v wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1;
+  fi
+  full_url=$url/$part.tgz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  cd $data || exit 1
+  if ! wget --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1;
+  fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+  echo "$0: error un-tarring archive $data/$part.tgz"
+  exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+  cd $data/$part/wav || exit 1
+  for wav in ./*.tar.gz; do
+    echo "Extracting wav from $wav"
+    tar -zxf $wav && rm $wav
+  done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+  rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/transformer/local/prepare_data.sh b/egs/aishell/transformer/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/transformer/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-#           2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
-  echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
-  echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
-  exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
-  echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
-  exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
-  cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/transformer/path.sh b/egs/aishell/transformer/path.sh
index 7972642..b4064e1 100755
--- a/egs/aishell/transformer/path.sh
+++ b/egs/aishell/transformer/path.sh
@@ -3,3 +3,5 @@
 # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
 export PYTHONIOENCODING=UTF-8
 export PATH=$FUNASR_DIR/funasr/bin:$PATH
+
+export OMP_NUM_THREADS=1
diff --git a/egs/aishell/transformer/run.sh b/egs/aishell/transformer/run.sh
index f66a338..b7ad9cd 100755
--- a/egs/aishell/transformer/run.sh
+++ b/egs/aishell/transformer/run.sh
@@ -8,33 +8,31 @@
 count=1
 gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
 # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=8
+njob=5
 train_cmd=utils/run.pl
 infer_cmd=utils/run.pl
 
 # general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
 exp_dir="."
 lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
 stop_stage=4
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
 
 # exp tag
-tag=""
+tag="exp1"
 
 . utils/parse_options.sh || exit 1;
 
@@ -48,8 +46,8 @@
 valid_set=dev
 test_sets="dev test"
 
-asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+asr_config=conf/train_asr_transformer.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
 
 inference_config=conf/decode_asr_transformer.yaml
 inference_asr_model=valid.acc.ave_10best.pb
@@ -66,10 +64,16 @@
     _ngpu=0
 fi
 
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    echo "stage -1: Data Download"
+    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
     echo "stage 0: Data preparation"
     # Data preparation
-    local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
     for x in train dev test; do
         cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
         paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -79,46 +83,9 @@
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
-        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
-    utils/fix_data_feat.sh ${fbankdir}/train
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
-    utils/fix_data_feat.sh ${fbankdir}/dev
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-        ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
-    utils/fix_data_feat.sh ${fbankdir}/test
-     
-    # compute global cmvn
-    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
-        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
-    # apply cmvn 
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-    
-    cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
-    cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
-    cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
-    utils/fix_data_feat.sh ${feat_train_dir}
-    utils/fix_data_feat.sh ${feat_dev_dir}
-    utils/fix_data_feat.sh ${feat_test_dir}
-
-    #generate ark list 
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
 fi
 
 token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -126,22 +93,14 @@
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     echo "stage 2: Dictionary Preparation"
     mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-   
+
     echo "make a dictionary"
     echo "<blank>" > ${token_list}
     echo "<s>" >> ${token_list}
     echo "</s>" >> ${token_list}
-    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
-    num_token=$(cat ${token_list} | wc -l)
     echo "<unk>" >> ${token_list}
-    vocab_size=$(cat ${token_list} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train 
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
 fi
 
 # Training Stage
@@ -153,7 +112,7 @@
     INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
     if [ -f $INIT_FILE ];then
         rm -f $INIT_FILE
-    fi 
+    fi
     init_method=file://$(readlink -f $INIT_FILE)
     echo "$0: init method is $init_method"
     for ((i = 0; i < $gpu_num; ++i)); do
@@ -161,26 +120,22 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train.py \
+            train.py \
+                --task_name asr \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
                 --token_type char \
                 --token_list $token_list \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char  \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
-                --input_size $feats_dim \
                 --ngpu $gpu_num \
                 --num_worker_count $count \
-                --multiprocessing_distributed true \
                 --dist_init_method $init_method \
                 --dist_world_size $world_size \
                 --dist_rank $rank \
@@ -203,7 +158,7 @@
             exit 0
         fi
         mkdir -p "${_logdir}"
-        _data="${feats_dir}/${dumpdir}/${dset}"
+        _data="${feats_dir}/data/${dset}"
         key_file=${_data}/${scp}
         num_scp_file="$(<${key_file} wc -l)"
         _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -224,6 +179,7 @@
                 --njob ${njob} \
                 --gpuid_list ${gpuid_list} \
                 --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                 --key_file "${_logdir}"/keys.JOB.scp \
                 --asr_train_config "${asr_exp}"/config.yaml \
                 --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -244,4 +200,4 @@
         tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
         cat ${_dir}/text.cer.txt
     done
-fi
+fi
\ No newline at end of file
diff --git a/egs/aishell/transformer/utils/cmvn_converter.py b/egs/aishell/transformer/utils/cmvn_converter.py
index cb978af..d405d12 100644
--- a/egs/aishell/transformer/utils/cmvn_converter.py
+++ b/egs/aishell/transformer/utils/cmvn_converter.py
@@ -9,16 +9,14 @@
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
     parser.add_argument(
-        "--cmvn-json",
-        "-c",
+        "--cmvn_json",
         default=False,
         required=True,
         type=str,
         help="cmvn json file",
     )
     parser.add_argument(
-        "--am-mvn",
-        "-a",
+        "--am_mvn",
         default=False,
         required=True,
         type=str,
diff --git a/egs/aishell/transformer/utils/combine_cmvn_file.py b/egs/aishell/transformer/utils/combine_cmvn_file.py
index b2974a4..c525973 100755
--- a/egs/aishell/transformer/utils/combine_cmvn_file.py
+++ b/egs/aishell/transformer/utils/combine_cmvn_file.py
@@ -1,6 +1,9 @@
 import argparse
 import json
+import os
+
 import numpy as np
+
 
 def get_parser():
     parser = argparse.ArgumentParser(
@@ -8,15 +11,13 @@
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
     parser.add_argument(
-        "--dims",
-        "-d",
+        "--dim",
         default=80,
         type=int,
-        help="feature dims",
+        help="feature dim",
     )
     parser.add_argument(
-        "--cmvn-dir",
-        "-c",
+        "--cmvn_dir",
         default=False,
         required=True,
         type=str,
@@ -25,15 +26,13 @@
 
     parser.add_argument(
         "--nj",
-        "-n",
         default=1,
         required=True,
         type=int,
-        help="num of cmvn file",
+        help="num of cmvn files",
     )
     parser.add_argument(
-        "--output-dir",
-        "-o",
+        "--output_dir",
         default=False,
         required=True,
         type=str,
@@ -46,14 +45,14 @@
     parser = get_parser()
     args = parser.parse_args()
 
-    total_means = np.zeros(args.dims)
-    total_vars = np.zeros(args.dims)
+    total_means = np.zeros(args.dim)
+    total_vars = np.zeros(args.dim)
     total_frames = 0
 
-    cmvn_file = args.output_dir + "/cmvn.json"
+    cmvn_file = os.path.join(args.output_dir, "cmvn.json")
 
-    for i in range(1, args.nj+1):
-        with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
+    for i in range(1, args.nj + 1):
+        with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
             cmvn_stats = json.load(fin)
 
         total_means += np.array(cmvn_stats["mean_stats"])
diff --git a/egs/aishell/transformer/utils/compute_cmvn.py b/egs/aishell/transformer/utils/compute_cmvn.py
index 2b96e26..949cc08 100755
--- a/egs/aishell/transformer/utils/compute_cmvn.py
+++ b/egs/aishell/transformer/utils/compute_cmvn.py
@@ -1,8 +1,10 @@
-from kaldiio import ReadHelper
-
 import argparse
-import numpy as np
 import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
 
 
 def get_parser():
@@ -11,55 +13,83 @@
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
     parser.add_argument(
-        "--dims",
-        "-d",
+        "--dim",
         default=80,
         type=int,
-        help="feature dims",
+        help="feature dimension",
     )
     parser.add_argument(
-        "--ark-file",
-        "-a",
+        "--wav_path",
         default=False,
         required=True,
         type=str,
-        help="fbank ark file",
+        help="the path of wav scps",
     )
     parser.add_argument(
-        "--ark-index",
-        "-i",
+        "--idx",
         default=1,
         required=True,
         type=int,
-        help="ark index",
-    )
-    parser.add_argument(
-        "--output-dir",
-        "-o",
-        default=False,
-        required=True,
-        type=str,
-        help="output dir",
+        help="index",
     )
     return parser
+
+
+def compute_fbank(wav_file,
+                  num_mel_bins=80,
+                  frame_length=25,
+                  frame_shift=10,
+                  dither=0.0,
+                  resample_rate=16000,
+                  speed=1.0,
+                  window_type="hamming"):
+    waveform, sample_rate = torchaudio.load(wav_file)
+    if resample_rate != sample_rate:
+        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+                                                  new_freq=resample_rate)(waveform)
+    if speed != 1.0:
+        waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+            waveform, resample_rate,
+            [['speed', str(speed)], ['rate', str(resample_rate)]]
+        )
+
+    waveform = waveform * (1 << 15)
+    mat = kaldi.fbank(waveform,
+                      num_mel_bins=num_mel_bins,
+                      frame_length=frame_length,
+                      frame_shift=frame_shift,
+                      dither=dither,
+                      energy_floor=0.0,
+                      window_type=window_type,
+                      sample_frequency=resample_rate)
+
+    return mat.numpy()
 
 
 def main():
     parser = get_parser()
     args = parser.parse_args()
 
-    ark_file = args.ark_file + "/feats." + str(args.ark_index) + ".ark"
-    cmvn_file = args.output_dir + "/cmvn." + str(args.ark_index) + ".json"
+    wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
+    cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
 
-    mean_stats = np.zeros(args.dims)
-    var_stats = np.zeros(args.dims)
+    mean_stats = np.zeros(args.dim)
+    var_stats = np.zeros(args.dim)
     total_frames = 0
 
-    with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
-        for key, mat in ark_reader:
-            mean_stats += np.sum(mat, axis=0)
-            var_stats += np.sum(np.square(mat), axis=0)
-            total_frames += mat.shape[0]
+    # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+    #     for key, mat in ark_reader:
+    #         mean_stats += np.sum(mat, axis=0)
+    #         var_stats += np.sum(np.square(mat), axis=0)
+    #         total_frames += mat.shape[0]
+    with open(wav_scp_file) as f:
+        lines = f.readlines()
+        for line in lines:
+            _, wav_file = line.strip().split()
+            fbank = compute_fbank(wav_file, num_mel_bins=args.dim)
+            mean_stats += np.sum(fbank, axis=0)
+            var_stats += np.sum(np.square(fbank), axis=0)
+            total_frames += fbank.shape[0]
 
     cmvn_info = {
         'mean_stats': list(mean_stats.tolist()),
diff --git a/egs/aishell/transformer/utils/compute_cmvn.sh b/egs/aishell/transformer/utils/compute_cmvn.sh
index 12173ee..7663df9 100755
--- a/egs/aishell/transformer/utils/compute_cmvn.sh
+++ b/egs/aishell/transformer/utils/compute_cmvn.sh
@@ -11,15 +11,24 @@
 . utils/parse_options.sh || exit 1;
 
 fbankdir=$1
-logdir=$2
 
-output_dir=${fbankdir}/cmvn; mkdir -p ${output_dir}
-mkdir -p ${logdir}
+split_dir=${fbankdir}/cmvn/split_${nj};
+mkdir -p $split_dir
+split_scps=""
+for n in $(seq $nj); do
+    split_scps="$split_scps $split_dir/wav.$n.scp"
+done
+utils/split_scp.pl ${fbankdir}/wav.scp $split_scps || exit 1;
 
+logdir=${fbankdir}/cmvn/log
 $cmd JOB=1:$nj $logdir/cmvn.JOB.log \
-    python utils/compute_cmvn.py -d ${feats_dim} -a $fbankdir/ark -i JOB -o ${output_dir} \
-        || exit 1;
+    python utils/compute_cmvn.py \
+      --dim ${feats_dim} \
+      --wav_path $split_dir \
+      --idx JOB
 
-python utils/combine_cmvn_file.py -d ${feats_dim} -c ${output_dir} -n $nj -o $fbankdir
+python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
+
+python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/cmvn.mvn
 
 echo "$0: Succeeded compute global cmvn"
diff --git a/egs/aishell2/conformer/conf/train_asr_conformer.yaml b/egs/aishell2/conformer/conf/train_asr_conformer.yaml
index 02fc5a8..3fd034f 100644
--- a/egs/aishell2/conformer/conf/train_asr_conformer.yaml
+++ b/egs/aishell2/conformer/conf/train_asr_conformer.yaml
@@ -29,6 +29,17 @@
     self_attention_dropout_rate: 0.0
     src_attention_dropout_rate: 0.0
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
 # hybrid CTC/attention
 model_conf:
     ctc_weight: 0.3
@@ -39,7 +50,7 @@
 accum_grad: 2
 grad_clip: 5
 patience: none
-max_epoch: 50
+max_epoch: 150
 val_scheduler_criterion:
     - valid
     - acc
@@ -72,10 +83,8 @@
     - 40
     num_time_mask: 2
 
-log_interval: 50
-normalize: None
-
 dataset_conf:
+    data_types: sound,text
     shuffle: True
     shuffle_conf:
         shuffle_size: 2048
@@ -83,4 +92,7 @@
     batch_conf:
         batch_type: token
         batch_size: 25000
-    num_workers: 8
\ No newline at end of file
+    num_workers: 8
+
+log_interval: 50
+normalize: None
diff --git a/egs/aishell2/conformer/local/prepare_data.sh b/egs/aishell2/conformer/local/prepare_data.sh
index ce6ee19..77791f9 100755
--- a/egs/aishell2/conformer/local/prepare_data.sh
+++ b/egs/aishell2/conformer/local/prepare_data.sh
@@ -34,14 +34,14 @@
 # validate utt-key list, IC0803W0380 is a bad utterance
 awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
 awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
 
 # wav.scp
 awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
 
 # text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
 
 # copy prepared resources from tmp_dir to target dir
 mkdir -p $dir
diff --git a/egs/aishell2/conformer/run.sh b/egs/aishell2/conformer/run.sh
index f9ea69a..65fd666 100755
--- a/egs/aishell2/conformer/run.sh
+++ b/egs/aishell2/conformer/run.sh
@@ -9,31 +9,28 @@
 gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
 # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
 njob=5
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
 infer_cmd=utils/run.pl
 
 # general configuration
 feats_dir="../DATA" #feature output dictionary
 exp_dir="."
 lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
 dataset_type=large
-scp=feats.scp
-type=kaldi_ark
-stage=0
+stage=3
 stop_stage=4
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-tr_dir=
-dev_tst_dir=
+tr_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-2/iOS/data
+dev_tst_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-DEV-TEST-SET
 
 # exp tag
 tag="exp1"
@@ -51,13 +48,13 @@
 test_sets="dev_ios test_ios"
 
 asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
 
 inference_config=conf/decode_asr_transformer.yaml
 inference_asr_model=valid.acc.ave_10best.pb
 
 # you can set gpu num for decoding here
-gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, e.g., gpuid_list=2,3, the same as training stage by default
+gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
 ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
 
 if ${gpu_inference}; then
@@ -73,7 +70,7 @@
     # For training set
     local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
     # # For dev and test set
-    for x in Android iOS Mic; do
+    for x in iOS; do
         local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
         local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
     done 
@@ -83,51 +80,14 @@
         paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
              | tr 'A-Z' 'a-z' | tr -d " ") \
             > ${feats_dir}/data/${x}/text
-        tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
         mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
-        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
-    tools/fix_data_feat.sh ${fbankdir}/train
-    for x in android ios mic; do
-        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
-            ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
-        tools/fix_data_feat.sh ${fbankdir}/dev_${x}
-        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
-            ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
-        tools/fix_data_feat.sh ${fbankdir}/test_${x}
-    done
-    
-    # compute global cmvn
-    steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
-    # apply cmvn 
-    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
-    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
-    for x in android ios mic; do
-        steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-            ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
-    done
-    
-    cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
-    tools/fix_data_feat.sh ${feat_train_dir}
-    cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
-    tools/fix_data_feat.sh ${feat_dev_dir}
-    for x in android ios mic; do
-        cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
-        tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
-    done
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
 fi
 
 token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -140,18 +100,12 @@
     echo "<blank>" > ${token_list}
     echo "<s>" >> ${token_list}
     echo "</s>" >> ${token_list}
-    tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
-    num_token=$(cat ${token_list} | wc -l)
     echo "<unk>" >> ${token_list}
-    vocab_size=$(cat ${token_list} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
     mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
     mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set} 
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
 
 # Training Stage
 world_size=$gpu_num  # run on one machine
@@ -170,21 +124,23 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train.py \
+            train.py \
+                --task_name asr \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
-                --dataset_type $dataset_type \
                 --token_type char \
                 --token_list $token_list \
-                --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
-                --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
+                --dataset_type $dataset_type \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
-                --input_size $feats_dim \
                 --ngpu $gpu_num \
                 --num_worker_count $count \
-                --multiprocessing_distributed true \
                 --dist_init_method $init_method \
                 --dist_world_size $world_size \
                 --dist_rank $rank \
@@ -207,7 +163,7 @@
             exit 0
         fi
         mkdir -p "${_logdir}"
-        _data="${feats_dir}/${dumpdir}/${dset}"
+        _data="${feats_dir}/data/${dset}"
         key_file=${_data}/${scp}
         num_scp_file="$(<${key_file} wc -l)"
         _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -228,6 +184,7 @@
                 --njob ${njob} \
                 --gpuid_list ${gpuid_list} \
                 --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                 --key_file "${_logdir}"/keys.JOB.scp \
                 --asr_train_config "${asr_exp}"/config.yaml \
                 --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
index 4052774..b6e8808 100644
--- a/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
+++ b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
@@ -2,47 +2,52 @@
 # encoder related
 encoder: data2vec_encoder
 encoder_conf:
-  extractor_mode: layer_norm
-  encoder_layerdrop: 0.05
-  dropout_input: 0.0
-  dropout_features: 0.0
-  feature_grad_mult: 1.0
-  encoder_embed_dim: 768
+    extractor_mode: layer_norm
+    encoder_layerdrop: 0.05
+    dropout_input: 0.0
+    dropout_features: 0.0
+    feature_grad_mult: 1.0
+    encoder_embed_dim: 768
 
-  mask_prob: 0.65
-  mask_length: 10
+    mask_prob: 0.65
+    mask_length: 10
 
-  loss_beta: 0
-  loss_scale: null
+    loss_beta: 0
+    loss_scale: null
 
-  instance_norm_target_layer: true
-  average_top_k_layers: 8
+    instance_norm_target_layer: true
+    average_top_k_layers: 8
 
-  pos_conv_depth: 5
-  conv_pos: 95
+    pos_conv_depth: 5
+    conv_pos: 95
 
-  ema_decay: 0.999
-  ema_end_decay: 0.9999
-  ema_anneal_end_step: 30000
-  ema_transformer_only: true
-  ema_layers_only: true
+    ema_decay: 0.999
+    ema_end_decay: 0.9999
+    ema_anneal_end_step: 30000
+    ema_transformer_only: true
+    ema_layers_only: true
 
-  require_same_masks: true
-  mask_dropout: 0
+    require_same_masks: true
+    mask_dropout: 0
 
-log_interval: 50
-normalize: None
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
 
-# minibatch related
-batch_type: length
-batch_bins: 64000
-num_workers: 16
+model: data2vec
 
 # optimization related
 accum_grad: 1
 grad_clip: 5
 patience: none
-max_epoch: 600
+max_epoch: 1800
 val_scheduler_criterion:
     - valid
     - acc
@@ -68,7 +73,7 @@
 dataset_conf:
     batch_mode: clipping
     data_names: speech,none
-    data_types: kaldi_ark,none
+    data_types: sound,none
     shuffle: true
     shuffle_conf:
         shuffle_size: 12800
@@ -76,4 +81,7 @@
     batch_conf:
         batch_type: token
         batch_size: 64000
-    num_workers: 8
\ No newline at end of file
+    num_workers: 8
+
+log_interval: 50
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell2/data2vec_pretrain/run.sh b/egs/aishell2/data2vec_pretrain/run.sh
index eceb183..2753f00 100755
--- a/egs/aishell2/data2vec_pretrain/run.sh
+++ b/egs/aishell2/data2vec_pretrain/run.sh
@@ -7,28 +7,25 @@
 gpu_num=8
 count=1
 
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
 
 # general configuration
 feats_dir="../DATA" #feature output dictionary
 exp_dir="."
 lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=char
+speed_perturb="0.9 1.0 1.1"
 dataset_type=large
-stage=0
-stop_stage=4
+stage=3
+stop_stage=3
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-tr_dir=
-dev_tst_dir=
+tr_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-2/iOS/data
+dev_tst_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-DEV-TEST-SET
 
 # exp tag
 tag="exp1"
@@ -45,68 +42,31 @@
 valid_set=dev_ios
 
 asr_config=conf/train_pretrain_transformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml) _${lang}_${token_type}_${tag}"
 
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
     echo "stage 0: Data preparation"
     # For training set
     local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
     # # For dev and test set
-    for x in Android iOS Mic; do
+    for x in iOS; do
         local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
         local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
-    done 
+    done
     # Normalize text to capital letters
-    for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
+    for x in train dev_ios test_ios; do
         mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
         paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
              | tr 'A-Z' 'a-z' | tr -d " ") \
             > ${feats_dir}/data/${x}/text
-        tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
         mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
-        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
-    tools/fix_data_feat.sh ${fbankdir}/train
-    for x in android ios mic; do
-        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
-            ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
-        tools/fix_data_feat.sh ${fbankdir}/dev_${x}
-        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
-            ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
-        tools/fix_data_feat.sh ${fbankdir}/test_${x}
-    done
-    
-    # compute global cmvn
-    steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
-    # apply cmvn 
-    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
-    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
-    for x in android ios mic; do
-        steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-            ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
-    done
-    
-    cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
-    tools/fix_data_feat.sh ${feat_train_dir}
-    cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
-    tools/fix_data_feat.sh ${feat_dev_dir}
-    for x in android ios mic; do
-        cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
-        tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
-    done
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
 fi
 
 token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -114,22 +74,59 @@
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     echo "stage 2: Dictionary Preparation"
     mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-   
+
     echo "make a dictionary"
     echo "<blank>" > ${token_list}
     echo "<s>" >> ${token_list}
     echo "</s>" >> ${token_list}
-    tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
-    num_token=$(cat ${token_list} | wc -l)
     echo "<unk>" >> ${token_list}
-    vocab_size=$(cat ${token_list} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
     mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
     mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set} 
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
+ fi
+
+# Training Stage
+world_size=$gpu_num  # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    echo "stage 3: Training"
+    mkdir -p ${exp_dir}/exp/${model_dir}
+    mkdir -p ${exp_dir}/exp/${model_dir}/log
+    INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+    if [ -f $INIT_FILE ];then
+        rm -f $INIT_FILE
+    fi
+    init_method=file://$(readlink -f $INIT_FILE)
+    echo "$0: init method is $init_method"
+    for ((i = 0; i < $gpu_num; ++i)); do
+        {
+            rank=$i
+            local_rank=$i
+            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+            train.py \
+                --task_name asr \
+                --gpu_id $gpu_id \
+                --use_preprocessor true \
+                --token_type char \
+                --token_list $token_list \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
+                --dataset_type $dataset_type \
+                --resume true \
+                --output_dir ${exp_dir}/exp/${model_dir} \
+                --config $asr_config \
+                --ngpu $gpu_num \
+                --num_worker_count $count \
+                --dist_init_method $init_method \
+                --dist_world_size $world_size \
+                --dist_rank $rank \
+                --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+        } &
+        done
+        wait
 fi
 
 # Training Stage
@@ -149,12 +146,16 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            data2vec_train.py \
+            train.py \
+                --task_name pretrain \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
                 --dataset_type $dataset_type \
-                --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
-                --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
diff --git a/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml b/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
index 450f71a..7fc9794 100644
--- a/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
+++ b/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
@@ -29,6 +29,17 @@
     self_attention_dropout_rate: 0.0
     src_attention_dropout_rate: 0.0
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
 # hybrid CTC/attention
 model: paraformer
 model_conf:
@@ -42,7 +53,7 @@
 accum_grad: 2
 grad_clip: 5
 patience: none
-max_epoch: 50
+max_epoch: 150
 val_scheduler_criterion:
     - valid
     - acc
@@ -82,10 +93,8 @@
   l_order: 1
   r_order: 1
 
-log_interval: 50
-normalize: None
-
 dataset_conf:
+    data_types: sound,text
     shuffle: True
     shuffle_conf:
         shuffle_size: 2048
@@ -93,4 +102,7 @@
     batch_conf:
         batch_type: token
         batch_size: 25000
-    num_workers: 8
\ No newline at end of file
+    num_workers: 8
+
+log_interval: 50
+normalize: None
diff --git a/egs/aishell2/paraformer/local/prepare_data.sh b/egs/aishell2/paraformer/local/prepare_data.sh
index ce6ee19..77791f9 100755
--- a/egs/aishell2/paraformer/local/prepare_data.sh
+++ b/egs/aishell2/paraformer/local/prepare_data.sh
@@ -34,14 +34,14 @@
 # validate utt-key list, IC0803W0380 is a bad utterance
 awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
 awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
 
 # wav.scp
 awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
 
 # text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
 
 # copy prepared resources from tmp_dir to target dir
 mkdir -p $dir
diff --git a/egs/aishell2/paraformer/run.sh b/egs/aishell2/paraformer/run.sh
index e1ea4fe..5236af5 100755
--- a/egs/aishell2/paraformer/run.sh
+++ b/egs/aishell2/paraformer/run.sh
@@ -9,31 +9,28 @@
 gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
 # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
 njob=1
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
 infer_cmd=utils/run.pl
 
 # general configuration
 feats_dir="../DATA" #feature output dictionary
 exp_dir="."
 lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
 dataset_type=large
-scp=feats.scp
-type=kaldi_ark
-stage=0
+stage=3
 stop_stage=4
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-tr_dir=
-dev_tst_dir=
+tr_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-2/iOS/data
+dev_tst_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-DEV-TEST-SET
 
 # exp tag
 tag="exp1"
@@ -51,7 +48,7 @@
 test_sets="dev_ios test_ios"
 
 asr_config=conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
 
 inference_config=conf/decode_asr_transformer_noctc_1best.yaml
 inference_asr_model=valid.acc.ave_10best.pb
@@ -73,61 +70,24 @@
     # For training set
     local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
     # # For dev and test set
-    for x in Android iOS Mic; do
+    for x in iOS; do
         local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
         local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
-    done 
+    done
     # Normalize text to capital letters
-    for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
+    for x in train dev_ios test_ios; do
         mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
         paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
              | tr 'A-Z' 'a-z' | tr -d " ") \
             > ${feats_dir}/data/${x}/text
-        tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
         mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
-        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
-    tools/fix_data_feat.sh ${fbankdir}/train
-    for x in android ios mic; do
-        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
-            ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
-        tools/fix_data_feat.sh ${fbankdir}/dev_${x}
-        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
-            ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
-        tools/fix_data_feat.sh ${fbankdir}/test_${x}
-    done
-    
-    # compute global cmvn
-    steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
-    # apply cmvn 
-    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
-    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
-    for x in android ios mic; do
-        steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-            ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
-    done
-    
-    cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
-    tools/fix_data_feat.sh ${feat_train_dir}
-    cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
-    tools/fix_data_feat.sh ${feat_dev_dir}
-    for x in android ios mic; do
-        cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
-        tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
-    done
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
 fi
 
 token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -135,23 +95,15 @@
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     echo "stage 2: Dictionary Preparation"
     mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-   
+
     echo "make a dictionary"
     echo "<blank>" > ${token_list}
     echo "<s>" >> ${token_list}
     echo "</s>" >> ${token_list}
-    tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
-    num_token=$(cat ${token_list} | wc -l)
     echo "<unk>" >> ${token_list}
-    vocab_size=$(cat ${token_list} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set} 
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
 
 # Training Stage
 world_size=$gpu_num  # run on one machine
@@ -170,28 +122,30 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train_paraformer.py \
+            train.py \
+                --task_name asr \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
-                --dataset_type $dataset_type \
                 --token_type char \
                 --token_list $token_list \
-                --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
-                --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
+                --dataset_type $dataset_type \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
-                --input_size $feats_dim \
                 --ngpu $gpu_num \
                 --num_worker_count $count \
-                --multiprocessing_distributed true \
                 --dist_init_method $init_method \
                 --dist_world_size $world_size \
                 --dist_rank $rank \
                 --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
         } &
-      done
-      wait
+        done
+        wait
 fi
 
 # Testing Stage
@@ -207,7 +161,7 @@
             exit 0
         fi
         mkdir -p "${_logdir}"
-        _data="${feats_dir}/${dumpdir}/${dset}"
+        _data="${feats_dir}/data/${dset}"
         key_file=${_data}/${scp}
         num_scp_file="$(<${key_file} wc -l)"
         _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -228,6 +182,7 @@
                 --njob ${njob} \
                 --gpuid_list ${gpuid_list} \
                 --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                 --key_file "${_logdir}"/keys.JOB.scp \
                 --asr_train_config "${asr_exp}"/config.yaml \
                 --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml b/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
index 19f123e..8968d2d 100644
--- a/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
+++ b/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
@@ -29,6 +29,17 @@
     self_attention_dropout_rate: 0.0
     src_attention_dropout_rate: 0.0
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
 # hybrid CTC/attention
 model: paraformer_bert
 model_conf:
@@ -36,7 +47,7 @@
     lsm_weight: 0.1     # label smoothing option
     length_normalized_loss: false
     predictor_weight: 1.0
-    glat_context_p: 0.4
+    sampling_ratio: 0.4
     embeds_id: 3
     embed_dims: 768
     embeds_loss_weight: 2.0
@@ -45,7 +56,7 @@
 accum_grad: 2
 grad_clip: 5
 patience: none
-max_epoch: 50
+max_epoch: 150
 val_scheduler_criterion:
     - valid
     - acc
@@ -78,24 +89,24 @@
     - 40
     num_time_mask: 2
 
-predictor: cif_predictor_sanm
+predictor: cif_predictor
 predictor_conf:
   idim: 320
   threshold: 1.0
   l_order: 1
   r_order: 1
 
-log_interval: 50
-normalize: None
-
 dataset_conf:
     data_names: speech,text,embed
-    data_types: kaldi_ark,text,kaldi_ark
+    data_types: sound,text,kaldi_ark
     shuffle: True
     shuffle_conf:
-        shuffle_size: 10240
+        shuffle_size: 2048
         sort_size: 500
     batch_conf:
         batch_type: token
         batch_size: 25000
-    num_workers: 8
\ No newline at end of file
+    num_workers: 8
+
+log_interval: 50
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell2/paraformerbert/local/extract_embeds.sh b/egs/aishell2/paraformerbert/local/extract_embeds.sh
index 5f45ff3..ee89924 100755
--- a/egs/aishell2/paraformerbert/local/extract_embeds.sh
+++ b/egs/aishell2/paraformerbert/local/extract_embeds.sh
@@ -3,20 +3,16 @@
 stage=1
 stop_stage=3
 
-bert_model_root="../../huggingface_models"
 bert_model_name="bert-base-chinese"
-#bert_model_name="chinese-roberta-wwm-ext"
-#bert_model_name="mengzi-bert-base"
 raw_dataset_path="../DATA"
-model_path=${bert_model_root}/${bert_model_name}
+nj=64
+model_path=${bert_model_name}
 
 . utils/parse_options.sh || exit 1;
 
-nj=100
-
-for data_set in train dev_ios test_ios;do
-    scp=$raw_dataset_path/dump/fbank/${data_set}/text
-    local_scp_dir_raw=$raw_dataset_path/embeds/$bert_model_name/${data_set}
+for data_set in train dev_ios;do
+    scp=$raw_dataset_path/data/${data_set}/text
+    local_scp_dir_raw=${raw_dataset_path}/data/embeds/${data_set}
     local_scp_dir=$local_scp_dir_raw/split$nj
     local_records_dir=$local_scp_dir_raw/ark
 
@@ -31,7 +27,7 @@
     utils/split_scp.pl $scp ${split_scps}
 
 
-    for num in {0..24};do
+    for num in {0..7};do
         tmp=`expr $num \* 4`
 
         if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
@@ -41,20 +37,9 @@
                 {
                     beg=0
                     gpu=`expr $beg + $idx`
-                    echo $local_scp_dir_raw/log/log.${JOB}
-                    python tools/extract_embeds.py $local_scp_dir/text.$JOB.txt ${local_records_dir}/embeds.${JOB}.ark ${local_records_dir}/embeds.${JOB}.scp ${local_records_dir}/embeds.${JOB}.shape ${gpu} ${model_path} &> $local_scp_dir_raw/log/log.${JOB}
+                    echo ${local_scp_dir}/log.${JOB}
+                    python utils/extract_embeds.py $local_scp_dir/data.$JOB.text ${local_records_dir}/embeds.${JOB}.ark ${local_records_dir}/embeds.${JOB}.scp ${local_records_dir}/embeds.${JOB}.shape ${gpu} ${model_path} &> ${local_scp_dir}/log.${JOB}
             } &
-            done
-            wait
-        fi
-
-        if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
-            for idx in {1..4}; do
-                JOB=`expr $tmp + $idx`
-                echo "upload jobid=$JOB"
-                {
-                    hadoop  fs -put -f ${local_records_dir}/embeds.${JOB}.ark ${odps_des_feature_dir}/embeds.${JOB}.ark
-                } &
             done
             wait
         fi
diff --git a/egs/aishell2/paraformerbert/local/prepare_data.sh b/egs/aishell2/paraformerbert/local/prepare_data.sh
index 801dbe5..77791f9 100755
--- a/egs/aishell2/paraformerbert/local/prepare_data.sh
+++ b/egs/aishell2/paraformerbert/local/prepare_data.sh
@@ -17,7 +17,6 @@
 fi
 
 corpus=$1
-#dict_dir=$2
 tmp=$2
 dir=$3
 
@@ -35,14 +34,14 @@
 # validate utt-key list, IC0803W0380 is a bad utterance
 awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
 awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
 
 # wav.scp
 awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
 
 # text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
 
 # copy prepared resources from tmp_dir to target dir
 mkdir -p $dir
diff --git a/egs/aishell2/paraformerbert/run.sh b/egs/aishell2/paraformerbert/run.sh
index 239a7e3..26c8063 100755
--- a/egs/aishell2/paraformerbert/run.sh
+++ b/egs/aishell2/paraformerbert/run.sh
@@ -8,36 +8,32 @@
 count=1
 gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
 # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
-train_cmd=tools/run.pl
+njob=1
+train_cmd=utils/run.pl
 infer_cmd=utils/run.pl
 
 # general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
 exp_dir="."
 lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
 dataset_type=large
-scp=feats.scp
-type=kaldi_ark
-stage=0
-stop_stage=5
+stage=3
+stop_stage=4
 
 skip_extract_embed=false
-bert_model_root="../../huggingface_models"
 bert_model_name="bert-base-chinese"
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-tr_dir=
-dev_tst_dir=
+tr_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-2/iOS/data
+dev_tst_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-DEV-TEST-SET
 
 # exp tag
 tag="exp1"
@@ -55,7 +51,7 @@
 test_sets="dev_ios test_ios"
 
 asr_config=conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
 
 inference_config=conf/decode_asr_transformer_noctc_1best.yaml
 inference_asr_model=valid.acc.ave_10best.pb
@@ -75,86 +71,42 @@
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
     echo "stage 0: Data preparation"
     # For training set
-    local/prepare_data.sh ${tr_dir} data/local/train data/train || exit 1;
+    local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
     # # For dev and test set
-    for x in Android iOS Mic; do
-        local/prepare_data.sh ${dev_tst_dir}/${x}/dev data/local/dev_${x,,} data/dev_${x,,} || exit 1;
-        local/prepare_data.sh ${dev_tst_dir}/${x}/test data/local/test_${x,,} data/test_${x,,} || exit 1;
-    done 
+    for x in iOS; do
+        local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
+        local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
+    done
     # Normalize text to capital letters
-    for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
-        mv data/${x}/text data/${x}/text.org
-        paste <(cut -f 1 data/${x}/text.org) <(cut -f 2 data/${x}/text.org | tr '[:lower:]' '[:upper:]') \
-            > data/${x}/text
-        tools/text2token.py -n 1 -s 1 data/${x}/text > data/${x}/text.org
-        mv data/${x}/text.org data/${x}/text
+    for x in train dev_ios test_ios; do
+        mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+        paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
+             | tr 'A-Z' 'a-z' | tr -d " ") \
+            > ${feats_dir}/data/${x}/text
+        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+        mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
-        data/train exp/make_fbank/train ${fbankdir}/train
-    tools/fix_data_feat.sh ${fbankdir}/train
-    for x in android ios mic; do
-        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
-            data/dev_${x} exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
-        tools/fix_data_feat.sh ${fbankdir}/dev_${x}
-        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
-            data/test_${x} exp/make_fbank/test_${x} ${fbankdir}/test_${x}
-        tools/fix_data_feat.sh ${fbankdir}/test_${x}
-    done
-    
-    # compute global cmvn
-    steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/train exp/make_fbank/train
-
-    # apply cmvn 
-    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json exp/make_fbank/${train_set} ${feat_train_dir}
-    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json exp/make_fbank/${valid_set} ${feat_dev_dir}
-    for x in android ios mic; do
-        steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-            ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
-    done
-    
-    cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
-    tools/fix_data_feat.sh ${feat_train_dir}
-    cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
-    tools/fix_data_feat.sh ${feat_dev_dir}
-    for x in android ios mic; do
-        cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
-        tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
-    done
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
 fi
 
 token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
 echo "dictionary: ${token_list}"
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     echo "stage 2: Dictionary Preparation"
-    mkdir -p data/${lang}_token_list/char/
-   
+    mkdir -p ${feats_dir}/data/${lang}_token_list/char/
+
     echo "make a dictionary"
     echo "<blank>" > ${token_list}
     echo "<s>" >> ${token_list}
     echo "</s>" >> ${token_list}
-    tools/text2token.py -s 1 -n 1 --space "" data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
-    num_token=$(cat ${token_list} | wc -l)
     echo "<unk>" >> ${token_list}
-    vocab_size=$(cat ${token_list} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
-    mkdir -p asr_stats_fbank_zh_char/${train_set} 
-    mkdir -p asr_stats_fbank_zh_char/${valid_set}
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char asr_stats_fbank_zh_char/${train_set} 
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
 
 # Training Stage
 world_size=$gpu_num  # run on one machine
@@ -163,9 +115,9 @@
     if ! "${skip_extract_embed}"; then
         echo "extract embeddings..."
         local/extract_embeds.sh \
-            --bert_model_root ${bert_model_root} \
             --bert_model_name ${bert_model_name} \
-            --raw_dataset_path ${feats_dir}
+            --raw_dataset_path ${feats_dir} \
+            --nj $nj
     fi
     mkdir -p ${exp_dir}/exp/${model_dir}
     mkdir -p ${exp_dir}/exp/${model_dir}/log
@@ -180,22 +132,24 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train_paraformer.py \
+            train.py \
+                --task_name asr \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
-                --dataset_type $dataset_type \
-                --token_type $token_type \
+                --token_type char \
                 --token_list $token_list \
-                --train_data_file $feats_dir/$dumpdir/${train_set}/data_bert.list \
-                --valid_data_file $feats_dir/$dumpdir/${valid_set}/data_bert.list \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --embed_path ${feats_dir}/data \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
+                --dataset_type $dataset_type \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
-                --allow_variable_data_keys true \
-                --input_size $feats_dim \
                 --ngpu $gpu_num \
                 --num_worker_count $count \
-                --multiprocessing_distributed true \
                 --dist_init_method $init_method \
                 --dist_world_size $world_size \
                 --dist_rank $rank \
@@ -218,7 +172,7 @@
             exit 0
         fi
         mkdir -p "${_logdir}"
-        _data="${feats_dir}/${dumpdir}/${dset}"
+        _data="${feats_dir}/data/${dset}"
         key_file=${_data}/${scp}
         num_scp_file="$(<${key_file} wc -l)"
         _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -239,6 +193,7 @@
                 --njob ${njob} \
                 --gpuid_list ${gpuid_list} \
                 --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                 --key_file "${_logdir}"/keys.JOB.scp \
                 --asr_train_config "${asr_exp}"/config.yaml \
                 --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -259,5 +214,4 @@
         tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
         cat ${_dir}/text.cer.txt
     done
-fi
-
+fi
\ No newline at end of file
diff --git a/egs/aishell2/transformer/conf/train_asr_transformer.yaml b/egs/aishell2/transformer/conf/train_asr_transformer.yaml
index 3e2172d..2d16c18 100644
--- a/egs/aishell2/transformer/conf/train_asr_transformer.yaml
+++ b/egs/aishell2/transformer/conf/train_asr_transformer.yaml
@@ -23,6 +23,17 @@
     self_attention_dropout_rate: 0.0
     src_attention_dropout_rate: 0.0
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
 # hybrid CTC/attention
 model_conf:
     ctc_weight: 0.3
@@ -33,7 +44,7 @@
 accum_grad: 2
 grad_clip: 5
 patience: none
-max_epoch: 50
+max_epoch: 150
 val_scheduler_criterion:
     - valid
     - acc
@@ -66,10 +77,8 @@
     - 40
     num_time_mask: 2
 
-log_interval: 50
-normalize: None
-
 dataset_conf:
+    data_types: sound,text
     shuffle: True
     shuffle_conf:
         shuffle_size: 2048
@@ -77,4 +86,7 @@
     batch_conf:
         batch_type: token
         batch_size: 25000
-    num_workers: 8
\ No newline at end of file
+    num_workers: 8
+
+log_interval: 50
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell2/transformer/local/prepare_data.sh b/egs/aishell2/transformer/local/prepare_data.sh
index ce6ee19..77791f9 100755
--- a/egs/aishell2/transformer/local/prepare_data.sh
+++ b/egs/aishell2/transformer/local/prepare_data.sh
@@ -34,14 +34,14 @@
 # validate utt-key list, IC0803W0380 is a bad utterance
 awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
 awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
 
 # wav.scp
 awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
 
 # text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
 
 # copy prepared resources from tmp_dir to target dir
 mkdir -p $dir
diff --git a/egs/aishell2/transformer/run.sh b/egs/aishell2/transformer/run.sh
index 6f2dd4d..d8b16aa 100755
--- a/egs/aishell2/transformer/run.sh
+++ b/egs/aishell2/transformer/run.sh
@@ -9,31 +9,28 @@
 gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
 # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
 njob=5
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
 infer_cmd=utils/run.pl
 
 # general configuration
 feats_dir="../DATA" #feature output dictionary
 exp_dir="."
 lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
 dataset_type=large
-scp=feats.scp
-type=kaldi_ark
-stage=0
+stage=3
 stop_stage=4
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-tr_dir=
-dev_tst_dir=
+tr_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-2/iOS/data
+dev_tst_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-DEV-TEST-SET
 
 # exp tag
 tag="exp1"
@@ -51,13 +48,13 @@
 test_sets="dev_ios test_ios"
 
 asr_config=conf/train_asr_transformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
 
 inference_config=conf/decode_asr_transformer.yaml
 inference_asr_model=valid.acc.ave_10best.pb
 
 # you can set gpu num for decoding here
-gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, e.g., gpuid_list=2,3, the same as training stage by default
+gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
 ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
 
 if ${gpu_inference}; then
@@ -73,61 +70,24 @@
     # For training set
     local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
     # # For dev and test set
-    for x in Android iOS Mic; do
+    for x in iOS; do
         local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
         local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
-    done 
+    done
     # Normalize text to capital letters
     for x in train dev_ios test_ios; do
         mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
         paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
              | tr 'A-Z' 'a-z' | tr -d " ") \
             > ${feats_dir}/data/${x}/text
-        tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
         mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
-        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
-    tools/fix_data_feat.sh ${fbankdir}/train
-    for x in android ios mic; do
-        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
-            ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
-        tools/fix_data_feat.sh ${fbankdir}/dev_${x}
-        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
-            ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
-        tools/fix_data_feat.sh ${fbankdir}/test_${x}
-    done
-    
-    # compute global cmvn
-    steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
-    # apply cmvn 
-    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
-    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
-    for x in android ios mic; do
-        steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-            ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
-    done
-    
-    cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
-    tools/fix_data_feat.sh ${feat_train_dir}
-    cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
-    tools/fix_data_feat.sh ${feat_dev_dir}
-    for x in android ios mic; do
-        cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
-        tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
-    done
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
 fi
 
 token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -135,23 +95,15 @@
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     echo "stage 2: Dictionary Preparation"
     mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-   
+
     echo "make a dictionary"
     echo "<blank>" > ${token_list}
     echo "<s>" >> ${token_list}
     echo "</s>" >> ${token_list}
-    tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
         | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
-    num_token=$(cat ${token_list} | wc -l)
     echo "<unk>" >> ${token_list}
-    vocab_size=$(cat ${token_list} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set} 
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
 
 # Training Stage
 world_size=$gpu_num  # run on one machine
@@ -170,21 +122,23 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train.py \
+            train.py \
+                --task_name asr \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
-                --dataset_type $dataset_type \
                 --token_type char \
                 --token_list $token_list \
-                --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
-                --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
+                --dataset_type $dataset_type \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
-                --input_size $feats_dim \
                 --ngpu $gpu_num \
                 --num_worker_count $count \
-                --multiprocessing_distributed true \
                 --dist_init_method $init_method \
                 --dist_world_size $world_size \
                 --dist_rank $rank \
@@ -207,7 +161,7 @@
             exit 0
         fi
         mkdir -p "${_logdir}"
-        _data="${feats_dir}/${dumpdir}/${dset}"
+        _data="${feats_dir}/data/${dset}"
         key_file=${_data}/${scp}
         num_scp_file="$(<${key_file} wc -l)"
         _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -228,6 +182,7 @@
                 --njob ${njob} \
                 --gpuid_list ${gpuid_list} \
                 --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                 --key_file "${_logdir}"/keys.JOB.scp \
                 --asr_train_config "${asr_exp}"/config.yaml \
                 --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell2/transformer/utils/cmvn_converter.py b/egs/aishell2/transformer/utils/cmvn_converter.py
new file mode 100644
index 0000000..d405d12
--- /dev/null
+++ b/egs/aishell2/transformer/utils/cmvn_converter.py
@@ -0,0 +1,51 @@
+import argparse
+import json
+import numpy as np
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        description="cmvn converter",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--cmvn_json",
+        default=False,
+        required=True,
+        type=str,
+        help="cmvn json file",
+    )
+    parser.add_argument(
+        "--am_mvn",
+        default=False,
+        required=True,
+        type=str,
+        help="am mvn file",
+    )
+    return parser
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    with open(args.cmvn_json, "r") as fin:
+        cmvn_dict = json.load(fin)
+
+    mean_stats = np.array(cmvn_dict["mean_stats"])
+    var_stats = np.array(cmvn_dict["var_stats"])
+    total_frame = np.array(cmvn_dict["total_frames"])
+
+    mean = -1.0 * mean_stats / total_frame
+    var = 1.0 / np.sqrt(var_stats / total_frame - mean * mean)
+    dims = mean.shape[0]
+    with open(args.am_mvn, 'w') as fout:
+        fout.write("<Nnet>" + "\n" + "<Splice> " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + "<AddShift> " + str(dims) + " " + str(dims) + "\n")
+        mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
+        fout.write("<LearnRateCoef> 0 " + mean_str + '\n')
+        fout.write("<Rescale> " + str(dims) + " " + str(dims) + '\n')
+        var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
+        fout.write("<LearnRateCoef> 0 " + var_str + '\n')
+        fout.write("</Nnet>" + '\n')
+
+if __name__ == '__main__':
+    main()
diff --git a/egs/aishell2/transformer/utils/combine_cmvn_file.py b/egs/aishell2/transformer/utils/combine_cmvn_file.py
index b2974a4..c525973 100755
--- a/egs/aishell2/transformer/utils/combine_cmvn_file.py
+++ b/egs/aishell2/transformer/utils/combine_cmvn_file.py
@@ -1,6 +1,9 @@
 import argparse
 import json
+import os
+
 import numpy as np
+
 
 def get_parser():
     parser = argparse.ArgumentParser(
@@ -8,15 +11,13 @@
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
     parser.add_argument(
-        "--dims",
-        "-d",
+        "--dim",
         default=80,
         type=int,
-        help="feature dims",
+        help="feature dim",
     )
     parser.add_argument(
-        "--cmvn-dir",
-        "-c",
+        "--cmvn_dir",
         default=False,
         required=True,
         type=str,
@@ -25,15 +26,13 @@
 
     parser.add_argument(
         "--nj",
-        "-n",
         default=1,
         required=True,
         type=int,
-        help="num of cmvn file",
+        help="num of cmvn files",
     )
     parser.add_argument(
-        "--output-dir",
-        "-o",
+        "--output_dir",
         default=False,
         required=True,
         type=str,
@@ -46,14 +45,14 @@
     parser = get_parser()
     args = parser.parse_args()
 
-    total_means = np.zeros(args.dims)
-    total_vars = np.zeros(args.dims)
+    total_means = np.zeros(args.dim)
+    total_vars = np.zeros(args.dim)
     total_frames = 0
 
-    cmvn_file = args.output_dir + "/cmvn.json"
+    cmvn_file = os.path.join(args.output_dir, "cmvn.json")
 
-    for i in range(1, args.nj+1):
-        with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
+    for i in range(1, args.nj + 1):
+        with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
             cmvn_stats = json.load(fin)
 
         total_means += np.array(cmvn_stats["mean_stats"])
diff --git a/egs/aishell2/transformer/utils/compute_cmvn.py b/egs/aishell2/transformer/utils/compute_cmvn.py
index 2b96e26..949cc08 100755
--- a/egs/aishell2/transformer/utils/compute_cmvn.py
+++ b/egs/aishell2/transformer/utils/compute_cmvn.py
@@ -1,8 +1,10 @@
-from kaldiio import ReadHelper
-
 import argparse
-import numpy as np
 import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
 
 
 def get_parser():
@@ -11,55 +13,83 @@
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
     parser.add_argument(
-        "--dims",
-        "-d",
+        "--dim",
         default=80,
         type=int,
-        help="feature dims",
+        help="feature dimension",
     )
     parser.add_argument(
-        "--ark-file",
-        "-a",
+        "--wav_path",
         default=False,
         required=True,
         type=str,
-        help="fbank ark file",
+        help="the path of wav scps",
     )
     parser.add_argument(
-        "--ark-index",
-        "-i",
+        "--idx",
         default=1,
         required=True,
         type=int,
-        help="ark index",
-    )
-    parser.add_argument(
-        "--output-dir",
-        "-o",
-        default=False,
-        required=True,
-        type=str,
-        help="output dir",
+        help="index",
     )
     return parser
+
+
+def compute_fbank(wav_file,
+                  num_mel_bins=80,
+                  frame_length=25,
+                  frame_shift=10,
+                  dither=0.0,
+                  resample_rate=16000,
+                  speed=1.0,
+                  window_type="hamming"):
+    waveform, sample_rate = torchaudio.load(wav_file)
+    if resample_rate != sample_rate:
+        waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+                                                  new_freq=resample_rate)(waveform)
+    if speed != 1.0:
+        waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+            waveform, resample_rate,
+            [['speed', str(speed)], ['rate', str(resample_rate)]]
+        )
+
+    waveform = waveform * (1 << 15)
+    mat = kaldi.fbank(waveform,
+                      num_mel_bins=num_mel_bins,
+                      frame_length=frame_length,
+                      frame_shift=frame_shift,
+                      dither=dither,
+                      energy_floor=0.0,
+                      window_type=window_type,
+                      sample_frequency=resample_rate)
+
+    return mat.numpy()
 
 
 def main():
     parser = get_parser()
     args = parser.parse_args()
 
-    ark_file = args.ark_file + "/feats." + str(args.ark_index) + ".ark"
-    cmvn_file = args.output_dir + "/cmvn." + str(args.ark_index) + ".json"
+    wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
+    cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
 
-    mean_stats = np.zeros(args.dims)
-    var_stats = np.zeros(args.dims)
+    mean_stats = np.zeros(args.dim)
+    var_stats = np.zeros(args.dim)
     total_frames = 0
 
-    with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
-        for key, mat in ark_reader:
-            mean_stats += np.sum(mat, axis=0)
-            var_stats += np.sum(np.square(mat), axis=0)
-            total_frames += mat.shape[0]
+    # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+    #     for key, mat in ark_reader:
+    #         mean_stats += np.sum(mat, axis=0)
+    #         var_stats += np.sum(np.square(mat), axis=0)
+    #         total_frames += mat.shape[0]
+    with open(wav_scp_file) as f:
+        lines = f.readlines()
+        for line in lines:
+            _, wav_file = line.strip().split()
+            fbank = compute_fbank(wav_file, num_mel_bins=args.dim)
+            mean_stats += np.sum(fbank, axis=0)
+            var_stats += np.sum(np.square(fbank), axis=0)
+            total_frames += fbank.shape[0]
 
     cmvn_info = {
         'mean_stats': list(mean_stats.tolist()),
diff --git a/egs/aishell2/transformer/utils/compute_cmvn.sh b/egs/aishell2/transformer/utils/compute_cmvn.sh
index 12173ee..7663df9 100755
--- a/egs/aishell2/transformer/utils/compute_cmvn.sh
+++ b/egs/aishell2/transformer/utils/compute_cmvn.sh
@@ -11,15 +11,24 @@
 . utils/parse_options.sh || exit 1;
 
 fbankdir=$1
-logdir=$2
 
-output_dir=${fbankdir}/cmvn; mkdir -p ${output_dir}
-mkdir -p ${logdir}
+split_dir=${fbankdir}/cmvn/split_${nj};
+mkdir -p $split_dir
+split_scps=""
+for n in $(seq $nj); do
+    split_scps="$split_scps $split_dir/wav.$n.scp"
+done
+utils/split_scp.pl ${fbankdir}/wav.scp $split_scps || exit 1;
 
+logdir=${fbankdir}/cmvn/log
 $cmd JOB=1:$nj $logdir/cmvn.JOB.log \
-    python utils/compute_cmvn.py -d ${feats_dim} -a $fbankdir/ark -i JOB -o ${output_dir} \
-        || exit 1;
+    python utils/compute_cmvn.py \
+      --dim ${feats_dim} \
+      --wav_path $split_dir \
+      --idx JOB
 
-python utils/combine_cmvn_file.py -d ${feats_dim} -c ${output_dir} -n $nj -o $fbankdir
+python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
+
+python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/cmvn.mvn
 
 echo "$0: Succeeded compute global cmvn"
diff --git a/egs/aishell2/transformer/utils/compute_fbank.py b/egs/aishell2/transformer/utils/compute_fbank.py
index d03b5a8..9c3904f 100755
--- a/egs/aishell2/transformer/utils/compute_fbank.py
+++ b/egs/aishell2/transformer/utils/compute_fbank.py
@@ -14,7 +14,8 @@
                   frame_shift=10,
                   dither=0.0,
                   resample_rate=16000,
-                  speed=1.0):
+                  speed=1.0,
+                  window_type="hamming"):
 
     waveform, sample_rate = torchaudio.load(wav_file)
     if resample_rate != sample_rate:
@@ -33,7 +34,7 @@
                       frame_shift=frame_shift,
                       dither=dither,
                       energy_floor=0.0,
-                      window_type='hamming',
+                      window_type=window_type,
                       sample_frequency=resample_rate)
 
     return mat.numpy()
@@ -68,6 +69,13 @@
         help="feature dims",
     )
     parser.add_argument(
+        "--max-lengths",
+        "-m",
+        default=1500,
+        type=int,
+        help="max frame numbers",
+    )
+    parser.add_argument(
         "--sample-frequency",
         "-s",
         default=16000,
@@ -96,6 +104,13 @@
         required=True,
         type=str,
         help="output dir",
+    )
+    parser.add_argument(
+        "--window-type",
+        default="hamming",
+        required=False,
+        type=str,
+        help="window type"
     )
     return parser
 
@@ -131,10 +146,13 @@
                     fbank = compute_fbank(wav_file,
                                           num_mel_bins=args.dims,
                                           resample_rate=args.sample_frequency,
-                                          speed=float(speed)
+                                          speed=float(speed),
+                                          window_type=args.window_type
                                           )
                     feats_dims = fbank.shape[1]
                     feats_lens = fbank.shape[0]
+                    if feats_lens >= args.max_lengths:
+                        continue
                     txt_lens = len(txt)
                     if speed == "1.0":
                         wav_id_sp = wav_id
diff --git a/egs/aishell2/transformer/utils/compute_fbank.sh b/egs/aishell2/transformer/utils/compute_fbank.sh
index 92a4fe6..8704b31 100755
--- a/egs/aishell2/transformer/utils/compute_fbank.sh
+++ b/egs/aishell2/transformer/utils/compute_fbank.sh
@@ -9,6 +9,8 @@
 feats_dim=80
 sample_frequency=16000
 speed_perturb="1.0"
+window_type="hamming"
+max_lengths=1500
 
 echo "$0 $@"
 
@@ -29,7 +31,8 @@
 
 $cmd JOB=1:$nj $logdir/make_fbank.JOB.log \
     python utils/compute_fbank.py -w $data/split${nj}/JOB/wav.scp -t $data/split${nj}/JOB/text \
-        -d $feats_dim -s $sample_frequency -p ${speed_perturb} -a JOB -o ${fbankdir} \
+        -d $feats_dim -s $sample_frequency -m ${max_lengths} -p ${speed_perturb} -a JOB -o ${fbankdir} \
+        --window-type ${window_type} \
         || exit 1;
 
 for n in $(seq $nj); do
diff --git a/egs/aishell2/transformer/utils/compute_wer.py b/egs/aishell2/transformer/utils/compute_wer.py
index 349a3f6..26a9f49 100755
--- a/egs/aishell2/transformer/utils/compute_wer.py
+++ b/egs/aishell2/transformer/utils/compute_wer.py
@@ -45,8 +45,8 @@
            if out_item['wrong'] > 0:
                rst['wrong_sentences'] += 1
            cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
-           cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
-           cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
+           cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
+           cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
 
     if rst['Wrd'] > 0:
         rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
diff --git a/egs/aishell2/transformer/utils/download_model.py b/egs/aishell2/transformer/utils/download_model.py
new file mode 100755
index 0000000..70ea179
--- /dev/null
+++ b/egs/aishell2/transformer/utils/download_model.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python3
+import argparse
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(
+        description="download model configs",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument("--model_name",
+                        type=str,
+                        default="damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch",
+                        help="model name in ModelScope")
+    args = parser.parse_args()
+
+    inference_pipeline = pipeline(
+        task=Tasks.auto_speech_recognition,
+        model=args.model_name)
diff --git a/egs/aishell2/transformer/utils/fix_data.sh b/egs/aishell2/transformer/utils/fix_data.sh
index 32cdde5..b1a2bb8 100755
--- a/egs/aishell2/transformer/utils/fix_data.sh
+++ b/egs/aishell2/transformer/utils/fix_data.sh
@@ -28,8 +28,8 @@
 mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak
 mv ${data_dir}/text ${data_dir}/text.bak
 
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak > ${data_dir}/wav.scp
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak | sort -k1,1 -u > ${data_dir}/wav.scp
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
 
 rm ${data_dir}/wav.scp.bak
 rm ${data_dir}/text.bak
diff --git a/egs/aishell2/transformer/utils/fix_data_feat.sh b/egs/aishell2/transformer/utils/fix_data_feat.sh
index 2c92d7f..84eea36 100755
--- a/egs/aishell2/transformer/utils/fix_data_feat.sh
+++ b/egs/aishell2/transformer/utils/fix_data_feat.sh
@@ -40,10 +40,10 @@
 mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak
 mv ${data_dir}/text_shape ${data_dir}/text_shape.bak
 
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak > ${data_dir}/feats.scp
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak > ${data_dir}/speech_shape
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak > ${data_dir}/text_shape
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak | sort -k1,1 -u > ${data_dir}/feats.scp
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak | sort -k1,1 -u > ${data_dir}/speech_shape
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak | sort -k1,1 -u > ${data_dir}/text_shape
 
 rm ${data_dir}/feats.scp.bak
 rm ${data_dir}/text.bak
diff --git a/egs/librispeech/conformer/conf/train_asr_conformer.yaml b/egs/librispeech/conformer/conf/train_asr_conformer.yaml
index 68b127f..2bd3db4 100644
--- a/egs/librispeech/conformer/conf/train_asr_conformer.yaml
+++ b/egs/librispeech/conformer/conf/train_asr_conformer.yaml
@@ -27,13 +27,25 @@
     self_attention_dropout_rate: 0.1
     src_attention_dropout_rate: 0.1
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
+# hybrid CTC/attention
 model_conf:
     ctc_weight: 0.3
     lsm_weight: 0.1
     length_normalized_loss: false
 
 accum_grad: 2
-max_epoch: 50
+max_epoch: 150
 patience: none
 init: none
 best_model_criterion:
diff --git a/egs/librispeech/conformer/local/data_prep_librispeech.sh b/egs/librispeech/conformer/local/data_prep.sh
similarity index 100%
rename from egs/librispeech/conformer/local/data_prep_librispeech.sh
rename to egs/librispeech/conformer/local/data_prep.sh
diff --git a/egs/librispeech/conformer/local/download_and_untar.sh b/egs/librispeech/conformer/local/download_and_untar.sh
new file mode 100755
index 0000000..fe322e4
--- /dev/null
+++ b/egs/librispeech/conformer/local/download_and_untar.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
+  echo "          train-clean-100, train-clean-360, train-other-500."
+  exit 1
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1
+fi
+
+part_ok=false
+list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1
+fi
+
+if [ -f $data/LibriSpeech/$part/.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0
+fi
+
+
+# sizes of the archive files in bytes.  This is some older versions.
+sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
+# sizes_new is the archive file sizes of the final release.  Some of these sizes are of
+# things we probably won't download.
+sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
+
+if [ -f $data/$part.tar.gz ]; then
+  size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tar.gz
+  else
+    echo "$data/$part.tar.gz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tar.gz ]; then
+  if ! which wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1
+  fi
+  full_url=$url/$part.tar.gz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  if ! wget -P $data --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1
+  fi
+fi
+
+if ! tar -C $data -xvzf $data/$part.tar.gz; then
+  echo "$0: error un-tarring archive $data/$part.tar.gz"
+  exit 1
+fi
+
+touch $data/LibriSpeech/$part/.complete
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
+  rm $data/$part.tar.gz
+fi
diff --git a/egs/librispeech/conformer/local/spm_encode.py b/egs/librispeech/conformer/local/spm_encode.py
new file mode 100755
index 0000000..9e1c15f
--- /dev/null
+++ b/egs/librispeech/conformer/local/spm_encode.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+
+
+import argparse
+import contextlib
+import sys
+
+import sentencepiece as spm
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model", required=True,
+                        help="sentencepiece model to use for encoding")
+    parser.add_argument("--inputs", nargs="+", default=['-'],
+                        help="input files to filter/encode")
+    parser.add_argument("--outputs", nargs="+", default=['-'],
+                        help="path to save encoded outputs")
+    parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
+    parser.add_argument("--min-len", type=int, metavar="N",
+                        help="filter sentence pairs with fewer than N tokens")
+    parser.add_argument("--max-len", type=int, metavar="N",
+                        help="filter sentence pairs with more than N tokens")
+    args = parser.parse_args()
+
+    assert len(args.inputs) == len(args.outputs), \
+        "number of input and output paths should match"
+
+    sp = spm.SentencePieceProcessor()
+    sp.Load(args.model)
+
+    if args.output_format == "piece":
+        def encode(l):
+            return sp.EncodeAsPieces(l)
+    elif args.output_format == "id":
+        def encode(l):
+            return list(map(str, sp.EncodeAsIds(l)))
+    else:
+        raise NotImplementedError
+
+    if args.min_len is not None or args.max_len is not None:
+        def valid(line):
+            return (
+                (args.min_len is None or len(line) >= args.min_len) and
+                (args.max_len is None or len(line) <= args.max_len)
+            )
+    else:
+        def valid(lines):
+            return True
+
+    with contextlib.ExitStack() as stack:
+        inputs = [
+            stack.enter_context(open(input, "r", encoding="utf-8"))
+            if input != "-" else sys.stdin
+            for input in args.inputs
+        ]
+        outputs = [
+            stack.enter_context(open(output, "w", encoding="utf-8"))
+            if output != "-" else sys.stdout
+            for output in args.outputs
+        ]
+
+        stats = {
+            "num_empty": 0,
+            "num_filtered": 0,
+        }
+
+        def encode_line(line):
+            line = line.strip()
+            if len(line) > 0:
+                line = encode(line)
+                if valid(line):
+                    return line
+                else:
+                    stats["num_filtered"] += 1
+            else:
+                stats["num_empty"] += 1
+            return None
+
+        for i, lines in enumerate(zip(*inputs), start=1):
+            enc_lines = list(map(encode_line, lines))
+            if not any(enc_line is None for enc_line in enc_lines):
+                for enc_line, output_h in zip(enc_lines, outputs):
+                    print(" ".join(enc_line), file=output_h)
+            if i % 10000 == 0:
+                print("processed {} lines".format(i), file=sys.stderr)
+
+        print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
+        print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech/conformer/local/spm_train.py b/egs/librispeech/conformer/local/spm_train.py
new file mode 100755
index 0000000..134a0b1
--- /dev/null
+++ b/egs/librispeech/conformer/local/spm_train.py
@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+import sys
+
+import sentencepiece as spm
+
+if __name__ == "__main__":
+    spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
diff --git a/egs/librispeech/conformer/run.sh b/egs/librispeech/conformer/run.sh
index 93d1b46..cc4029c 100755
--- a/egs/librispeech/conformer/run.sh
+++ b/egs/librispeech/conformer/run.sh
@@ -16,30 +16,27 @@
 feats_dir="../DATA" #feature output dictionary
 exp_dir="."
 lang=en
-dumpdir=dump/fbank
-feats_type=fbank
 token_type=bpe
-dataset_type=large
-scp=feats.scp
-type=kaldi_ark
-stage=3
-stop_stage=4
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=0
+stop_stage=2
 
 # feature configuration
 feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
 
 # data
-data_librispeech=
+raw_data=
+data_url=www.openslr.org/resources/12
 
 # bpe model
 nbpe=5000
 bpemode=unigram
 
 # exp tag
-tag=""
+tag="exp1"
 
 . utils/parse_options.sh || exit 1;
 
@@ -54,8 +51,7 @@
 test_sets="test_clean test_other dev_clean dev_other"
 
 asr_config=conf/train_asr_conformer.yaml
-#asr_config=conf/train_asr_conformer_uttnorm.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
 
 inference_config=conf/decode_asr_transformer.yaml
 #inference_config=conf/decode_asr_transformer_beam60_ctc0.3.yaml
@@ -73,96 +69,52 @@
     _ngpu=0
 fi
 
-if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
-    echo "stage 0: Data preparation"
-    # Data preparation
-    for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
-        local/data_prep_librispeech.sh ${data_librispeech}/LibriSpeech/${x} ${feats_dir}/data/${x//-/_}
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    echo "stage -1: Data Download"
+    for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
+        local/download_and_untar.sh ${raw_data} ${data_url} ${part}
     done
 fi
 
-feat_train_dir=${feats_dir}/${dumpdir}/$train_set; mkdir -p ${feat_train_dir}
-feat_dev_clean_dir=${feats_dir}/${dumpdir}/dev_clean; mkdir -p ${feat_dev_clean_dir}
-feat_dev_other_dir=${feats_dir}/${dumpdir}/dev_other; mkdir -p ${feat_dev_other_dir}
-feat_test_clean_dir=${feats_dir}/${dumpdir}/test_clean; mkdir -p ${feat_test_clean_dir}
-feat_test_other_dir=${feats_dir}/${dumpdir}/test_other; mkdir -p ${feat_test_other_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/$valid_set; mkdir -p ${feat_dev_dir}
-if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
-    echo "stage 1: Feature Generation"
-    # compute fbank features
-    fbankdir=${feats_dir}/fbank
-    for x in dev_clean dev_other test_clean test_other; do
-        utils/compute_fbank.sh --cmd "$train_cmd" --nj 1 --max_lengths 3000 --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
-            ${feats_dir}/data/${x} ${exp_dir}/exp/make_fbank/${x} ${fbankdir}/${x}
-        utils/fix_data_feat.sh ${fbankdir}/${x}
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+    echo "stage 0: Data preparation"
+    # Data preparation
+    for x in dev-clean dev-other test-clean test-other train-clean-100 train-clean-360 train-other-500; do
+        local/data_prep.sh ${raw_data}/LibriSpeech/${x} ${feats_dir}/data/${x//-/_}
     done
-
-    mkdir ${feats_dir}/data/$train_set
+    mkdir $feats_dir/data/$valid_set
+    dev_sets="dev_clean dev_other"
+    for file in wav.scp text; do
+        ( for f in $dev_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$valid_set/$file || exit 1;
+    done
+    mkdir $feats_dir/data/$train_set
     train_sets="train_clean_100 train_clean_360 train_other_500"
     for file in wav.scp text; do
         ( for f in $train_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$train_set/$file || exit 1;
     done
-    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --max_lengths 3000 --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
-    ${feats_dir}/data/$train_set ${exp_dir}/exp/make_fbank/$train_set ${fbankdir}/$train_set
-    utils/fix_data_feat.sh ${fbankdir}/$train_set
-
-    # compute global cmvn
-    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
-        ${fbankdir}/$train_set ${exp_dir}/exp/make_fbank/$train_set
-
-    # apply cmvn
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
-        ${fbankdir}/$train_set ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/$train_set ${feat_train_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1 \
-        ${fbankdir}/dev_clean ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/dev_clean ${feat_dev_clean_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1\
-        ${fbankdir}/dev_other ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/dev_other ${feat_dev_other_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1 \
-        ${fbankdir}/test_clean ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/test_clean ${feat_test_clean_dir}
-    utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1 \
-        ${fbankdir}/test_other ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/test_other ${feat_test_other_dir}
-
-    cp ${fbankdir}/$train_set/text ${fbankdir}/$train_set/speech_shape ${fbankdir}/$train_set/text_shape ${feat_train_dir}
-    cp ${fbankdir}/dev_clean/text ${fbankdir}/dev_clean/speech_shape ${fbankdir}/dev_clean/text_shape ${feat_dev_clean_dir}
-    cp ${fbankdir}/dev_other/text ${fbankdir}/dev_other/speech_shape ${fbankdir}/dev_other/text_shape ${feat_dev_other_dir}
-    cp ${fbankdir}/test_clean/text ${fbankdir}/test_clean/speech_shape ${fbankdir}/test_clean/text_shape ${feat_test_clean_dir}
-    cp ${fbankdir}/test_other/text ${fbankdir}/test_other/speech_shape ${fbankdir}/test_other/text_shape ${feat_test_other_dir}
-
-    dev_sets="dev_clean dev_other"
-    for file in feats.scp text speech_shape text_shape; do
-        ( for f in $dev_sets; do cat $feats_dir/${dumpdir}/$f/$file; done ) | sort -k1 > $feat_dev_dir/$file || exit 1;
-    done
-
-    #generate ark list
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/${train_set} ${feat_train_dir}
-    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/${valid_set} ${feat_dev_dir}
 fi
 
-dict=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
+fi
+
+token_list=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
 bpemodel=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}
-echo "dictionary: ${dict}"
+echo "dictionary: ${token_list}"
 if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
     ### Task dependent. You have to check non-linguistic symbols used in the corpus.
     echo "stage 2: Dictionary and Json Data Preparation"
     mkdir -p ${feats_dir}/data/lang_char/
-    echo "<blank>" > ${dict}
-    echo "<s>" >> ${dict}
-    echo "</s>" >> ${dict}
+    echo "<blank>" > ${token_list}
+    echo "<s>" >> ${token_list}
+    echo "</s>" >> ${token_list}
     cut -f 2- -d" " ${feats_dir}/data/${train_set}/text > ${feats_dir}/data/lang_char/input.txt
-    spm_train --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
-    spm_encode --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${dict}
-    echo "<unk>" >> ${dict}
-    wc -l ${dict}
-
-    vocab_size=$(cat ${dict} | wc -l)
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
-    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/$train_set
-    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/$valid_set
-    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/$train_set
-    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/$valid_set
+    local/spm_train.py --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
+    local/spm_encode.py --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${token_list}
+    echo "<unk>" >> ${token_list}
 fi
-
 
 # Training Stage
 world_size=$gpu_num  # run on one machine
@@ -181,20 +133,22 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train.py \
+            train.py \
+                --task_name asr \
                 --gpu_id $gpu_id \
                 --use_preprocessor true \
                 --split_with_space false \
                 --bpemodel ${bpemodel}.model \
                 --token_type $token_type \
-                --dataset_type $dataset_type \
-                --token_list $dict \
-                --train_data_file $feats_dir/$dumpdir/${train_set}/ark_txt.scp \
-                --valid_data_file $feats_dir/$dumpdir/${valid_set}/ark_txt.scp \
+                --token_list $token_list \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
                 --resume true \
                 --output_dir ${exp_dir}/exp/${model_dir} \
                 --config $asr_config \
-                --input_size $feats_dim \
                 --ngpu $gpu_num \
                 --num_worker_count $count \
                 --multiprocessing_distributed true \
@@ -220,7 +174,7 @@
             exit 0
         fi
         mkdir -p "${_logdir}"
-        _data="${feats_dir}/${dumpdir}/${dset}"
+        _data="${feats_dir}/data/${dset}"
         key_file=${_data}/${scp}
         num_scp_file="$(<${key_file} wc -l)"
         _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -241,6 +195,7 @@
                 --njob ${njob} \
                 --gpuid_list ${gpuid_list} \
                 --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
                 --key_file "${_logdir}"/keys.JOB.scp \
                 --asr_train_config "${asr_exp}"/config.yaml \
                 --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/librispeech_100h/conformer/conf/decode_asr_transformer.yaml b/egs/librispeech_100h/conformer/conf/decode_asr_transformer.yaml
new file mode 100644
index 0000000..a147fa7
--- /dev/null
+++ b/egs/librispeech_100h/conformer/conf/decode_asr_transformer.yaml
@@ -0,0 +1,6 @@
+beam_size: 10
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.5
+lm_weight: 0.7
diff --git a/egs/librispeech/conformer/conf/train_asr_conformer_uttnorm.yaml b/egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml
similarity index 77%
rename from egs/librispeech/conformer/conf/train_asr_conformer_uttnorm.yaml
rename to egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml
index 16b7cc0..bd92bb0 100644
--- a/egs/librispeech/conformer/conf/train_asr_conformer_uttnorm.yaml
+++ b/egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml
@@ -1,8 +1,8 @@
 encoder: conformer
 encoder_conf:
-    output_size: 512
-    attention_heads: 8
-    linear_units: 2048
+    output_size: 256
+    attention_heads: 4
+    linear_units: 1024
     num_blocks: 12
     dropout_rate: 0.1
     positional_dropout_rate: 0.1
@@ -19,7 +19,7 @@
 
 decoder: transformer
 decoder_conf:
-    attention_heads: 8
+    attention_heads: 4
     linear_units: 2048
     num_blocks: 6
     dropout_rate: 0.1
@@ -27,13 +27,25 @@
     self_attention_dropout_rate: 0.1
     src_attention_dropout_rate: 0.1
 
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
+# hybrid CTC/attention
 model_conf:
     ctc_weight: 0.3
     lsm_weight: 0.1
     length_normalized_loss: false
 
-accum_grad: 2
-max_epoch: 50
+accum_grad: 1
+max_epoch: 210
 patience: none
 init: none
 best_model_criterion:
@@ -44,11 +56,11 @@
 
 optim: adam
 optim_conf:
-    lr: 0.0025
+    lr: 0.002
     weight_decay: 0.000001
 scheduler: warmuplr
 scheduler_conf:
-    warmup_steps: 40000
+    warmup_steps: 15000
 
 specaug: specaug
 specaug_conf:
@@ -64,7 +76,7 @@
     time_mask_width_ratio_range:
     - 0.
     - 0.05
-    num_time_mask: 10
+    num_time_mask: 5
 
 dataset_conf:
     shuffle: True
@@ -77,4 +89,4 @@
     num_workers: 8
 
 log_interval: 50
-normalize: utterance_mvn
\ No newline at end of file
+normalize: None
\ No newline at end of file
diff --git a/egs/librispeech/conformer/local/data_prep_librispeech.sh b/egs/librispeech_100h/conformer/local/data_prep.sh
similarity index 100%
copy from egs/librispeech/conformer/local/data_prep_librispeech.sh
copy to egs/librispeech_100h/conformer/local/data_prep.sh
diff --git a/egs/librispeech_100h/conformer/local/download_and_untar.sh b/egs/librispeech_100h/conformer/local/download_and_untar.sh
new file mode 100755
index 0000000..fe322e4
--- /dev/null
+++ b/egs/librispeech_100h/conformer/local/download_and_untar.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
+  echo "          train-clean-100, train-clean-360, train-other-500."
+  exit 1
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1
+fi
+
+part_ok=false
+list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1
+fi
+
+if [ -f $data/LibriSpeech/$part/.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0
+fi
+
+
+# sizes of the archive files in bytes.  This is some older versions.
+sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
+# sizes_new is the archive file sizes of the final release.  Some of these sizes are of
+# things we probably won't download.
+sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
+
+if [ -f $data/$part.tar.gz ]; then
+  size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tar.gz
+  else
+    echo "$data/$part.tar.gz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tar.gz ]; then
+  if ! which wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1
+  fi
+  full_url=$url/$part.tar.gz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  if ! wget -P $data --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1
+  fi
+fi
+
+if ! tar -C $data -xvzf $data/$part.tar.gz; then
+  echo "$0: error un-tarring archive $data/$part.tar.gz"
+  exit 1
+fi
+
+touch $data/LibriSpeech/$part/.complete
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
+  rm $data/$part.tar.gz
+fi
diff --git a/egs/librispeech_100h/conformer/local/spm_encode.py b/egs/librispeech_100h/conformer/local/spm_encode.py
new file mode 100755
index 0000000..9e1c15f
--- /dev/null
+++ b/egs/librispeech_100h/conformer/local/spm_encode.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+
+
+import argparse
+import contextlib
+import sys
+
+import sentencepiece as spm
+
+
+def main():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--model", required=True,
+                        help="sentencepiece model to use for encoding")
+    parser.add_argument("--inputs", nargs="+", default=['-'],
+                        help="input files to filter/encode")
+    parser.add_argument("--outputs", nargs="+", default=['-'],
+                        help="path to save encoded outputs")
+    parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
+    parser.add_argument("--min-len", type=int, metavar="N",
+                        help="filter sentence pairs with fewer than N tokens")
+    parser.add_argument("--max-len", type=int, metavar="N",
+                        help="filter sentence pairs with more than N tokens")
+    args = parser.parse_args()
+
+    assert len(args.inputs) == len(args.outputs), \
+        "number of input and output paths should match"
+
+    sp = spm.SentencePieceProcessor()
+    sp.Load(args.model)
+
+    if args.output_format == "piece":
+        def encode(l):
+            return sp.EncodeAsPieces(l)
+    elif args.output_format == "id":
+        def encode(l):
+            return list(map(str, sp.EncodeAsIds(l)))
+    else:
+        raise NotImplementedError
+
+    if args.min_len is not None or args.max_len is not None:
+        def valid(line):
+            return (
+                (args.min_len is None or len(line) >= args.min_len) and
+                (args.max_len is None or len(line) <= args.max_len)
+            )
+    else:
+        def valid(lines):
+            return True
+
+    with contextlib.ExitStack() as stack:
+        inputs = [
+            stack.enter_context(open(input, "r", encoding="utf-8"))
+            if input != "-" else sys.stdin
+            for input in args.inputs
+        ]
+        outputs = [
+            stack.enter_context(open(output, "w", encoding="utf-8"))
+            if output != "-" else sys.stdout
+            for output in args.outputs
+        ]
+
+        stats = {
+            "num_empty": 0,
+            "num_filtered": 0,
+        }
+
+        def encode_line(line):
+            line = line.strip()
+            if len(line) > 0:
+                line = encode(line)
+                if valid(line):
+                    return line
+                else:
+                    stats["num_filtered"] += 1
+            else:
+                stats["num_empty"] += 1
+            return None
+
+        for i, lines in enumerate(zip(*inputs), start=1):
+            enc_lines = list(map(encode_line, lines))
+            if not any(enc_line is None for enc_line in enc_lines):
+                for enc_line, output_h in zip(enc_lines, outputs):
+                    print(" ".join(enc_line), file=output_h)
+            if i % 10000 == 0:
+                print("processed {} lines".format(i), file=sys.stderr)
+
+        print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
+        print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/egs/librispeech_100h/conformer/local/spm_train.py b/egs/librispeech_100h/conformer/local/spm_train.py
new file mode 100755
index 0000000..134a0b1
--- /dev/null
+++ b/egs/librispeech_100h/conformer/local/spm_train.py
@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+import sys
+
+import sentencepiece as spm
+
+if __name__ == "__main__":
+    spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
diff --git a/egs/mars/sd/path.sh b/egs/librispeech_100h/conformer/path.sh
similarity index 100%
rename from egs/mars/sd/path.sh
rename to egs/librispeech_100h/conformer/path.sh
diff --git a/egs/librispeech_100h/conformer/run.sh b/egs/librispeech_100h/conformer/run.sh
new file mode 100755
index 0000000..e980806
--- /dev/null
+++ b/egs/librispeech_100h/conformer/run.sh
@@ -0,0 +1,212 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=2
+count=1
+gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="../DATA" #feature output dictionary
+exp_dir="."
+lang=en
+token_type=bpe
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
+stop_stage=4
+
+# feature configuration
+feats_dim=80
+nj=64
+
+# data
+raw_data=
+data_url=www.openslr.org/resources/12
+
+# bpe model
+nbpe=5000
+bpemode=unigram
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train_clean_100
+valid_set=dev
+test_sets="test_clean test_other dev_clean dev_other"
+
+asr_config=conf/train_asr_conformer.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+
+inference_config=conf/decode_asr_transformer.yaml
+#inference_config=conf/decode_asr_transformer_beam60_ctc0.3.yaml
+inference_asr_model=valid.acc.ave_10best.pth
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+    inference_nj=$[${ngpu}*${njob}]
+    _ngpu=1
+else
+    inference_nj=$njob
+    _ngpu=0
+fi
+
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    echo "stage -1: Data Download"
+    for part in dev-clean test-clean dev-other test-other train-clean-100; do
+        local/download_and_untar.sh ${raw_data} ${data_url} ${part}
+    done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+    echo "stage 0: Data preparation"
+    # Data preparation
+    for x in dev-clean dev-other test-clean test-other train-clean-100; do
+        local/data_prep.sh ${raw_data}/LibriSpeech/${x} ${feats_dir}/data/${x//-/_}
+    done
+    mkdir $feats_dir/data/$valid_set
+    dev_sets="dev_clean dev_other"
+    for file in wav.scp text; do
+        ( for f in $dev_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$valid_set/$file || exit 1;
+    done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
+fi
+
+token_list=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
+bpemodel=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+    ### Task dependent. You have to check non-linguistic symbols used in the corpus.
+    echo "stage 2: Dictionary and Json Data Preparation"
+    mkdir -p ${feats_dir}/data/lang_char/
+    echo "<blank>" > ${token_list}
+    echo "<s>" >> ${token_list}
+    echo "</s>" >> ${token_list}
+    cut -f 2- -d" " ${feats_dir}/data/${train_set}/text > ${feats_dir}/data/lang_char/input.txt
+    local/spm_train.py --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
+    local/spm_encode.py --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${token_list}
+    echo "<unk>" >> ${token_list}
+fi
+
+# Training Stage
+world_size=$gpu_num  # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    echo "stage 3: Training"
+    mkdir -p ${exp_dir}/exp/${model_dir}
+    mkdir -p ${exp_dir}/exp/${model_dir}/log
+    INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+    if [ -f $INIT_FILE ];then
+        rm -f $INIT_FILE
+    fi
+    init_method=file://$(readlink -f $INIT_FILE)
+    echo "$0: init method is $init_method"
+    for ((i = 0; i < $gpu_num; ++i)); do
+        {
+            rank=$i
+            local_rank=$i
+            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+            train.py \
+                --task_name asr \
+                --gpu_id $gpu_id \
+                --use_preprocessor true \
+                --split_with_space false \
+                --bpemodel ${bpemodel}.model \
+                --token_type $token_type \
+                --token_list $token_list \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --speed_perturb ${speed_perturb} \
+                --resume true \
+                --output_dir ${exp_dir}/exp/${model_dir} \
+                --config $asr_config \
+                --ngpu $gpu_num \
+                --num_worker_count $count \
+                --multiprocessing_distributed true \
+                --dist_init_method $init_method \
+                --dist_world_size $world_size \
+                --dist_rank $rank \
+                --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+        } &
+        done
+        wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+    echo "stage 4: Inference"
+    for dset in ${test_sets}; do
+        asr_exp=${exp_dir}/exp/${model_dir}
+        inference_tag="$(basename "${inference_config}" .yaml)"
+        _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+        _logdir="${_dir}/logdir"
+        if [ -d ${_dir} ]; then
+            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+            exit 0
+        fi
+        mkdir -p "${_logdir}"
+        _data="${feats_dir}/data/${dset}"
+        key_file=${_data}/${scp}
+        num_scp_file="$(<${key_file} wc -l)"
+        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+        split_scps=
+        for n in $(seq "${_nj}"); do
+            split_scps+=" ${_logdir}/keys.${n}.scp"
+        done
+        # shellcheck disable=SC2086
+        utils/split_scp.pl "${key_file}" ${split_scps}
+        _opts=
+        if [ -n "${inference_config}" ]; then
+            _opts+="--config ${inference_config} "
+        fi
+        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+            python -m funasr.bin.asr_inference_launch \
+                --batch_size 1 \
+                --ngpu "${_ngpu}" \
+                --njob ${njob} \
+                --gpuid_list ${gpuid_list} \
+                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+                --key_file "${_logdir}"/keys.JOB.scp \
+                --asr_train_config "${asr_exp}"/config.yaml \
+                --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+                --output_dir "${_logdir}"/output.JOB \
+                --mode asr \
+                ${_opts}
+
+        for f in token token_int score text; do
+            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+                for i in $(seq "${_nj}"); do
+                    cat "${_logdir}/output.${i}/1best_recog/${f}"
+                done | sort -k1 >"${_dir}/${f}"
+            fi
+        done
+        python utils/compute_wer.py ${_data}/text ${_dir}/text ${_dir}/text.cer
+        tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+        cat ${_dir}/text.cer.txt
+    done
+fi
\ No newline at end of file
diff --git a/egs/librispeech_100h/conformer/utils b/egs/librispeech_100h/conformer/utils
new file mode 120000
index 0000000..fe070dd
--- /dev/null
+++ b/egs/librispeech_100h/conformer/utils
@@ -0,0 +1 @@
+../../aishell/transformer/utils
\ No newline at end of file
diff --git a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml b/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml
deleted file mode 100644
index 459a741..0000000
--- a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml
+++ /dev/null
@@ -1,121 +0,0 @@
-model: sond
-model_conf:
-    lsm_weight: 0.0
-    length_normalized_loss: true
-    max_spk_num: 16
-
-# speech encoder
-encoder: ecapa_tdnn
-encoder_conf:
-    # pass by model, equal to feature dim
-    # input_size: 80
-    pool_size: 20
-    stride: 1
-speaker_encoder: conv
-speaker_encoder_conf:
-    input_units: 256
-    num_layers: 3
-    num_units: 256
-    kernel_size: 1
-    dropout_rate: 0.0
-    position_encoder: null
-    out_units: 256
-    out_norm: false
-    auxiliary_states: false
-    tf2torch_tensor_name_prefix_torch: speaker_encoder
-    tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder
-ci_scorer: dot
-ci_scorer_conf: {}
-cd_scorer: san
-cd_scorer_conf:
-    input_size: 512
-    output_size: 512
-    out_units: 1
-    attention_heads: 4
-    linear_units: 1024
-    num_blocks: 4
-    dropout_rate: 0.0
-    positional_dropout_rate: 0.0
-    attention_dropout_rate: 0.0
-    # use string "null" to remove input layer
-    input_layer: "null"
-    pos_enc_class: null
-    normalize_before: true
-    tf2torch_tensor_name_prefix_torch: cd_scorer
-    tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer
-# post net
-decoder: fsmn
-decoder_conf:
-    in_units: 32
-    out_units: 2517
-    filter_size: 31
-    fsmn_num_layers: 6
-    dnn_num_layers: 1
-    num_memory_units: 512
-    ffn_inner_dim: 512
-    dropout_rate: 0.0
-    tf2torch_tensor_name_prefix_torch: decoder
-    tf2torch_tensor_name_prefix_tf: EAND/post_net
-frontend: wav_frontend
-frontend_conf:
-    fs: 16000
-    window: povey
-    n_mels: 80
-    frame_length: 25
-    frame_shift: 10
-    filter_length_min: -1
-    filter_length_max: -1
-    lfr_m: 1
-    lfr_n: 1
-    dither: 0.0
-    snip_edges: false
-
-# minibatch related
-batch_type: length
-# 16s * 16k * 16 samples
-batch_bins: 4096000
-num_workers: 8
-
-# optimization related
-accum_grad: 1
-grad_clip: 5
-max_epoch: 50
-val_scheduler_criterion:
-    - valid
-    - acc
-best_model_criterion:
--   - valid
-    - der
-    - min
--   - valid
-    - forward_steps
-    - max
-keep_nbest_models: 10
-
-optim: adam
-optim_conf:
-   lr: 0.001
-scheduler: warmuplr
-scheduler_conf:
-   warmup_steps: 10000
-
-# without spec aug
-specaug: null
-specaug_conf:
-    apply_time_warp: true
-    time_warp_window: 5
-    time_warp_mode: bicubic
-    apply_freq_mask: true
-    freq_mask_width_range:
-    - 0
-    - 30
-    num_freq_mask: 2
-    apply_time_mask: true
-    time_mask_width_range:
-    - 0
-    - 40
-    num_time_mask: 2
-
-log_interval: 50
-# without normalize
-normalize: None
diff --git a/egs/mars/sd/local_run.sh b/egs/mars/sd/local_run.sh
deleted file mode 100755
index 4516e9f..0000000
--- a/egs/mars/sd/local_run.sh
+++ /dev/null
@@ -1,171 +0,0 @@
-#!/usr/bin/env bash
-
-. ./path.sh || exit 1;
-
-# machines configuration
-CUDA_VISIBLE_DEVICES="6,7"
-gpu_num=2
-count=1
-gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
-# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
-train_cmd=utils/run.pl
-infer_cmd=utils/run.pl
-
-# general configuration
-feats_dir="." #feature output dictionary
-exp_dir="."
-lang=zh
-dumpdir=dump/raw
-feats_type=raw
-token_type=char
-scp=wav.scp
-type=kaldi_ark
-stage=3
-stop_stage=4
-
-# feature configuration
-feats_dim=
-sample_frequency=16000
-nj=32
-speed_perturb=
-
-# exp tag
-tag="exp1"
-
-. utils/parse_options.sh || exit 1;
-
-# Set bash to 'debug' mode, it will exit on :
-# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
-set -e
-set -u
-set -o pipefail
-
-train_set=train
-valid_set=dev
-test_sets="dev test"
-
-asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
-
-inference_config=conf/decode_asr_transformer.yaml
-inference_asr_model=valid.acc.ave_10best.pb
-
-# you can set gpu num for decoding here
-gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
-ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
-
-if ${gpu_inference}; then
-    inference_nj=$[${ngpu}*${njob}]
-    _ngpu=1
-else
-    inference_nj=$njob
-    _ngpu=0
-fi
-
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
-
-# Training Stage
-world_size=$gpu_num  # run on one machine
-if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
-    echo "stage 3: Training"
-    mkdir -p ${exp_dir}/exp/${model_dir}
-    mkdir -p ${exp_dir}/exp/${model_dir}/log
-    INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
-    if [ -f $INIT_FILE ];then
-        rm -f $INIT_FILE
-    fi 
-    init_method=file://$(readlink -f $INIT_FILE)
-    echo "$0: init method is $init_method"
-    for ((i = 0; i < $gpu_num; ++i)); do
-        {
-            rank=$i
-            local_rank=$i
-            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
-            asr_train.py \
-                --gpu_id $gpu_id \
-                --use_preprocessor true \
-                --token_type char \
-                --token_list $token_list \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
-                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
-                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
-                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
-                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char  \
-                --resume true \
-                --output_dir ${exp_dir}/exp/${model_dir} \
-                --config $asr_config \
-                --input_size $feats_dim \
-                --ngpu $gpu_num \
-                --num_worker_count $count \
-                --multiprocessing_distributed true \
-                --dist_init_method $init_method \
-                --dist_world_size $world_size \
-                --dist_rank $rank \
-                --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
-        } &
-        done
-        wait
-fi
-
-# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
-    echo "stage 4: Inference"
-    for dset in ${test_sets}; do
-        asr_exp=${exp_dir}/exp/${model_dir}
-        inference_tag="$(basename "${inference_config}" .yaml)"
-        _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
-        _logdir="${_dir}/logdir"
-        if [ -d ${_dir} ]; then
-            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
-            exit 0
-        fi
-        mkdir -p "${_logdir}"
-        _data="${feats_dir}/${dumpdir}/${dset}"
-        key_file=${_data}/${scp}
-        num_scp_file="$(<${key_file} wc -l)"
-        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
-        split_scps=
-        for n in $(seq "${_nj}"); do
-            split_scps+=" ${_logdir}/keys.${n}.scp"
-        done
-        # shellcheck disable=SC2086
-        utils/split_scp.pl "${key_file}" ${split_scps}
-        _opts=
-        if [ -n "${inference_config}" ]; then
-            _opts+="--config ${inference_config} "
-        fi
-        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1: "${_nj}" "${_logdir}"/asr_inference.JOB.log \
-            python -m funasr.bin.asr_inference_launch \
-                --batch_size 1 \
-                --ngpu "${_ngpu}" \
-                --njob ${njob} \
-                --gpuid_list ${gpuid_list} \
-                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
-                --key_file "${_logdir}"/keys.JOB.scp \
-                --asr_train_config "${asr_exp}"/config.yaml \
-                --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
-                --output_dir "${_logdir}"/output.JOB \
-                --mode asr \
-                ${_opts}
-
-        for f in token token_int score text; do
-            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
-                for i in $(seq "${_nj}"); do
-                    cat "${_logdir}/output.${i}/1best_recog/${f}"
-                done | sort -k1 >"${_dir}/${f}"
-            fi
-        done
-        python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
-        python utils/proce_text.py ${_data}/text ${_data}/text.proc
-        python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
-        tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
-        cat ${_dir}/text.cer.txt
-    done
-fi
-
diff --git a/egs/mars/sd/scripts/calculate_shapes.py b/egs/mars/sd/scripts/calculate_shapes.py
deleted file mode 100644
index b207f2d..0000000
--- a/egs/mars/sd/scripts/calculate_shapes.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-
-
-class MyRunner(MultiProcessRunnerV3):
-
-    def prepare(self, parser: argparse.ArgumentParser):
-        parser.add_argument("--input_scp", type=str, required=True)
-        parser.add_argument("--out_path")
-        args = parser.parse_args()
-
-        if not os.path.exists(os.path.dirname(args.out_path)):
-            os.makedirs(os.path.dirname(args.out_path))
-
-        task_list = load_scp_as_list(args.input_scp)
-        return task_list, None, args
-
-    def post(self, result_list, args):
-        fd = open(args.out_path, "wt", encoding="utf-8")
-        for results in result_list:
-            for uttid, shape in results:
-                fd.write("{} {}\n".format(uttid, ",".join(shape)))
-        fd.close()
-
-
-def process(task_args):
-    task_idx, task_list, _, args = task_args
-    rst = []
-    for uttid, file_path in task_list:
-        data = kaldiio.load_mat(file_path)
-        shape = [str(x) for x in data.shape]
-        rst.append((uttid, shape))
-    return rst
-
-
-if __name__ == '__main__':
-    my_runner = MyRunner(process)
-    my_runner.run()
diff --git a/egs/mars/sd/scripts/dump_rttm_to_labels.py b/egs/mars/sd/scripts/dump_rttm_to_labels.py
deleted file mode 100644
index ec1c765..0000000
--- a/egs/mars/sd/scripts/dump_rttm_to_labels.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-
-
-class MyRunner(MultiProcessRunnerV3):
-
-    def prepare(self, parser: argparse.ArgumentParser):
-        parser.add_argument("--rttm_list", type=str, required=True)
-        parser.add_argument("--wav_scp_list", type=str, required=True)
-        parser.add_argument("--out_dir", type=str, required=True)
-        parser.add_argument("--n_spk", type=int, default=8)
-        parser.add_argument("--remove_sil", default=False, action="store_true")
-        parser.add_argument("--max_overlap", default=0, type=int)
-        parser.add_argument("--frame_shift", type=float, default=0.01)
-        args = parser.parse_args()
-
-        rttm_list = [x.strip() for x in open(args.rttm_list, "rt", encoding="utf-8").readlines()]
-        meeting2rttm = OrderedDict()
-        for rttm_path in rttm_list:
-            meeting2rttm.update(self.load_rttm(rttm_path))
-
-        wav_scp_list = [x.strip() for x in open(args.wav_scp_list, "rt", encoding="utf-8").readlines()]
-        meeting_scp = OrderedDict()
-        for scp_path in wav_scp_list:
-            meeting_scp.update(load_scp_as_dict(scp_path))
-
-        if len(meeting_scp) != len(meeting2rttm):
-            logging.warning("Number of wav and rttm mismatch {} != {}".format(
-                len(meeting_scp), len(meeting2rttm)))
-            common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys())
-            logging.warning("Keep {} records.".format(len(common_keys)))
-            new_meeting_scp = OrderedDict()
-            rm_keys = []
-            for key in meeting_scp:
-                if key not in common_keys:
-                    rm_keys.append(key)
-                else:
-                    new_meeting_scp[key] = meeting_scp[key]
-            logging.warning("Keys are removed from wav scp: {}".format(" ".join(rm_keys)))
-
-            new_meeting2rttm = OrderedDict()
-            rm_keys = []
-            for key in meeting2rttm:
-                if key not in common_keys:
-                    rm_keys.append(key)
-                else:
-                    new_meeting2rttm[key] = meeting2rttm[key]
-            logging.warning("Keys are removed from rttm scp: {}".format(" ".join(rm_keys)))
-            meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm
-        if not os.path.exists(args.out_dir):
-            os.makedirs(args.out_dir)
-
-        task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()]
-        return task_list, None, args
-
-    @staticmethod
-    def load_rttm(rttm_path):
-        meeting2rttm = OrderedDict()
-        for one_line in open(rttm_path, "rt", encoding="utf-8"):
-            mid = one_line.strip().split(" ")[1]
-            if mid not in meeting2rttm:
-                meeting2rttm[mid] = []
-            meeting2rttm[mid].append(one_line.strip())
-
-        return meeting2rttm
-
-    def post(self, results_list, args):
-        pass
-
-
-def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, max_overlap=0,
-                sr=None, frame_shift=0.01):
-    frame_shift = int(frame_shift * sr)
-    num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
-    multi_label = np.zeros([n_spk, num_frame], dtype=np.float32)
-    for _, st, dur, spk in spk_turns:
-        idx = spk_list.index(spk)
-
-        st, dur = int(st * sr), int(dur * sr)
-        frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift)
-        frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift)
-        multi_label[idx, frame_st:frame_ed] = 1
-
-    if remove_sil:
-        speech_count = np.sum(multi_label, axis=0)
-        idx = np.nonzero(speech_count)[0]
-        multi_label = multi_label[:, idx]
-
-    if max_overlap > 0:
-        speech_count = np.sum(multi_label, axis=0)
-        idx = np.nonzero(speech_count <= max_overlap)[0]
-        multi_label = multi_label[:, idx]
-
-    label = multi_label.T
-    return label  # (T, N)
-
-
-def build_labels(wav_path, rttms, n_spk, remove_sil=False, max_overlap=0,
-                 sr=16000, frame_shift=0.01):
-    wav, sr = soundfile.read(wav_path)
-    wav_len = len(wav)
-    spk_turns = []
-    spk_list = []
-    for one_line in rttms:
-        parts = one_line.strip().split(" ")
-        mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7]
-        if spk not in spk_list:
-            spk_list.append(spk)
-        spk_turns.append((mid, st, dur, spk))
-    labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, max_overlap, sr, frame_shift)
-    return labels, spk_list
-
-
-def process(task_args):
-    task_idx, task_list, _, args = task_args
-    spk_list_writer = open(os.path.join(args.out_dir, "spk_list.{}.txt".format(task_idx+1)),
-                           "wt", encoding="utf-8")
-    out_path = os.path.join(args.out_dir, "labels.{}".format(task_idx + 1))
-    label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-    for mid, wav_path, rttms in task_list:
-        meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil, args.max_overlap,
-                                                args.sr, args.frame_shift)
-        label_writer(mid, meeting_labels)
-        spk_list_writer.write("{} {}\n".format(mid, " ".join(spk_list)))
-
-    spk_list_writer.close()
-    label_writer.close()
-    return None
-
-
-if __name__ == '__main__':
-    my_runner = MyRunner(process)
-    my_runner.run()
diff --git a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
deleted file mode 100644
index cd1ec7b..0000000
--- a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
+++ /dev/null
@@ -1,115 +0,0 @@
-import numpy as np
-import os
-import argparse
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import soundfile as sf
-from tqdm import tqdm
-
-
-class MyRunner(MultiProcessRunnerV3):
-    def prepare(self, parser):
-        assert isinstance(parser, argparse.ArgumentParser)
-        parser.add_argument("wav_scp", type=str)
-        parser.add_argument("rttm", type=str)
-        parser.add_argument("out_dir", type=str)
-        parser.add_argument("--min_dur", type=float, default=2.0)
-        parser.add_argument("--max_spk_num", type=int, default=4)
-        args = parser.parse_args()
-
-        if not os.path.exists(args.out_dir):
-            os.makedirs(args.out_dir)
-
-        wav_scp = load_scp_as_list(args.wav_scp)
-        meeting2rttms = {}
-        for one_line in open(args.rttm, "rt"):
-            parts = [x for x in one_line.strip().split(" ") if x != ""]
-            mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
-            if mid not in meeting2rttms:
-                meeting2rttms[mid] = []
-            meeting2rttms[mid].append(one_line)
-
-        task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp]
-        return task_list, None, args
-
-    def post(self, result_list, args):
-        count = [0, 0]
-        for result in result_list:
-            count[0] += result[0]
-            count[1] += result[1]
-        print("Found {} speakers, extracted {}.".format(count[1], count[0]))
-
-
-# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
-def calc_multi_label(rttms, length, sr=8000, max_spk_num=4):
-    labels = np.zeros([max_spk_num, length], int)
-    spk_list = []
-    for one_line in rttms:
-        parts = [x for x in one_line.strip().split(" ") if x != ""]
-        mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
-        spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
-        if spk_name.isdigit():
-            spk_name = "{}_S{:03d}".format(mid, int(spk_name))
-        else:
-            spk_name = "{}_{}".format(mid, spk_name)
-        if spk_name not in spk_list:
-            spk_list.append(spk_name)
-        st, dur = int(st*sr), int(dur*sr)
-        idx = spk_list.index(spk_name)
-        labels[idx, st:st+dur] = 1
-    return labels, spk_list
-
-
-def get_nonoverlap_turns(multi_label, spk_list):
-    turns = []
-    label = np.sum(multi_label, axis=0) == 1
-    spk, in_turn, st = None, False, 0
-    for i in range(len(label)):
-        if not in_turn and label[i]:
-            st, in_turn = i, True
-            spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
-        if in_turn:
-            if not label[i]:
-                in_turn = False
-                turns.append([st, i, spk])
-            elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
-                turns.append([st, i, spk])
-                st, in_turn = i, True
-                spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
-    if in_turn:
-        turns.append([st, len(label), spk])
-    return turns
-
-
-def process(task_args):
-    task_id, task_list, _, args = task_args
-    spk_count = [0, 0]
-    for mid, wav_path, rttms in task_list:
-        wav, sr = sf.read(wav_path, dtype="int16")
-        assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
-        multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num)
-        turns = get_nonoverlap_turns(multi_label, spk_list)
-        extracted_spk = []
-        count = 1
-        for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar):
-            if (ed - st) >= args.min_dur * args.sr:
-                seg = wav[st: ed]
-                save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
-                if not os.path.exists(os.path.dirname(save_path)):
-                    os.makedirs(os.path.dirname(save_path))
-                sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
-                count += 1
-                if spk not in extracted_spk:
-                    extracted_spk.append(spk)
-        if len(extracted_spk) != len(spk_list):
-            print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
-                mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
-            ))
-        spk_count[0] += len(extracted_spk)
-        spk_count[1] += len(spk_list)
-    return spk_count
-
-
-if __name__ == '__main__':
-    my_runner = MyRunner(process)
-    my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py b/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py
deleted file mode 100644
index e579f51..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
-    def prepare(self, parser):
-        parser.add_argument("dir", type=str)
-        parser.add_argument("out_dir", type=str)
-        parser.add_argument("--n_spk", type=int, default=4)
-        parser.add_argument("--remove_sil", default=False, action="store_true")
-        args = parser.parse_args()
-
-        meeting_scp = load_scp_as_dict(os.path.join(args.dir, "meeting.scp"))
-        rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
-
-        if not os.path.exists(args.out_dir):
-            os.makedirs(args.out_dir)
-
-        task_list = [(mid, meeting_scp[mid], rttm_path) for mid, rttm_path in rttm_scp]
-        return task_list, None, args
-
-    def post(self, results_list, args):
-        pass
-
-
-def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, sr=16000):
-    multi_label = np.zeros([n_spk, length], dtype=int)
-    for _, st, dur, spk in spk_turns:
-        st, dur = int(st * sr), int(dur * sr)
-        idx = spk_list.index(spk)
-        multi_label[idx, st:st+dur] = 1
-    if not remove_sil:
-        return multi_label.T
-
-    speech_count = np.sum(multi_label, axis=0)
-    idx = np.nonzero(speech_count)[0]
-    label = multi_label[:, idx].T
-    return label  # (T, N)
-
-
-def build_labels(wav_path, rttm_path, n_spk, remove_sil=False, sr=16000):
-    wav_len = int(librosa.get_duration(filename=wav_path, sr=sr) * sr)
-    spk_turns = []
-    spk_list = []
-    for one_line in open(rttm_path, "rt"):
-        parts = one_line.strip().split(" ")
-        mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
-        spk = "{}_S{:03d}".format(mid, spk)
-        if spk not in spk_list:
-            spk_list.append(spk)
-        spk_turns.append((mid, st, dur, spk))
-    labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil)
-    return labels
-
-
-def process(task_args):
-    _, task_list, _, args = task_args
-    for mid, wav_path, rttm_path in task_list:
-        meeting_labels = build_labels(wav_path, rttm_path, args.n_spk, args.remove_sil)
-        save_path = os.path.join(args.out_dir, "{}.lbl".format(mid))
-        np.save(save_path, meeting_labels.astype(bool))
-        print(mid)
-    return None
-
-
-if __name__ == '__main__':
-    my_runner = MyRunner(process)
-    my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py b/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py
deleted file mode 100644
index 11bc395..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import soundfile as sf
-from tqdm import tqdm
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
-    def prepare(self, parser):
-        parser.add_argument("wav_scp", type=str)
-        parser.add_argument("out_dir", type=str)
-        parser.add_argument("--chunk_dur", type=float, default=16)
-        parser.add_argument("--shift_dur", type=float, default=4)
-        args = parser.parse_args()
-
-        if not os.path.exists(args.out_dir):
-            os.makedirs(args.out_dir)
-
-        wav_scp = load_scp_as_list(args.wav_scp)
-        return wav_scp, None, args
-
-    def post(self, results_list, args):
-        pass
-
-
-def process(task_args):
-    _, task_list, _, args = task_args
-    chunk_len, shift_len = int(args.chunk_dur * args.sr), int(args.shift_dur * args.sr)
-    for mid, wav_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
-        if not os.path.exists(os.path.join(args.out_dir, mid)):
-            os.makedirs(os.path.join(args.out_dir, mid))
-
-        wav = librosa.load(wav_path, args.sr, True)[0] * 32767
-        n_chunk = (len(wav) - chunk_len) // shift_len + 1
-        if (len(wav) - chunk_len) % shift_len > 0:
-            n_chunk += 1
-        for i in range(n_chunk):
-            seg = wav[i*shift_len: i*shift_len + chunk_len]
-            st = int(float(i*shift_len)/args.sr * 100)
-            dur = int(float(len(seg))/args.sr * 100)
-            file_name = "{}_S{:04d}_{:07d}_{:07d}.wav".format(mid, i, st, st+dur)
-            save_path = os.path.join(args.out_dir, mid, file_name)
-            sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
-    return None
-
-
-if __name__ == '__main__':
-    my_runner = MyRunner(process)
-    my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py b/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py
deleted file mode 100644
index 011bd7c..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
-    def prepare(self, parser):
-        parser.add_argument("--rttm_scp", type=str)
-        parser.add_argument("--seg_file", type=str)
-        args = parser.parse_args()
-
-        if not os.path.exists(os.path.dirname(args.seg_file)):
-            os.makedirs(os.path.dirname(args.seg_file))
-
-        task_list = load_scp_as_list(args.rttm_scp)
-        return task_list, None, args
-
-    def post(self, results_list, args):
-        with open(args.seg_file, "wt", encoding="utf-8") as fd:
-            for results in results_list:
-                fd.writelines(results)
-
-
-def process(task_args):
-    _, task_list, _, args = task_args
-    outputs = []
-    for mid, rttm_path in task_list:
-        spk_turns = []
-        length = 0
-        for one_line in open(rttm_path, 'rt', encoding="utf-8"):
-            parts = one_line.strip().split(" ")
-            _, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
-            st, ed = int(st*100), int((st + dur)*100)
-            length = ed if ed > length else length
-            spk_turns.append([mid, st, ed, spk_name])
-        is_sph = np.zeros((length+1, ), dtype=bool)
-        for _, st, ed, _ in spk_turns:
-            is_sph[st:ed] = True
-
-        st, in_speech = 0, False
-        for i in range(length+1):
-            if not in_speech and is_sph[i]:
-                st, in_speech = i, True
-            if in_speech and not is_sph[i]:
-                in_speech = False
-                outputs.append("{}-{:07d}-{:07d} {} {:.2f} {:.2f}\n".format(
-                    mid, st, i, mid, float(st)/100, float(i)/100
-                ))
-    return outputs
-
-
-if __name__ == '__main__':
-    my_runner = MyRunner(process)
-    my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py b/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py
deleted file mode 100644
index a2bcd39..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py
+++ /dev/null
@@ -1,138 +0,0 @@
-import soundfile
-import kaldiio
-from tqdm import tqdm
-import json
-import os
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import numpy as np
-import argparse
-import random
-
-short_spk_list = []
-def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000):
-    all_utts = spk2utt[spk]
-    idx_list = list(range(len(all_utts)))
-    random.shuffle(idx_list)
-    count = 0
-    utt_list = []
-    for i in idx_list:
-        utt_id = all_utts[i]
-        utt_list.append(utt_id)
-        count += int(utt2frames[utt_id])
-        if count >= total_len:
-            break
-    if count < 300 and spk not in short_spk_list:
-        print("Speaker {} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300))
-        short_spk_list.append(spk)
-
-    ivc_list = [kaldiio.load_mat(utt2ivc[utt]) for utt in utt_list]
-    ivc_list = [x/np.linalg.norm(x, axis=-1) for x in ivc_list]
-    ivc = np.concatenate(ivc_list, axis=0)
-    ivc = np.mean(ivc, axis=0, keepdims=False)
-    return ivc
-
-
-def process(meeting_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args):
-    out_prefix = args.out
-
-    ivc_dim = 192
-    win_len, win_shift = 400, 160
-    label_weights = 2 ** np.array(list(range(args.n_spk)))
-    wav_writer = kaldiio.WriteHelper("ark,scp:{}_wav.ark,{}_wav.scp".format(out_prefix, out_prefix))
-    ivc_writer = kaldiio.WriteHelper("ark,scp:{}_profile.ark,{}_profile.scp".format(out_prefix, out_prefix))
-    label_writer = kaldiio.WriteHelper("ark,scp:{}_label.ark,{}_label.scp".format(out_prefix, out_prefix))
-
-
-    frames_list = []
-    chunk_size = int(args.chunk_size * args.sr)
-    chunk_shift = int(args.chunk_shift * args.sr)
-    for mid, meeting_wav_path in tqdm(meeting_scp, total=len(meeting_scp), ascii=True, disable=args.no_pbar):
-        meeting_wav, sr = soundfile.read(meeting_wav_path, dtype='float32')
-        num_chunk = (len(meeting_wav) - chunk_size) // chunk_shift + 1
-        meeting_labels = np.load(labels_scp[mid])
-        for i in range(num_chunk):
-            st, ed = i*chunk_shift, i*chunk_shift+chunk_size
-            seg_id = "{}-{:03d}-{:06d}-{:06d}".format(mid, i, int(st/args.sr*100), int(ed/args.sr*100))
-            wav_writer(seg_id, meeting_wav[st: ed])
-
-            xvec_list = []
-            for spk in meeting2spk_list[mid]:
-                spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 1000)
-                xvec_list.append(spk_xvec)
-            for _ in range(args.n_spk - len(xvec_list)):
-                xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32))
-            xvec = np.row_stack(xvec_list)
-            ivc_writer(seg_id, xvec)
-
-            wav_label = meeting_labels[st:ed, :]
-            frame_num = (ed-st) // win_shift
-            # wav_label = np.pad(wav_label, ((win_len/2, win_len/2), (0, 0)), "constant")
-            feat_label = np.zeros((frame_num, wav_label.shape[1]), dtype=np.float32)
-            for i in range(frame_num):
-                frame_label = wav_label[i*win_shift: (i+1)*win_shift, :]
-                feat_label[i, :] = (np.sum(frame_label, axis=0) > 0).astype(np.float32)
-            label_writer(seg_id, feat_label)
-
-            frames_list.append((mid, feat_label.shape[0]))
-    return frames_list
-
-
-def calc_spk_list(rttm_path):
-    spk_list = []
-    for one_line in open(rttm_path, "rt"):
-        parts = one_line.strip().split(" ")
-        mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
-        spk = "{}_S{:03d}".format(mid, spk)
-        if spk not in spk_list:
-            spk_list.append(spk)
-
-    return spk_list
-
-
-def main():
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--dir", required=True, type=str, default=None,
-                        help="feats.scp")
-    parser.add_argument("--out", required=True, type=str, default=None,
-                        help="The prefix of dumpped files.")
-    parser.add_argument("--n_spk", type=int, default=4)
-    parser.add_argument("--use_lfr", default=False, action="store_true")
-    parser.add_argument("--no_pbar", default=False, action="store_true")
-    parser.add_argument("--sr", type=int, default=16000)
-    parser.add_argument("--chunk_size", type=int, default=16)
-    parser.add_argument("--chunk_shift", type=int, default=4)
-    args = parser.parse_args()
-
-    if not os.path.exists(os.path.dirname(args.out)):
-        os.makedirs(os.path.dirname(args.out))
-
-    meetings_scp = load_scp_as_list(os.path.join(args.dir, "meetings_rmsil.scp"))
-    labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp"))
-    rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
-    utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
-    utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
-    utt2wav = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
-    utt2frames = {}
-    for uttid, wav_path in utt2wav.items():
-        wav, sr = soundfile.read(wav_path, dtype="int16")
-        utt2frames[uttid] = int(len(wav) / sr * 100)
-
-    meeting2spk_list = {}
-    for mid, rttm_path in rttm_scp:
-        meeting2spk_list[mid] = calc_spk_list(rttm_path)
-
-    spk2utt = {}
-    for utt, spk in utt2spk.items():
-        if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25:
-            if spk not in spk2utt:
-                spk2utt[spk] = []
-            spk2utt[spk].append(utt)
-
-    # random.shuffle(feat_scp)
-    meeting_lens = process(meetings_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args)
-    total_frames = sum([x[1] for x in meeting_lens])
-    print("Total chunks: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames))
-
-
-if __name__ == '__main__':
-    main()
diff --git a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py b/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
deleted file mode 100644
index 1d6f53e..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
+++ /dev/null
@@ -1,110 +0,0 @@
-from __future__ import print_function
-import numpy as np
-import os
-import sys
-import argparse
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import librosa
-import soundfile as sf
-from copy import deepcopy
-import json
-from tqdm import tqdm
-
-
-class MyRunner(MultiProcessRunnerV3):
-    def prepare(self, parser):
-        assert isinstance(parser, argparse.ArgumentParser)
-        parser.add_argument("wav_scp", type=str)
-        parser.add_argument("rttm_scp", type=str)
-        parser.add_argument("out_dir", type=str)
-        parser.add_argument("--min_dur", type=float, default=2.0)
-        parser.add_argument("--max_spk_num", type=int, default=4)
-        args = parser.parse_args()
-
-        if not os.path.exists(args.out_dir):
-            os.makedirs(args.out_dir)
-
-        wav_scp = load_scp_as_list(args.wav_scp)
-        rttm_scp = load_scp_as_dict(args.rttm_scp)
-        task_list = [(mid, wav_path, rttm_scp[mid]) for (mid, wav_path) in wav_scp]
-        return task_list, None, args
-
-    def post(self, result_list, args):
-        count = [0, 0]
-        for result in result_list:
-            count[0] += result[0]
-            count[1] += result[1]
-        print("Found {} speakers, extracted {}.".format(count[1], count[0]))
-
-
-# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
-def calc_multi_label(rttm_path, length, sr=16000, max_spk_num=4):
-    labels = np.zeros([max_spk_num, length], int)
-    spk_list = []
-    for one_line in open(rttm_path, 'rt'):
-        parts = one_line.strip().split(" ")
-        mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
-        if spk_name.isdigit():
-            spk_name = "{}_S{:03d}".format(mid, int(spk_name))
-        if spk_name not in spk_list:
-            spk_list.append(spk_name)
-        st, dur = int(st*sr), int(dur*sr)
-        idx = spk_list.index(spk_name)
-        labels[idx, st:st+dur] = 1
-    return labels, spk_list
-
-
-def get_nonoverlap_turns(multi_label, spk_list):
-    turns = []
-    label = np.sum(multi_label, axis=0) == 1
-    spk, in_turn, st = None, False, 0
-    for i in range(len(label)):
-        if not in_turn and label[i]:
-            st, in_turn = i, True
-            spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
-        if in_turn:
-            if not label[i]:
-                in_turn = False
-                turns.append([st, i, spk])
-            elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
-                turns.append([st, i, spk])
-                st, in_turn = i, True
-                spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
-    if in_turn:
-        turns.append([st, len(label), spk])
-    return turns
-
-
-def process(task_args):
-    task_id, task_list, _, args = task_args
-    spk_count = [0, 0]
-    for mid, wav_path, rttm_path in task_list:
-        wav, sr = sf.read(wav_path, dtype="int16")
-        assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
-        multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num)
-        turns = get_nonoverlap_turns(multi_label, spk_list)
-        extracted_spk = []
-        count = 1
-        for st, ed, spk in tqdm(turns, total=len(turns), ascii=True):
-            if (ed - st) >= args.min_dur * args.sr:
-                seg = wav[st: ed]
-                save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
-                if not os.path.exists(os.path.dirname(save_path)):
-                    os.makedirs(os.path.dirname(save_path))
-                sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
-                count += 1
-                if spk not in extracted_spk:
-                    extracted_spk.append(spk)
-        if len(extracted_spk) != len(spk_list):
-            print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
-                mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
-            ))
-        spk_count[0] += len(extracted_spk)
-        spk_count[1] += len(spk_list)
-    return spk_count
-
-
-if __name__ == '__main__':
-    my_runner = MyRunner(process)
-    my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py b/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py
deleted file mode 100644
index 8b3195f..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import soundfile as sf
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
-    def prepare(self, parser):
-        parser.add_argument("dir", type=str)
-        parser.add_argument("out_dir", type=str)
-        args = parser.parse_args()
-
-        meeting_scp = load_scp_as_list(os.path.join(args.dir, "meeting.scp"))
-        vad_file = open(os.path.join(args.dir, "segments"), encoding="utf-8")
-        meeting2vad = {}
-        for one_line in vad_file:
-            uid, mid, st, ed = one_line.strip().split(" ")
-            st, ed = int(float(st) * args.sr), int(float(ed) * args.sr)
-            if mid not in meeting2vad:
-                meeting2vad[mid] = []
-            meeting2vad[mid].append((uid, st, ed))
-
-        if not os.path.exists(args.out_dir):
-            os.makedirs(args.out_dir)
-
-        task_list = [(mid, wav_path, meeting2vad[mid]) for mid, wav_path in meeting_scp]
-        return task_list, None, args
-
-    def post(self, results_list, args):
-        pass
-
-
-def process(task_args):
-    _, task_list, _, args = task_args
-    for mid, wav_path, vad_list in task_list:
-        wav = librosa.load(wav_path, args.sr, True)[0] * 32767
-        seg_list = []
-        pos_map = []
-        offset = 0
-        for uid, st, ed in vad_list:
-            seg_list.append(wav[st: ed])
-            pos_map.append("{} {} {} {} {}\n".format(uid, st, ed, offset, offset+ed-st))
-            offset = offset + ed - st
-        out = np.concatenate(seg_list, axis=0)
-        save_path = os.path.join(args.out_dir, "{}.wav".format(mid))
-        sf.write(save_path, out.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
-        map_path = os.path.join(args.out_dir, "{}.pos".format(mid))
-        with open(map_path, "wt", encoding="utf-8") as fd:
-            fd.writelines(pos_map)
-        print(mid)
-    return None
-
-
-if __name__ == '__main__':
-    my_runner = MyRunner(process)
-    my_runner.run()
diff --git a/egs/mars/sd/scripts/simu_chunk_with_labels.py b/egs/mars/sd/scripts/simu_chunk_with_labels.py
deleted file mode 100644
index f61b808..0000000
--- a/egs/mars/sd/scripts/simu_chunk_with_labels.py
+++ /dev/null
@@ -1,261 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-import random
-from typing import List, Dict
-from copy import deepcopy
-import json
-logging.basicConfig(
-    level="INFO",
-    format=f"[{os.uname()[1].split('.')[0]}]"
-           f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-)
-
-
-class MyRunner(MultiProcessRunnerV3):
-
-    def prepare(self, parser: argparse.ArgumentParser):
-        parser.add_argument("--label_scp", type=str, required=True)
-        parser.add_argument("--wav_scp", type=str, required=True)
-        parser.add_argument("--utt2spk", type=str, required=True)
-        parser.add_argument("--spk2meeting", type=str, required=True)
-        parser.add_argument("--utt2xvec", type=str, required=True)
-        parser.add_argument("--out_dir", type=str, required=True)
-        parser.add_argument("--chunk_size", type=float, default=16)
-        parser.add_argument("--chunk_shift", type=float, default=4)
-        parser.add_argument("--frame_shift", type=float, default=0.01)
-        parser.add_argument("--embedding_dim", type=int, default=None)
-        parser.add_argument("--average_emb_num", type=int, default=0)
-        parser.add_argument("--subset", type=int, default=0)
-        parser.add_argument("--data_json", type=str, default=None)
-        parser.add_argument("--seed", type=int, default=1234)
-        parser.add_argument("--log_interval", type=int, default=100)
-        args = parser.parse_args()
-        random.seed(args.seed)
-        np.random.seed(args.seed)
-
-        logging.info("Loading data...")
-        if not os.path.exists(args.data_json):
-            label_list = load_scp_as_list(args.label_scp)
-            wav_scp = load_scp_as_dict(args.wav_scp)
-            utt2spk = load_scp_as_dict(args.utt2spk)
-            utt2xvec = load_scp_as_dict(args.utt2xvec)
-            spk2meeting = load_scp_as_dict(args.spk2meeting)
-
-            meeting2spks = OrderedDict()
-            for spk, meeting in spk2meeting.items():
-                if meeting not in meeting2spks:
-                    meeting2spks[meeting] = []
-                meeting2spks[meeting].append(spk)
-
-            spk2utts = OrderedDict()
-            for utt, spk in utt2spk.items():
-                if spk not in spk2utts:
-                    spk2utts[spk] = []
-                spk2utts[spk].append(utt)
-
-            os.makedirs(os.path.dirname(args.data_json), exist_ok=True)
-            logging.info("Dump data...")
-            json.dump({
-                "label_list": label_list, "wav_scp": wav_scp, "utt2xvec": utt2xvec,
-                "spk2utts": spk2utts, "meeting2spks": meeting2spks
-            }, open(args.data_json, "wt", encoding="utf-8"), ensure_ascii=False, indent=4)
-        else:
-            data_dict = json.load(open(args.data_json, "rt", encoding="utf-8"))
-            label_list = data_dict["label_list"]
-            wav_scp = data_dict["wav_scp"]
-            utt2xvec = data_dict["utt2xvec"]
-            spk2utts = data_dict["spk2utts"]
-            meeting2spks = data_dict["meeting2spks"]
-
-        if not os.path.exists(args.out_dir):
-            os.makedirs(args.out_dir)
-
-        args.chunk_size = int(args.chunk_size / args.frame_shift)
-        args.chunk_shift = int(args.chunk_shift / args.frame_shift)
-
-        if args.embedding_dim is None:
-            args.embedding_dim = kaldiio.load_mat(next(iter(utt2xvec.values()))).shape[1]
-            logging.info("Embedding dim is detected as {}.".format(args.embedding_dim))
-
-        logging.info("Number utt: {}, Number speaker: {}, Number meetings: {}".format(
-            len(wav_scp), len(spk2utts), len(meeting2spks)
-        ))
-        return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args
-
-    def post(self, results_list, args):
-        logging.info("[main]: Got {} chunks.".format(sum(results_list)))
-
-
-def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length):
-    utt_list = spk2utts[spk]
-    wav_list = []
-    cur_length = 0
-    while cur_length < sample_length:
-        uttid = random.choice(utt_list)
-        wav, fs = soundfile.read(wav_scp[uttid], dtype='float32')
-        wav_list.append(wav)
-        cur_length += len(wav)
-    concat_wav = np.concatenate(wav_list, axis=0)
-    start = random.randint(0, len(concat_wav) - sample_length)
-    return concat_wav[start: start+sample_length]
-
-
-def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num):
-    # process for dummy speaker
-    if spk == "None":
-        return np.zeros((1, embedding_dim), dtype=np.float32)
-
-    # calculate averaged speaker embeddings
-    utt_list = spk2utts[spk]
-    if average_emb_num == 0 or average_emb_num > len(utt_list):
-        xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in utt_list]
-    else:
-        xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in random.sample(utt_list, average_emb_num)]
-    xvec = np.concatenate(xvec_list, axis=0)
-    xvec = xvec / np.linalg.norm(xvec, axis=-1, keepdims=True)
-    xvec = np.mean(xvec, axis=0)
-
-    return xvec
-
-
-def simu_chunk(
-        frame_label: np.ndarray,
-        sample_label: np.ndarray,
-        wav_scp: Dict[str, str],
-        utt2xvec: Dict[str, str],
-        spk2utts: Dict[str, List[str]],
-        meeting2spks: Dict[str, List[str]],
-        all_speaker_list: List[str],
-        meeting_list: List[str],
-        embedding_dim: int,
-        average_emb_num: int,
-):
-    frame_length, max_spk_num = frame_label.shape
-    sample_length = sample_label.shape[0]
-    positive_speaker_num = int(np.sum(frame_label.sum(axis=0) > 0))
-    pos_speaker_list = deepcopy(meeting2spks[random.choice(meeting_list)])
-
-    # get positive speakers
-    if len(pos_speaker_list) >= positive_speaker_num:
-        pos_speaker_list = random.sample(pos_speaker_list, positive_speaker_num)
-    else:
-        while len(pos_speaker_list) < positive_speaker_num:
-            _spk = random.choice(all_speaker_list)
-            if _spk not in pos_speaker_list:
-                pos_speaker_list.append(_spk)
-
-    # get negative speakers
-    negative_speaker_num = random.randint(0, max_spk_num - positive_speaker_num)
-    neg_speaker_list = []
-    while len(neg_speaker_list) < negative_speaker_num:
-        _spk = random.choice(all_speaker_list)
-        if _spk not in pos_speaker_list and _spk not in neg_speaker_list:
-            neg_speaker_list.append(_spk)
-    neg_speaker_list.extend(["None"] * (max_spk_num - positive_speaker_num - negative_speaker_num))
-
-    random.shuffle(pos_speaker_list)
-    random.shuffle(neg_speaker_list)
-    seperated_wav = np.zeros(sample_label.shape, dtype=np.float32)
-    this_spk_list = []
-    for idx, frame_num in enumerate(frame_label.sum(axis=0)):
-        if frame_num > 0:
-            spk = pos_speaker_list.pop(0)
-            this_spk_list.append(spk)
-            simu_spk_wav = simu_wav_chunk(spk, spk2utts, wav_scp, sample_length)
-            seperated_wav[:, idx] = simu_spk_wav
-        else:
-            spk = neg_speaker_list.pop(0)
-            this_spk_list.append(spk)
-
-    # calculate mixed wav
-    mixed_wav = np.sum(seperated_wav * sample_label, axis=1)
-
-    # shuffle the order of speakers
-    shuffle_idx = list(range(max_spk_num))
-    random.shuffle(shuffle_idx)
-    this_spk_list = [this_spk_list[x] for x in shuffle_idx]
-    seperated_wav = seperated_wav.transpose()[shuffle_idx].transpose()
-    frame_label = frame_label.transpose()[shuffle_idx].transpose()
-
-    # calculate profile
-    profile = [calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num)
-               for spk in this_spk_list]
-    profile = np.vstack(profile)
-    # pse_weights = 2 ** np.arange(max_spk_num)
-    # pse_label = np.sum(frame_label * pse_weights[np.newaxis, :], axis=1)
-    # pse_label = pse_label.astype(str).tolist()
-
-    return mixed_wav, seperated_wav, profile, frame_label
-
-
-def process(task_args):
-    task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args
-    logging.info("{:02d}/{:02d}: Start simulation...".format(task_idx+1, args.nj))
-
-    out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1))
-    wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
-    # out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
-    # wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
-    out_path = os.path.join(args.out_dir, "profile.{}".format(task_idx + 1))
-    profile_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
-    out_path = os.path.join(args.out_dir, "frame_label.{}".format(task_idx + 1))
-    label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
-    speaker_list, meeting_list = list(spk2utts.keys()), list(meeting2spks.keys())
-
-    labels_list = []
-    total_chunks = 0
-    for org_mid, label_path in task_list:
-        whole_label = kaldiio.load_mat(label_path)
-        # random offset to keep diversity
-        rand_shift = random.randint(0, args.chunk_shift)
-        num_chunk = (whole_label.shape[0] - rand_shift - args.chunk_size) // args.chunk_shift + 1
-        labels_list.append((org_mid, whole_label, rand_shift, num_chunk))
-        total_chunks += num_chunk
-
-    idx = 0
-    simu_chunk_count = 0
-    for org_mid, whole_label, rand_shift, num_chunk in labels_list:
-        for i in range(num_chunk):
-            idx = idx + 1
-            st = i * args.chunk_shift + rand_shift
-            ed = i * args.chunk_shift + args.chunk_size + rand_shift
-            utt_id = "subset{}_part{}_{}_{:06d}_{:06d}".format(
-                args.subset + 1, task_idx + 1, org_mid, st, ed
-            )
-            frame_label = whole_label[st: ed, :]
-            sample_label = frame_label.repeat(int(args.sr * args.frame_shift), axis=0)
-            mix_wav, seg_wav, profile, frame_label = simu_chunk(
-                frame_label, sample_label, wav_scp, utt2xvec, spk2utts, meeting2spks,
-                speaker_list, meeting_list, args.embedding_dim, args.average_emb_num
-            )
-            wav_mix_writer(utt_id, mix_wav)
-            # wav_sep_writer(utt_id, seg_wav)
-            profile_writer(utt_id, profile)
-            label_writer(utt_id, frame_label)
-
-            simu_chunk_count += 1
-            if simu_chunk_count % args.log_interval == 0:
-                logging.info("{:02d}/{:02d}: Complete {}/{} simulation, {}.".format(
-                    task_idx + 1, args.nj, simu_chunk_count, total_chunks, utt_id))
-    wav_mix_writer.close()
-    # wav_sep_writer.close()
-    profile_writer.close()
-    label_writer.close()
-    logging.info("[{}/{}]: Simulate {} chunks.".format(task_idx+1, args.nj, simu_chunk_count))
-    return simu_chunk_count
-
-
-if __name__ == '__main__':
-    my_runner = MyRunner(process)
-    my_runner.run()
diff --git a/funasr/bin/asr_test.py b/funasr/bin/asr_test.py
new file mode 100644
index 0000000..389bd22
--- /dev/null
+++ b/funasr/bin/asr_test.py
@@ -0,0 +1,226 @@
+import argparse
+import logging
+import os
+import sys
+
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="ASR Decoding",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    # Note(kamo): Use '_' instead of '-' as separator.
+    # '-' is confusing if written in yaml.
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument("--output_dir", type=str, required=True)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument(
+        "--njob",
+        type=int,
+        default=1,
+        help="The number of jobs for each gpu",
+    )
+    parser.add_argument(
+        "--gpuid_list",
+        type=str,
+        default="",
+        help="The visible gpus",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument(
+        "--data_path_and_name_and_type",
+        type=str2triple_str,
+        required=True,
+        action="append",
+    )
+    group.add_argument("--key_file", type=str_or_none)
+    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument(
+        "--vad_infer_config",
+        type=str,
+        help="VAD infer configuration",
+    )
+    group.add_argument(
+        "--vad_model_file",
+        type=str,
+        help="VAD model parameter file",
+    )
+    group.add_argument(
+        "--cmvn_file",
+        type=str,
+        help="Global CMVN file",
+    )
+    group.add_argument(
+        "--asr_train_config",
+        type=str,
+        help="ASR training configuration",
+    )
+    group.add_argument(
+        "--asr_model_file",
+        type=str,
+        help="ASR model parameter file",
+    )
+    group.add_argument(
+        "--lm_train_config",
+        type=str,
+        help="LM training configuration",
+    )
+    group.add_argument(
+        "--lm_file",
+        type=str,
+        help="LM parameter file",
+    )
+    group.add_argument(
+        "--word_lm_train_config",
+        type=str,
+        help="Word LM training configuration",
+    )
+    group.add_argument(
+        "--word_lm_file",
+        type=str,
+        help="Word LM parameter file",
+    )
+    group.add_argument(
+        "--ngram_file",
+        type=str,
+        help="N-gram parameter file",
+    )
+    group.add_argument(
+        "--model_tag",
+        type=str,
+        help="Pretrained model tag. If specify this option, *_train_config and "
+             "*_file will be overwritten",
+    )
+
+    group = parser.add_argument_group("Beam-search related")
+    group.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+    group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
+    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+    group.add_argument(
+        "--maxlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain max output length. "
+             "If maxlenratio=0.0 (default), it uses a end-detect "
+             "function "
+             "to automatically find maximum hypothesis lengths."
+             "If maxlenratio<0.0, its absolute value is interpreted"
+             "as a constant max output length",
+    )
+    group.add_argument(
+        "--minlenratio",
+        type=float,
+        default=0.0,
+        help="Input length ratio to obtain min output length",
+    )
+    group.add_argument(
+        "--ctc_weight",
+        type=float,
+        default=0.0,
+        help="CTC weight in joint decoding",
+    )
+    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+    group.add_argument("--streaming", type=str2bool, default=False)
+
+    group = parser.add_argument_group("Text converter related")
+    group.add_argument(
+        "--token_type",
+        type=str_or_none,
+        default=None,
+        choices=["char", "bpe", None],
+        help="The token type for ASR model. "
+             "If not given, refers from the training args",
+    )
+    group.add_argument(
+        "--bpemodel",
+        type=str_or_none,
+        default=None,
+        help="The model path of sentencepiece. "
+             "If not given, refers from the training args",
+    )
+    group.add_argument("--token_num_relax", type=int, default=1, help="")
+    group.add_argument("--decoding_ind", type=int, default=0, help="")
+    group.add_argument("--decoding_mode", type=str, default="model1", help="")
+    group.add_argument(
+        "--ctc_weight2",
+        type=float,
+        default=0.0,
+        help="CTC weight in joint decoding",
+    )
+    return parser
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    parser.add_argument(
+        "--mode",
+        type=str,
+        default="asr",
+        help="The decoding mode",
+    )
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    kwargs.pop("config", None)
+
+    # set logging messages
+    logging.basicConfig(
+        level=args.log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+    logging.info("Decoding args: {}".format(kwargs))
+
+    # gpu setting
+    if args.ngpu > 0:
+        jobid = int(args.output_dir.split(".")[-1])
+        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
+        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
+
+    # inference_launch_funasr(**kwargs)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
deleted file mode 100644
index df3434f..0000000
--- a/funasr/bin/build_trainer.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import os
-
-import yaml
-
-
-def update_dct(fin_configs, root):
-    if root == {}:
-        return {}
-    for root_key, root_value in root.items():
-        if not isinstance(root[root_key], dict):
-            fin_configs[root_key] = root[root_key]
-        else:
-            if root_key in fin_configs.keys():
-                result = update_dct(fin_configs[root_key], root[root_key])
-                fin_configs[root_key] = result
-            else:
-                fin_configs[root_key] = root[root_key]
-    return fin_configs
-
-
-def parse_args(mode):
-    if mode == "asr":
-        from funasr.tasks.asr import ASRTask as ASRTask
-    elif mode == "paraformer":
-        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-    elif mode == "paraformer_vad_punc":
-        from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-    elif mode == "uniasr":
-        from funasr.tasks.asr import ASRTaskUniASR as ASRTask
-    elif mode == "mfcca":
-        from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
-    elif mode == "tp":
-        from funasr.tasks.asr import ASRTaskAligner as ASRTask
-    else:
-        raise ValueError("Unknown mode: {}".format(mode))
-    parser = ASRTask.get_parser()
-    args = parser.parse_args()
-    return args, ASRTask
-
-
-def build_trainer(modelscope_dict,
-                  data_dir,
-                  output_dir,
-                  train_set="train",
-                  dev_set="validation",
-                  distributed=False,
-                  dataset_type="small",
-                  batch_bins=None,
-                  max_epoch=None,
-                  optim=None,
-                  lr=None,
-                  scheduler=None,
-                  scheduler_conf=None,
-                  specaug=None,
-                  specaug_conf=None,
-                  param_dict=None,
-                  **kwargs):
-    mode = modelscope_dict['mode']
-    args, ASRTask = parse_args(mode=mode)
-    # ddp related
-    if args.local_rank is not None:
-        distributed = True
-    else:
-        distributed = False
-    args.local_rank = args.local_rank if args.local_rank is not None else 0
-    local_rank = args.local_rank
-    if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
-        gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
-        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
-    else:
-        os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
-
-    config = modelscope_dict['am_model_config']
-    finetune_config = modelscope_dict['finetune_config']
-    init_param = modelscope_dict['init_model']
-    cmvn_file = modelscope_dict['cmvn_file']
-    seg_dict_file = modelscope_dict['seg_dict']
-
-    # overwrite parameters
-    with open(config) as f:
-        configs = yaml.safe_load(f)
-    with open(finetune_config) as f:
-        finetune_configs = yaml.safe_load(f)
-        # set data_types
-        if dataset_type == "large":
-            if 'data_types' not in finetune_configs['dataset_conf']:
-                finetune_configs["dataset_conf"]["data_types"] = "sound,text"
-    finetune_configs = update_dct(configs, finetune_configs)
-    for key, value in finetune_configs.items():
-        if hasattr(args, key):
-            setattr(args, key, value)
-
-    # prepare data
-    args.dataset_type = dataset_type
-    if args.dataset_type == "small":
-        args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
-                                                  ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
-        args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
-                                                  ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
-    elif args.dataset_type == "large":
-        args.train_data_file = None
-        args.valid_data_file = None
-    else:
-        raise ValueError(f"Not supported dataset_type={args.dataset_type}")
-    args.init_param = [init_param]
-    args.cmvn_file = cmvn_file
-    if os.path.exists(seg_dict_file):
-        args.seg_dict_file = seg_dict_file
-    else:
-        args.seg_dict_file = None
-    args.data_dir = data_dir
-    args.train_set = train_set
-    args.dev_set = dev_set
-    args.output_dir = output_dir
-    args.gpu_id = args.local_rank
-    args.config = finetune_config
-    if optim is not None:
-        args.optim = optim
-    if lr is not None:
-        args.optim_conf["lr"] = lr
-    if scheduler is not None:
-        args.scheduler = scheduler
-    if scheduler_conf is not None:
-        args.scheduler_conf = scheduler_conf
-    if specaug is not None:
-        args.specaug = specaug
-    if specaug_conf is not None:
-        args.specaug_conf = specaug_conf
-    if max_epoch is not None:
-        args.max_epoch = max_epoch
-    if batch_bins is not None:
-        if args.dataset_type == "small":
-            args.batch_bins = batch_bins
-        elif args.dataset_type == "large":
-            args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
-        else:
-            raise ValueError(f"Not supported dataset_type={args.dataset_type}")
-    if args.normalize in ["null", "none", "None"]:
-        args.normalize = None
-    if args.patience in ["null", "none", "None"]:
-        args.patience = None
-    args.local_rank = local_rank
-    args.distributed = distributed
-    ASRTask.finetune_args = args
-
-    return ASRTask
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
new file mode 100755
index 0000000..53e5bde
--- /dev/null
+++ b/funasr/bin/train.py
@@ -0,0 +1,572 @@
+#!/usr/bin/env python3
+
+import argparse
+import logging
+import os
+import sys
+from io import BytesIO
+
+import torch
+
+from funasr.build_utils.build_args import build_args
+from funasr.build_utils.build_dataloader import build_dataloader
+from funasr.build_utils.build_distributed import build_distributed
+from funasr.build_utils.build_model import build_model
+from funasr.build_utils.build_optimizer import build_optimizer
+from funasr.build_utils.build_scheduler import build_scheduler
+from funasr.build_utils.build_trainer import build_trainer
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.load_pretrained_model import load_pretrained_model
+from funasr.torch_utils.model_summary import model_summary
+from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.prepare_data import prepare_data
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        description="FunASR Common Training Parser",
+    )
+
+    # common configuration
+    parser.add_argument("--output_dir", help="model save path")
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
+
+    # ddp related
+    parser.add_argument(
+        "--dist_backend",
+        default="nccl",
+        type=str,
+        help="distributed backend",
+    )
+    parser.add_argument(
+        "--dist_init_method",
+        type=str,
+        default="env://",
+        help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
+             '"WORLD_SIZE", and "RANK" are referred.',
+    )
+    parser.add_argument(
+        "--dist_world_size",
+        type=int,
+        default=1,
+        help="number of nodes for distributed training",
+    )
+    parser.add_argument(
+        "--dist_rank",
+        type=int,
+        default=None,
+        help="node rank for distributed training",
+    )
+    parser.add_argument(
+        "--local_rank",
+        type=int,
+        default=None,
+        help="local rank for distributed training",
+    )
+    parser.add_argument(
+        "--dist_master_addr",
+        default=None,
+        type=str_or_none,
+        help="The master address for distributed training. "
+             "This value is used when dist_init_method == 'env://'",
+    )
+    parser.add_argument(
+        "--dist_master_port",
+        default=None,
+        type=int_or_none,
+        help="The master port for distributed training"
+             "This value is used when dist_init_method == 'env://'",
+    )
+    parser.add_argument(
+        "--dist_launcher",
+        default=None,
+        type=str_or_none,
+        choices=["slurm", "mpi", None],
+        help="The launcher type for distributed training",
+    )
+    parser.add_argument(
+        "--multiprocessing_distributed",
+        default=True,
+        type=str2bool,
+        help="Use multi-processing distributed training to launch "
+             "N processes per node, which has N GPUs. This is the "
+             "fastest way to use PyTorch for either single node or "
+             "multi node data parallel training",
+    )
+    parser.add_argument(
+        "--unused_parameters",
+        type=str2bool,
+        default=False,
+        help="Whether to use the find_unused_parameters in "
+             "torch.nn.parallel.DistributedDataParallel ",
+    )
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
+    )
+
+    # cudnn related
+    parser.add_argument(
+        "--cudnn_enabled",
+        type=str2bool,
+        default=torch.backends.cudnn.enabled,
+        help="Enable CUDNN",
+    )
+    parser.add_argument(
+        "--cudnn_benchmark",
+        type=str2bool,
+        default=torch.backends.cudnn.benchmark,
+        help="Enable cudnn-benchmark mode",
+    )
+    parser.add_argument(
+        "--cudnn_deterministic",
+        type=str2bool,
+        default=True,
+        help="Enable cudnn-deterministic mode",
+    )
+
+    # trainer related
+    parser.add_argument(
+        "--max_epoch",
+        type=int,
+        default=40,
+        help="The maximum number epoch to train",
+    )
+    parser.add_argument(
+        "--max_update",
+        type=int,
+        default=sys.maxsize,
+        help="The maximum number update step to train",
+    )
+    parser.add_argument(
+        "--batch_interval",
+        type=int,
+        default=10000,
+        help="The batch interval for saving model.",
+    )
+    parser.add_argument(
+        "--patience",
+        type=int_or_none,
+        default=None,
+        help="Number of epochs to wait without improvement "
+             "before stopping the training",
+    )
+    parser.add_argument(
+        "--val_scheduler_criterion",
+        type=str,
+        nargs=2,
+        default=("valid", "loss"),
+        help="The criterion used for the value given to the lr scheduler. "
+             'Give a pair referring the phase, "train" or "valid",'
+             'and the criterion name. The mode specifying "min" or "max" can '
+             "be changed by --scheduler_conf",
+    )
+    parser.add_argument(
+        "--early_stopping_criterion",
+        type=str,
+        nargs=3,
+        default=("valid", "loss", "min"),
+        help="The criterion used for judging of early stopping. "
+             'Give a pair referring the phase, "train" or "valid",'
+             'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
+    )
+    parser.add_argument(
+        "--best_model_criterion",
+        nargs="+",
+        default=[
+            ("train", "loss", "min"),
+            ("valid", "loss", "min"),
+            ("train", "acc", "max"),
+            ("valid", "acc", "max"),
+        ],
+        help="The criterion used for judging of the best model. "
+             'Give a pair referring the phase, "train" or "valid",'
+             'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
+    )
+    parser.add_argument(
+        "--keep_nbest_models",
+        type=int,
+        nargs="+",
+        default=[10],
+        help="Remove previous snapshots excluding the n-best scored epochs",
+    )
+    parser.add_argument(
+        "--nbest_averaging_interval",
+        type=int,
+        default=0,
+        help="The epoch interval to apply model averaging and save nbest models",
+    )
+    parser.add_argument(
+        "--grad_clip",
+        type=float,
+        default=5.0,
+        help="Gradient norm threshold to clip",
+    )
+    parser.add_argument(
+        "--grad_clip_type",
+        type=float,
+        default=2.0,
+        help="The type of the used p-norm for gradient clip. Can be inf",
+    )
+    parser.add_argument(
+        "--grad_noise",
+        type=str2bool,
+        default=False,
+        help="The flag to switch to use noise injection to "
+             "gradients during training",
+    )
+    parser.add_argument(
+        "--accum_grad",
+        type=int,
+        default=1,
+        help="The number of gradient accumulation",
+    )
+    parser.add_argument(
+        "--resume",
+        type=str2bool,
+        default=False,
+        help="Enable resuming if checkpoint is existing",
+    )
+    parser.add_argument(
+        "--train_dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type for training.",
+    )
+    parser.add_argument(
+        "--use_amp",
+        type=str2bool,
+        default=False,
+        help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
+    )
+    parser.add_argument(
+        "--log_interval",
+        default=None,
+        help="Show the logs every the number iterations in each epochs at the "
+             "training phase. If None is given, it is decided according the number "
+             "of training samples automatically .",
+    )
+    parser.add_argument(
+        "--use_tensorboard",
+        type=str2bool,
+        default=True,
+        help="Enable tensorboard logging",
+    )
+
+    # pretrained model related
+    parser.add_argument(
+        "--init_param",
+        type=str,
+        default=[],
+        nargs="*",
+        help="Specify the file path used for initialization of parameters. "
+             "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
+             "where file_path is the model file path, "
+             "src_key specifies the key of model states to be used in the model file, "
+             "dst_key specifies the attribute of the model to be initialized, "
+             "and exclude_keys excludes keys of model states for the initialization."
+             "e.g.\n"
+             "  # Load all parameters"
+             "  --init_param some/where/model.pb\n"
+             "  # Load only decoder parameters"
+             "  --init_param some/where/model.pb:decoder:decoder\n"
+             "  # Load only decoder parameters excluding decoder.embed"
+             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
+             "  --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
+    )
+    parser.add_argument(
+        "--ignore_init_mismatch",
+        type=str2bool,
+        default=False,
+        help="Ignore size mismatch when loading pre-trained model",
+    )
+    parser.add_argument(
+        "--freeze_param",
+        type=str,
+        default=[],
+        nargs="*",
+        help="Freeze parameters",
+    )
+
+    # dataset related
+    parser.add_argument(
+        "--dataset_type",
+        type=str,
+        default="small",
+        help="whether to use dataloader for large dataset",
+    )
+    parser.add_argument(
+        "--dataset_conf",
+        action=NestedDictAction,
+        default=dict(),
+        help=f"The keyword arguments for dataset",
+    )
+    parser.add_argument(
+        "--data_dir",
+        type=str,
+        default=None,
+        help="root path of data",
+    )
+    parser.add_argument(
+        "--train_set",
+        type=str,
+        default="train",
+        help="train dataset",
+    )
+    parser.add_argument(
+        "--valid_set",
+        type=str,
+        default="validation",
+        help="dev dataset",
+    )
+    parser.add_argument(
+        "--speed_perturb",
+        type=float,
+        nargs="+",
+        default=None,
+        help="speed perturb",
+    )
+    parser.add_argument(
+        "--use_preprocessor",
+        type=str2bool,
+        default=True,
+        help="Apply preprocessing to data or not",
+    )
+    parser.add_argument(
+        "--embed_path",
+        type=str,
+        default=None,
+        help="for model which requires embeds",
+    )
+
+    # optimization related
+    parser.add_argument(
+        "--optim",
+        type=lambda x: x.lower(),
+        default="adam",
+        help="The optimizer type",
+    )
+    parser.add_argument(
+        "--optim_conf",
+        action=NestedDictAction,
+        default=dict(),
+        help="The keyword arguments for optimizer",
+    )
+    parser.add_argument(
+        "--scheduler",
+        type=lambda x: str_or_none(x.lower()),
+        default=None,
+        help="The lr scheduler type",
+    )
+    parser.add_argument(
+        "--scheduler_conf",
+        action=NestedDictAction,
+        default=dict(),
+        help="The keyword arguments for lr scheduler",
+    )
+
+    # most task related
+    parser.add_argument(
+        "--init",
+        type=lambda x: str_or_none(x.lower()),
+        default=None,
+        help="The initialization method",
+        choices=[
+            "chainer",
+            "xavier_uniform",
+            "xavier_normal",
+            "kaiming_uniform",
+            "kaiming_normal",
+            None,
+        ],
+    )
+    parser.add_argument(
+        "--token_list",
+        type=str_or_none,
+        default=None,
+        help="A text mapping int-id to token",
+    )
+    parser.add_argument(
+        "--token_type",
+        type=str,
+        default="bpe",
+        choices=["bpe", "char", "word"],
+        help="",
+    )
+    parser.add_argument(
+        "--bpemodel",
+        type=str_or_none,
+        default=None,
+        help="The model file fo sentencepiece",
+    )
+    parser.add_argument(
+        "--cleaner",
+        type=str_or_none,
+        choices=[None, "tacotron", "jaconv", "vietnamese"],
+        default=None,
+        help="Apply text cleaning",
+    )
+    parser.add_argument(
+        "--g2p",
+        type=str_or_none,
+        choices=g2p_choices,
+        default=None,
+        help="Specify g2p method if --token_type=phn",
+    )
+
+    # pai related
+    parser.add_argument(
+        "--use_pai",
+        type=str2bool,
+        default=False,
+        help="flag to indicate whether training on PAI",
+    )
+    parser.add_argument(
+        "--simple_ddp",
+        type=str2bool,
+        default=False,
+    )
+    parser.add_argument(
+        "--num_worker_count",
+        type=int,
+        default=1,
+        help="The number of machines on PAI.",
+    )
+    parser.add_argument(
+        "--access_key_id",
+        type=str,
+        default=None,
+        help="The username for oss.",
+    )
+    parser.add_argument(
+        "--access_key_secret",
+        type=str,
+        default=None,
+        help="The password for oss.",
+    )
+    parser.add_argument(
+        "--endpoint",
+        type=str,
+        default=None,
+        help="The endpoint for oss.",
+    )
+    parser.add_argument(
+        "--bucket_name",
+        type=str,
+        default=None,
+        help="The bucket name for oss.",
+    )
+    parser.add_argument(
+        "--oss_bucket",
+        default=None,
+        help="oss bucket.",
+    )
+
+    return parser
+
+
+if __name__ == '__main__':
+    parser = get_parser()
+    args, extra_task_params = parser.parse_known_args()
+    if extra_task_params:
+        args = build_args(args, parser, extra_task_params)
+
+    # set random seed
+    set_all_random_seed(args.seed)
+    torch.backends.cudnn.enabled = args.cudnn_enabled
+    torch.backends.cudnn.benchmark = args.cudnn_benchmark
+    torch.backends.cudnn.deterministic = args.cudnn_deterministic
+
+    # ddp init
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+    args.distributed = args.ngpu > 1 or args.dist_world_size > 1
+    distributed_option = build_distributed(args)
+
+    # for logging
+    if not distributed_option.distributed or distributed_option.dist_rank == 0:
+        logging.basicConfig(
+            level="INFO",
+            format=f"[{os.uname()[1].split('.')[0]}]"
+                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+        )
+    else:
+        logging.basicConfig(
+            level="ERROR",
+            format=f"[{os.uname()[1].split('.')[0]}]"
+                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+        )
+
+    # prepare files for dataloader
+    prepare_data(args, distributed_option)
+
+    model = build_model(args)
+    model = model.to(
+        dtype=getattr(torch, args.train_dtype),
+        device="cuda" if args.ngpu > 0 else "cpu",
+    )
+    optimizers = build_optimizer(args, model=model)
+    schedulers = build_scheduler(args, optimizers)
+
+    logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
+                                                                   distributed_option.dist_rank,
+                                                                   distributed_option.local_rank))
+    logging.info(pytorch_cudnn_version())
+    logging.info("Args: {}".format(args))
+    logging.info(model_summary(model))
+    logging.info("Optimizer: {}".format(optimizers))
+    logging.info("Scheduler: {}".format(schedulers))
+
+    # dump args to config.yaml
+    if not distributed_option.distributed or distributed_option.dist_rank == 0:
+        os.makedirs(args.output_dir, exist_ok=True)
+        with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
+            logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
+            if args.use_pai:
+                buffer = BytesIO()
+                torch.save({"config": vars(args)}, buffer)
+                args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
+            else:
+                yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
+
+    for p in args.init_param:
+        logging.info(f"Loading pretrained params from {p}")
+        load_pretrained_model(
+            model=model,
+            init_param=p,
+            ignore_init_mismatch=args.ignore_init_mismatch,
+            map_location=f"cuda:{torch.cuda.current_device()}"
+            if args.ngpu > 0
+            else "cpu",
+            oss_bucket=args.oss_bucket,
+        )
+
+    # dataloader for training/validation
+    train_dataloader, valid_dataloader = build_dataloader(args)
+
+    # Trainer, including model, optimizers, etc.
+    trainer = build_trainer(
+        args=args,
+        model=model,
+        optimizers=optimizers,
+        schedulers=schedulers,
+        train_dataloader=train_dataloader,
+        valid_dataloader=valid_dataloader,
+        distributed_option=distributed_option
+    )
+
+    trainer.run()
diff --git a/funasr/build_utils/__init__.py b/funasr/build_utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/build_utils/__init__.py
diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py
new file mode 100644
index 0000000..517c85b
--- /dev/null
+++ b/funasr/build_utils/build_args.py
@@ -0,0 +1,93 @@
+from funasr.models.ctc import CTC
+from funasr.utils import config_argparse
+from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+
+def build_args(args, parser, extra_task_params):
+    task_parser = config_argparse.ArgumentParser("Task related config")
+    if args.task_name == "asr":
+        from funasr.build_utils.build_asr_model import class_choices_list
+        for class_choices in class_choices_list:
+            class_choices.add_arguments(task_parser)
+        task_parser.add_argument(
+            "--split_with_space",
+            type=str2bool,
+            default=True,
+            help="whether to split text using <space>",
+        )
+        task_parser.add_argument(
+            "--seg_dict_file",
+            type=str,
+            default=None,
+            help="seg_dict_file for text processing",
+        )
+        task_parser.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+        task_parser.add_argument(
+            "--ctc_conf",
+            action=NestedDictAction,
+            default=get_default_kwargs(CTC),
+            help="The keyword arguments for CTC class.",
+        )
+        task_parser.add_argument(
+            "--cmvn_file",
+            type=str_or_none,
+            default=None,
+            help="The file path of noise scp file.",
+        )
+
+    elif args.task_name == "pretrain":
+        from funasr.build_utils.build_pretrain_model import class_choices_list
+        for class_choices in class_choices_list:
+            class_choices.add_arguments(task_parser)
+        task_parser.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+
+    elif args.task_name == "lm":
+        from funasr.build_utils.build_lm_model import class_choices_list
+        for class_choices in class_choices_list:
+            class_choices.add_arguments(task_parser)
+
+    elif args.task_name == "punc":
+        from funasr.build_utils.build_punc_model import class_choices_list
+        for class_choices in class_choices_list:
+            class_choices.add_arguments(task_parser)
+
+    elif args.task_name == "vad":
+        from funasr.build_utils.build_vad_model import class_choices_list
+        for class_choices in class_choices_list:
+            class_choices.add_arguments(task_parser)
+        task_parser.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+
+    elif args.task_name == "diar":
+        from funasr.build_utils.build_diar_model import class_choices_list
+        for class_choices in class_choices_list:
+            class_choices.add_arguments(task_parser)
+
+    else:
+        raise NotImplementedError("Not supported task: {}".format(args.task_name))
+
+    for action in parser._actions:
+        if not any(action.dest == a.dest for a in task_parser._actions):
+            task_parser._add_action(action)
+
+    task_parser.set_defaults(**vars(args))
+    task_args = task_parser.parse_args(extra_task_params)
+    return task_args
diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
new file mode 100644
index 0000000..d8cbba5
--- /dev/null
+++ b/funasr/build_utils/build_asr_model.py
@@ -0,0 +1,352 @@
+import logging
+
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.rnn_decoder import RNNDecoder
+from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
+from funasr.models.decoder.transformer_decoder import (
+    DynamicConvolution2DTransformerDecoder,  # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
+from funasr.models.decoder.transformer_decoder import (
+    LightweightConvolution2DTransformerDecoder,  # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import (
+    LightweightConvolutionTransformerDecoder,  # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
+from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.e2e_asr import ASRModel
+from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_tp import TimestampPredictor
+from funasr.models.e2e_uni_asr import UniASR
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.default import MultiChannelFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.subsampling import Conv1dSubsampling
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+    name="frontend",
+    classes=dict(
+        default=DefaultFrontend,
+        sliding_window=SlidingWindow,
+        s3prl=S3prlFrontend,
+        fused=FusedFrontends,
+        wav_frontend=WavFrontend,
+        multichannelfrontend=MultiChannelFrontend,
+    ),
+    default="default",
+)
+specaug_choices = ClassChoices(
+    name="specaug",
+    classes=dict(
+        specaug=SpecAug,
+        specaug_lfr=SpecAugLFR,
+    ),
+    default=None,
+    optional=True,
+)
+normalize_choices = ClassChoices(
+    "normalize",
+    classes=dict(
+        global_mvn=GlobalMVN,
+        utterance_mvn=UtteranceMVN,
+    ),
+    default=None,
+    optional=True,
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        asr=ASRModel,
+        uniasr=UniASR,
+        paraformer=Paraformer,
+        paraformer_bert=ParaformerBert,
+        bicif_paraformer=BiCifParaformer,
+        contextual_paraformer=ContextualParaformer,
+        mfcca=MFCCA,
+        timestamp_prediction=TimestampPredictor,
+    ),
+    default="asr",
+)
+encoder_choices = ClassChoices(
+    "encoder",
+    classes=dict(
+        conformer=ConformerEncoder,
+        transformer=TransformerEncoder,
+        rnn=RNNEncoder,
+        sanm=SANMEncoder,
+        sanm_chunk_opt=SANMEncoderChunkOpt,
+        data2vec_encoder=Data2VecEncoder,
+        mfcca_enc=MFCCAEncoder,
+    ),
+    default="rnn",
+)
+encoder_choices2 = ClassChoices(
+    "encoder2",
+    classes=dict(
+        conformer=ConformerEncoder,
+        transformer=TransformerEncoder,
+        rnn=RNNEncoder,
+        sanm=SANMEncoder,
+        sanm_chunk_opt=SANMEncoderChunkOpt,
+    ),
+    default="rnn",
+)
+decoder_choices = ClassChoices(
+    "decoder",
+    classes=dict(
+        transformer=TransformerDecoder,
+        lightweight_conv=LightweightConvolutionTransformerDecoder,
+        lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+        dynamic_conv=DynamicConvolutionTransformerDecoder,
+        dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+        rnn=RNNDecoder,
+        fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+        paraformer_decoder_sanm=ParaformerSANMDecoder,
+        paraformer_decoder_san=ParaformerDecoderSAN,
+        contextual_paraformer_decoder=ContextualParaformerDecoder,
+    ),
+    default="rnn",
+)
+decoder_choices2 = ClassChoices(
+    "decoder2",
+    classes=dict(
+        transformer=TransformerDecoder,
+        lightweight_conv=LightweightConvolutionTransformerDecoder,
+        lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+        dynamic_conv=DynamicConvolutionTransformerDecoder,
+        dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+        rnn=RNNDecoder,
+        fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+        paraformer_decoder_sanm=ParaformerSANMDecoder,
+    ),
+    type_check=AbsDecoder,
+    default="rnn",
+)
+predictor_choices = ClassChoices(
+    name="predictor",
+    classes=dict(
+        cif_predictor=CifPredictor,
+        ctc_predictor=None,
+        cif_predictor_v2=CifPredictorV2,
+        cif_predictor_v3=CifPredictorV3,
+    ),
+    default="cif_predictor",
+    optional=True,
+)
+predictor_choices2 = ClassChoices(
+    name="predictor2",
+    classes=dict(
+        cif_predictor=CifPredictor,
+        ctc_predictor=None,
+        cif_predictor_v2=CifPredictorV2,
+    ),
+    default="cif_predictor",
+    optional=True,
+)
+stride_conv_choices = ClassChoices(
+    name="stride_conv",
+    classes=dict(
+        stride_conv1d=Conv1dSubsampling
+    ),
+    default="stride_conv1d",
+    optional=True,
+)
+class_choices_list = [
+    # --frontend and --frontend_conf
+    frontend_choices,
+    # --specaug and --specaug_conf
+    specaug_choices,
+    # --normalize and --normalize_conf
+    normalize_choices,
+    # --model and --model_conf
+    model_choices,
+    # --encoder and --encoder_conf
+    encoder_choices,
+    # --decoder and --decoder_conf
+    decoder_choices,
+    # --predictor and --predictor_conf
+    predictor_choices,
+    # --encoder2 and --encoder2_conf
+    encoder_choices2,
+    # --decoder2 and --decoder2_conf
+    decoder_choices2,
+    # --predictor2 and --predictor2_conf
+    predictor_choices2,
+    # --stride_conv and --stride_conv_conf
+    stride_conv_choices,
+]
+
+
+def build_asr_model(args):
+    # token_list
+    if args.token_list is not None:
+        with open(args.token_list) as f:
+            token_list = [line.rstrip() for line in f]
+        args.token_list = list(token_list)
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size}")
+    else:
+        vocab_size = None
+
+    # frontend
+    if args.input_size is None:
+        frontend_class = frontend_choices.get_class(args.frontend)
+        if args.frontend == 'wav_frontend':
+            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+        else:
+            frontend = frontend_class(**args.frontend_conf)
+        input_size = frontend.output_size()
+    else:
+        args.frontend = None
+        args.frontend_conf = {}
+        frontend = None
+        input_size = args.input_size
+
+    # data augmentation for spectrogram
+    if args.specaug is not None:
+        specaug_class = specaug_choices.get_class(args.specaug)
+        specaug = specaug_class(**args.specaug_conf)
+    else:
+        specaug = None
+
+    # normalization layer
+    if args.normalize is not None:
+        normalize_class = normalize_choices.get_class(args.normalize)
+        normalize = normalize_class(**args.normalize_conf)
+    else:
+        normalize = None
+
+    # encoder
+    encoder_class = encoder_choices.get_class(args.encoder)
+    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+    # decoder
+    decoder_class = decoder_choices.get_class(args.decoder)
+    decoder = decoder_class(
+        vocab_size=vocab_size,
+        encoder_output_size=encoder.output_size(),
+        **args.decoder_conf,
+    )
+
+    # ctc
+    ctc = CTC(
+        odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
+    )
+
+    if args.model in ["asr", "mfcca"]:
+        model_class = model_choices.get_class(args.model)
+        model = model_class(
+            vocab_size=vocab_size,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            encoder=encoder,
+            decoder=decoder,
+            ctc=ctc,
+            token_list=token_list,
+            **args.model_conf,
+        )
+    elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
+        # predictor
+        predictor_class = predictor_choices.get_class(args.predictor)
+        predictor = predictor_class(**args.predictor_conf)
+
+        model_class = model_choices.get_class(args.model)
+        model = model_class(
+            vocab_size=vocab_size,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            encoder=encoder,
+            decoder=decoder,
+            ctc=ctc,
+            token_list=token_list,
+            predictor=predictor,
+            **args.model_conf,
+        )
+    elif args.model == "uniasr":
+        # stride_conv
+        stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
+        stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
+                                        odim=input_size + encoder.output_size())
+        stride_conv_output_size = stride_conv.output_size()
+
+        # encoder2
+        encoder_class2 = encoder_choices2.get_class(args.encoder2)
+        encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
+
+        # decoder2
+        decoder_class2 = decoder_choices2.get_class(args.decoder2)
+        decoder2 = decoder_class2(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder2.output_size(),
+            **args.decoder2_conf,
+        )
+
+        # ctc2
+        ctc2 = CTC(
+            odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
+        )
+
+        # predictor
+        predictor_class = predictor_choices.get_class(args.predictor)
+        predictor = predictor_class(**args.predictor_conf)
+
+        # predictor2
+        predictor_class = predictor_choices2.get_class(args.predictor2)
+        predictor2 = predictor_class(**args.predictor2_conf)
+
+        model_class = model_choices.get_class(args.model)
+        model = model_class(
+            vocab_size=vocab_size,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            encoder=encoder,
+            decoder=decoder,
+            ctc=ctc,
+            token_list=token_list,
+            predictor=predictor,
+            ctc2=ctc2,
+            encoder2=encoder2,
+            decoder2=decoder2,
+            predictor2=predictor2,
+            stride_conv=stride_conv,
+            **args.model_conf,
+        )
+    elif args.model == "timestamp_prediction":
+        model_class = model_choices.get_class(args.model)
+        model = model_class(
+            frontend=frontend,
+            encoder=encoder,
+            token_list=token_list,
+            **args.model_conf,
+        )
+    else:
+        raise NotImplementedError("Not supported model: {}".format(args.model))
+
+    # initialize
+    if args.init is not None:
+        initialize(model, args.init)
+
+    return model
\ No newline at end of file
diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
new file mode 100644
index 0000000..c95c40d
--- /dev/null
+++ b/funasr/build_utils/build_dataloader.py
@@ -0,0 +1,15 @@
+from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
+from funasr.datasets.small_datasets.sequence_iter_factory import SequenceIterFactory
+
+
+def build_dataloader(args):
+    if args.dataset_type == "small":
+        train_iter_factory = SequenceIterFactory(args, mode="train")
+        valid_iter_factory = SequenceIterFactory(args, mode="valid")
+    elif args.dataset_type == "large":
+        train_iter_factory = LargeDataLoader(args, mode="train")
+        valid_iter_factory = LargeDataLoader(args, mode="valid")
+    else:
+        raise ValueError(f"Not supported dataset_type={args.dataset_type}")
+
+    return train_iter_factory, valid_iter_factory
diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py
new file mode 100644
index 0000000..6406404
--- /dev/null
+++ b/funasr/build_utils/build_diar_model.py
@@ -0,0 +1,296 @@
+import logging
+
+import torch
+
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.label_aggregation import LabelAggregate
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
+from funasr.models.e2e_diar_sond import DiarSondModel
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
+from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
+from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
+from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
+from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
+from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+    name="frontend",
+    classes=dict(
+        default=DefaultFrontend,
+        sliding_window=SlidingWindow,
+        s3prl=S3prlFrontend,
+        fused=FusedFrontends,
+        wav_frontend=WavFrontend,
+        wav_frontend_mel23=WavFrontendMel23,
+    ),
+    default="default",
+)
+specaug_choices = ClassChoices(
+    name="specaug",
+    classes=dict(
+        specaug=SpecAug,
+        specaug_lfr=SpecAugLFR,
+    ),
+    default=None,
+    optional=True,
+)
+normalize_choices = ClassChoices(
+    "normalize",
+    classes=dict(
+        global_mvn=GlobalMVN,
+        utterance_mvn=UtteranceMVN,
+    ),
+    default=None,
+    optional=True,
+)
+label_aggregator_choices = ClassChoices(
+    "label_aggregator",
+    classes=dict(
+        label_aggregator=LabelAggregate
+    ),
+    default=None,
+    optional=True,
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        sond=DiarSondModel,
+        eend_ola=DiarEENDOLAModel,
+    ),
+    default="sond",
+)
+encoder_choices = ClassChoices(
+    "encoder",
+    classes=dict(
+        conformer=ConformerEncoder,
+        transformer=TransformerEncoder,
+        rnn=RNNEncoder,
+        sanm=SANMEncoder,
+        san=SelfAttentionEncoder,
+        fsmn=FsmnEncoder,
+        conv=ConvEncoder,
+        resnet34=ResNet34Diar,
+        resnet34_sp_l2reg=ResNet34SpL2RegDiar,
+        sanm_chunk_opt=SANMEncoderChunkOpt,
+        data2vec_encoder=Data2VecEncoder,
+        ecapa_tdnn=ECAPA_TDNN,
+        eend_ola_transformer=EENDOLATransformerEncoder,
+    ),
+    default="resnet34",
+)
+speaker_encoder_choices = ClassChoices(
+    "speaker_encoder",
+    classes=dict(
+        conformer=ConformerEncoder,
+        transformer=TransformerEncoder,
+        rnn=RNNEncoder,
+        sanm=SANMEncoder,
+        san=SelfAttentionEncoder,
+        fsmn=FsmnEncoder,
+        conv=ConvEncoder,
+        sanm_chunk_opt=SANMEncoderChunkOpt,
+        data2vec_encoder=Data2VecEncoder,
+    ),
+    default=None,
+    optional=True
+)
+cd_scorer_choices = ClassChoices(
+    "cd_scorer",
+    classes=dict(
+        san=SelfAttentionEncoder,
+    ),
+    default=None,
+    optional=True,
+)
+ci_scorer_choices = ClassChoices(
+    "ci_scorer",
+    classes=dict(
+        dot=DotScorer,
+        cosine=CosScorer,
+        conv=ConvEncoder,
+    ),
+    type_check=torch.nn.Module,
+    default=None,
+    optional=True,
+)
+# decoder is used for output (e.g. post_net in SOND)
+decoder_choices = ClassChoices(
+    "decoder",
+    classes=dict(
+        rnn=RNNEncoder,
+        fsmn=FsmnEncoder,
+    ),
+    type_check=torch.nn.Module,
+    default="fsmn",
+)
+# encoder_decoder_attractor is used for EEND-OLA
+encoder_decoder_attractor_choices = ClassChoices(
+    "encoder_decoder_attractor",
+    classes=dict(
+        eda=EncoderDecoderAttractor,
+    ),
+    type_check=torch.nn.Module,
+    default="eda",
+)
+class_choices_list = [
+    # --frontend and --frontend_conf
+    frontend_choices,
+    # --specaug and --specaug_conf
+    specaug_choices,
+    # --normalize and --normalize_conf
+    normalize_choices,
+    # --label_aggregator and --label_aggregator_conf
+    label_aggregator_choices,
+    # --model and --model_conf
+    model_choices,
+    # --encoder and --encoder_conf
+    encoder_choices,
+    # --speaker_encoder and --speaker_encoder_conf
+    speaker_encoder_choices,
+    # --cd_scorer and cd_scorer_conf
+    cd_scorer_choices,
+    # --ci_scorer and ci_scorer_conf
+    ci_scorer_choices,
+    # --decoder and --decoder_conf
+    decoder_choices,
+    # --eda and --eda_conf
+    encoder_decoder_attractor_choices,
+]
+
+
+def build_diar_model(args):
+    # token_list
+    if args.token_list is not None:
+        with open(args.token_list) as f:
+            token_list = [line.rstrip() for line in f]
+        args.token_list = list(token_list)
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size}")
+    else:
+        vocab_size = None
+
+    # frontend
+    if args.input_size is None:
+        frontend_class = frontend_choices.get_class(args.frontend)
+        if args.frontend == 'wav_frontend':
+            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+        else:
+            frontend = frontend_class(**args.frontend_conf)
+        input_size = frontend.output_size()
+    else:
+        args.frontend = None
+        args.frontend_conf = {}
+        frontend = None
+        input_size = args.input_size
+
+    # encoder
+    encoder_class = encoder_choices.get_class(args.encoder)
+    encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+    if args.model_name == "sond":
+        # data augmentation for spectrogram
+        if args.specaug is not None:
+            specaug_class = specaug_choices.get_class(args.specaug)
+            specaug = specaug_class(**args.specaug_conf)
+        else:
+            specaug = None
+
+        # normalization layer
+        if args.normalize is not None:
+            normalize_class = normalize_choices.get_class(args.normalize)
+            normalize = normalize_class(**args.normalize_conf)
+        else:
+            normalize = None
+
+        # speaker encoder
+        if getattr(args, "speaker_encoder", None) is not None:
+            speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
+            speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
+        else:
+            speaker_encoder = None
+
+        # ci scorer
+        if getattr(args, "ci_scorer", None) is not None:
+            ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
+            ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
+        else:
+            ci_scorer = None
+
+        # cd scorer
+        if getattr(args, "cd_scorer", None) is not None:
+            cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
+            cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
+        else:
+            cd_scorer = None
+
+        # decoder
+        decoder_class = decoder_choices.get_class(args.decoder)
+        decoder = decoder_class(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder.output_size(),
+            **args.decoder_conf,
+        )
+
+        # logger aggregator
+        if getattr(args, "label_aggregator", None) is not None:
+            label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
+            label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
+        else:
+            label_aggregator = None
+
+        model_class = model_choices.get_class(args.model)
+        model = model_class(
+            vocab_size=vocab_size,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            label_aggregator=label_aggregator,
+            encoder=encoder,
+            speaker_encoder=speaker_encoder,
+            ci_scorer=ci_scorer,
+            cd_scorer=cd_scorer,
+            decoder=decoder,
+            token_list=token_list,
+            **args.model_conf,
+        )
+
+    elif args.model_name == "eend_ola":
+        # encoder-decoder attractor
+        encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
+        encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
+
+        # 9. Build model
+        model_class = model_choices.get_class(args.model)
+        model = model_class(
+            frontend=frontend,
+            encoder=encoder,
+            encoder_decoder_attractor=encoder_decoder_attractor,
+            **args.model_conf,
+        )
+
+    else:
+        raise NotImplementedError("Not supported model: {}".format(args.model))
+
+    # 10. Initialize
+    if args.init is not None:
+        initialize(model, args.init)
+
+    return model
diff --git a/funasr/build_utils/build_distributed.py b/funasr/build_utils/build_distributed.py
new file mode 100644
index 0000000..b64b4c0
--- /dev/null
+++ b/funasr/build_utils/build_distributed.py
@@ -0,0 +1,38 @@
+import logging
+import os
+
+import torch
+
+from funasr.train.distributed_utils import DistributedOption
+from funasr.utils.build_dataclass import build_dataclass
+
+
+def build_distributed(args):
+    distributed_option = build_dataclass(DistributedOption, args)
+    if args.use_pai:
+        distributed_option.init_options_pai()
+        distributed_option.init_torch_distributed_pai(args)
+    elif not args.simple_ddp:
+        distributed_option.init_torch_distributed(args)
+    elif args.distributed and args.simple_ddp:
+        distributed_option.init_torch_distributed_pai(args)
+        args.ngpu = torch.distributed.get_world_size()
+
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+    if not distributed_option.distributed or distributed_option.dist_rank == 0:
+        logging.basicConfig(
+            level="INFO",
+            format=f"[{os.uname()[1].split('.')[0]}]"
+                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+        )
+    else:
+        logging.basicConfig(
+            level="ERROR",
+            format=f"[{os.uname()[1].split('.')[0]}]"
+                   f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+        )
+    logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
+                                                                   distributed_option.dist_rank,
+                                                                   distributed_option.local_rank))
+    return distributed_option
diff --git a/funasr/build_utils/build_lm_model.py b/funasr/build_utils/build_lm_model.py
new file mode 100644
index 0000000..aaa4fb7
--- /dev/null
+++ b/funasr/build_utils/build_lm_model.py
@@ -0,0 +1,57 @@
+import logging
+
+from funasr.lm.abs_model import AbsLM
+from funasr.lm.abs_model import LanguageModel
+from funasr.lm.seq_rnn_lm import SequentialRNNLM
+from funasr.lm.transformer_lm import TransformerLM
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+lm_choices = ClassChoices(
+    "lm",
+    classes=dict(
+        seq_rnn=SequentialRNNLM,
+        transformer=TransformerLM,
+    ),
+    type_check=AbsLM,
+    default="seq_rnn",
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        lm=LanguageModel,
+    ),
+    default="lm",
+)
+
+class_choices_list = [
+    # --lm and --lm_conf
+    lm_choices,
+    # --model and --model_conf
+    model_choices
+]
+
+
+def build_lm_model(args):
+    # token_list
+    if args.token_list is not None:
+        with open(args.token_list) as f:
+            token_list = [line.rstrip() for line in f]
+        args.token_list = list(token_list)
+        vocab_size = len(token_list)
+        logging.info(f"Vocabulary size: {vocab_size}")
+    else:
+        vocab_size = None
+
+    # lm
+    lm_class = lm_choices.get_class(args.lm)
+    lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
+
+    model_class = model_choices.get_class(args.model)
+    model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
+
+    # initialize
+    if args.init is not None:
+        initialize(model, args.init)
+
+    return model
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
new file mode 100644
index 0000000..13a6faa
--- /dev/null
+++ b/funasr/build_utils/build_model.py
@@ -0,0 +1,25 @@
+from funasr.build_utils.build_asr_model import build_asr_model
+from funasr.build_utils.build_lm_model import build_lm_model
+from funasr.build_utils.build_pretrain_model import build_pretrain_model
+from funasr.build_utils.build_punc_model import build_punc_model
+from funasr.build_utils.build_vad_model import build_vad_model
+from funasr.build_utils.build_diar_model import build_diar_model
+
+
+def build_model(args):
+    if args.task_name == "asr":
+        model = build_asr_model(args)
+    elif args.task_name == "pretrain":
+        model = build_pretrain_model(args)
+    elif args.task_name == "lm":
+        model = build_lm_model(args)
+    elif args.task_name == "punc":
+        model = build_punc_model(args)
+    elif args.task_name == "vad":
+        model = build_vad_model(args)
+    elif args.task_name == "diar":
+        model = build_diar_model(args)
+    else:
+        raise NotImplementedError("Not supported task: {}".format(args.task_name))
+
+    return model
diff --git a/funasr/build_utils/build_optimizer.py b/funasr/build_utils/build_optimizer.py
new file mode 100644
index 0000000..bd0b73d
--- /dev/null
+++ b/funasr/build_utils/build_optimizer.py
@@ -0,0 +1,28 @@
+import torch
+
+from funasr.optimizers.fairseq_adam import FairseqAdam
+from funasr.optimizers.sgd import SGD
+
+
+def build_optimizer(args, model):
+    optim_classes = dict(
+        adam=torch.optim.Adam,
+        fairseq_adam=FairseqAdam,
+        adamw=torch.optim.AdamW,
+        sgd=SGD,
+        adadelta=torch.optim.Adadelta,
+        adagrad=torch.optim.Adagrad,
+        adamax=torch.optim.Adamax,
+        asgd=torch.optim.ASGD,
+        lbfgs=torch.optim.LBFGS,
+        rmsprop=torch.optim.RMSprop,
+        rprop=torch.optim.Rprop,
+    )
+
+    optim_class = optim_classes.get(args.optim)
+    if optim_class is None:
+        raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
+    optimizer = optim_class(model.parameters(), **args.optim_conf)
+
+    optimizers = [optimizer]
+    return optimizers
\ No newline at end of file
diff --git a/funasr/build_utils/build_pretrain_model.py b/funasr/build_utils/build_pretrain_model.py
new file mode 100644
index 0000000..629937f
--- /dev/null
+++ b/funasr/build_utils/build_pretrain_model.py
@@ -0,0 +1,107 @@
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.data2vec import Data2VecPretrainModel
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.specaug import SpecAug
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+    name="frontend",
+    classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
+    default="default",
+)
+specaug_choices = ClassChoices(
+    name="specaug",
+    classes=dict(specaug=SpecAug),
+    default=None,
+    optional=True,
+)
+normalize_choices = ClassChoices(
+    "normalize",
+    classes=dict(
+        global_mvn=GlobalMVN,
+        utterance_mvn=UtteranceMVN,
+    ),
+    default=None,
+    optional=True,
+)
+encoder_choices = ClassChoices(
+    "encoder",
+    classes=dict(
+        data2vec_encoder=Data2VecEncoder,
+    ),
+    default="data2vec_encoder",
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        data2vec=Data2VecPretrainModel,
+    ),
+    default="data2vec",
+)
+class_choices_list = [
+    # --frontend and --frontend_conf
+    frontend_choices,
+    # --specaug and --specaug_conf
+    specaug_choices,
+    # --normalize and --normalize_conf
+    normalize_choices,
+    # --encoder and --encoder_conf
+    encoder_choices,
+    # --model and --model_conf
+    model_choices,
+]
+
+
+def build_pretrain_model(args):
+    # frontend
+    if args.input_size is None:
+        frontend_class = frontend_choices.get_class(args.frontend)
+        frontend = frontend_class(**args.frontend_conf)
+        input_size = frontend.output_size()
+    else:
+        args.frontend = None
+        args.frontend_conf = {}
+        frontend = None
+        input_size = args.input_size
+
+    # data augmentation for spectrogram
+    if args.specaug is not None:
+        specaug_class = specaug_choices.get_class(args.specaug)
+        specaug = specaug_class(**args.specaug_conf)
+    else:
+        specaug = None
+
+    # normalization layer
+    if args.normalize is not None:
+        normalize_class = normalize_choices.get_class(args.normalize)
+        normalize = normalize_class(**args.normalize_conf)
+    else:
+        normalize = None
+
+    # encoder
+    encoder_class = encoder_choices.get_class(args.encoder)
+    encoder = encoder_class(
+        input_size=input_size,
+        **args.encoder_conf,
+    )
+
+    if args.model == "data2vec":
+        model_class = model_choices.get_class("data2vec")
+        model = model_class(
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            encoder=encoder,
+        )
+    else:
+        raise NotImplementedError("Not supported model: {}".format(args.model))
+
+    # initialize
+    if args.init is not None:
+        initialize(model, args.init)
+
+    return model
diff --git a/funasr/build_utils/build_punc_model.py b/funasr/build_utils/build_punc_model.py
new file mode 100644
index 0000000..62ccaf2
--- /dev/null
+++ b/funasr/build_utils/build_punc_model.py
@@ -0,0 +1,68 @@
+import logging
+
+from funasr.models.target_delay_transformer import TargetDelayTransformer
+from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.torch_utils.initialize import initialize
+from funasr.train.abs_model import PunctuationModel
+from funasr.train.class_choices import ClassChoices
+
+punc_choices = ClassChoices(
+    "punctuation",
+    classes=dict(
+        target_delay=TargetDelayTransformer,
+        vad_realtime=VadRealtimeTransformer
+    ),
+    default="target_delay",
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        punc=PunctuationModel,
+    ),
+    default="punc",
+)
+class_choices_list = [
+    # --punc and --punc_conf
+    punc_choices,
+    # --model and --model_conf
+    model_choices
+]
+
+
+def build_punc_model(args):
+    # token_list and punc list
+    if isinstance(args.token_list, str):
+        with open(args.token_list, encoding="utf-8") as f:
+            token_list = [line.rstrip() for line in f]
+        args.token_list = token_list.copy()
+    if isinstance(args.punc_list, str):
+        with open(args.punc_list, encoding="utf-8") as f2:
+            pairs = [line.rstrip().split(":") for line in f2]
+        punc_list = [pair[0] for pair in pairs]
+        punc_weight_list = [float(pair[1]) for pair in pairs]
+        args.punc_list = punc_list.copy()
+    elif isinstance(args.punc_list, list):
+        punc_list = args.punc_list.copy()
+        punc_weight_list = [1] * len(punc_list)
+    if isinstance(args.token_list, (tuple, list)):
+        token_list = args.token_list.copy()
+    else:
+        raise RuntimeError("token_list must be str or dict")
+
+    vocab_size = len(token_list)
+    punc_size = len(punc_list)
+    logging.info(f"Vocabulary size: {vocab_size}")
+
+    # punc
+    punc_class = punc_choices.get_class(args.punctuation)
+    punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
+
+    if "punc_weight" in args.model_conf:
+        args.model_conf.pop("punc_weight")
+    model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
+
+    # initialize
+    if args.init is not None:
+        initialize(model, args.init)
+
+    return model
diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py
new file mode 100644
index 0000000..4b9990e
--- /dev/null
+++ b/funasr/build_utils/build_scheduler.py
@@ -0,0 +1,44 @@
+import torch
+import torch.multiprocessing
+import torch.nn
+import torch.optim
+
+from funasr.schedulers.noam_lr import NoamLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
+from funasr.schedulers.warmup_lr import WarmupLR
+
+
+def build_scheduler(args, optimizers):
+    scheduler_classes = dict(
+        ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
+        lambdalr=torch.optim.lr_scheduler.LambdaLR,
+        steplr=torch.optim.lr_scheduler.StepLR,
+        multisteplr=torch.optim.lr_scheduler.MultiStepLR,
+        exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
+        CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
+        noamlr=NoamLR,
+        warmuplr=WarmupLR,
+        tri_stage=TriStageLR,
+        cycliclr=torch.optim.lr_scheduler.CyclicLR,
+        onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
+        CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
+    )
+
+    schedulers = []
+    for i, optim in enumerate(optimizers, 1):
+        suf = "" if i == 1 else str(i)
+        name = getattr(args, f"scheduler{suf}")
+        conf = getattr(args, f"scheduler{suf}_conf")
+        if name is not None:
+            cls_ = scheduler_classes.get(name)
+            if cls_ is None:
+                raise ValueError(
+                    f"must be one of {list(scheduler_classes)}: {name}"
+                )
+            scheduler = cls_(optim, **conf)
+        else:
+            scheduler = None
+
+        schedulers.append(scheduler)
+
+    return schedulers
\ No newline at end of file
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
new file mode 100644
index 0000000..aff99b5
--- /dev/null
+++ b/funasr/build_utils/build_trainer.py
@@ -0,0 +1,820 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Trainer module."""
+import argparse
+import dataclasses
+import logging
+import os
+import time
+from contextlib import contextmanager
+from dataclasses import is_dataclass
+from distutils.version import LooseVersion
+from io import BytesIO
+from pathlib import Path
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+import humanfriendly
+import oss2
+import torch
+import torch.nn
+import torch.optim
+from typeguard import check_argument_types
+
+from funasr.iterators.abs_iter_factory import AbsIterFactory
+from funasr.main_funcs.average_nbest_models import average_nbest_models
+from funasr.models.base_model import FunASRModel
+from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
+from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
+from funasr.schedulers.abs_scheduler import AbsScheduler
+from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
+from funasr.torch_utils.add_gradient_noise import add_gradient_noise
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.recursive_op import recursive_average
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.train.distributed_utils import DistributedOption
+from funasr.train.reporter import Reporter
+from funasr.train.reporter import SubReporter
+from funasr.utils.build_dataclass import build_dataclass
+
+if torch.distributed.is_available():
+    from torch.distributed import ReduceOp
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+    from torch.cuda.amp import GradScaler
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
+
+    GradScaler = None
+
+try:
+    import fairscale
+except ImportError:
+    fairscale = None
+
+
+@dataclasses.dataclass
+class TrainerOptions:
+    ngpu: int
+    resume: bool
+    use_amp: bool
+    train_dtype: str
+    grad_noise: bool
+    accum_grad: int
+    grad_clip: float
+    grad_clip_type: float
+    log_interval: Optional[int]
+    # no_forward_run: bool
+    use_tensorboard: bool
+    # use_wandb: bool
+    output_dir: Union[Path, str]
+    max_epoch: int
+    max_update: int
+    seed: int
+    # sharded_ddp: bool
+    patience: Optional[int]
+    keep_nbest_models: Union[int, List[int]]
+    nbest_averaging_interval: int
+    early_stopping_criterion: Sequence[str]
+    best_model_criterion: Sequence[Sequence[str]]
+    val_scheduler_criterion: Sequence[str]
+    unused_parameters: bool
+    # wandb_model_log_interval: int
+    use_pai: bool
+    oss_bucket: Union[oss2.Bucket, None]
+
+
+class Trainer:
+    """Trainer
+
+    """
+
+    def __init__(self,
+                 args,
+                 model: FunASRModel,
+                 optimizers: Sequence[torch.optim.Optimizer],
+                 schedulers: Sequence[Optional[AbsScheduler]],
+                 train_dataloader: AbsIterFactory,
+                 valid_dataloader: AbsIterFactory,
+                 distributed_option: DistributedOption):
+        self.trainer_options = self.build_options(args)
+        self.model = model
+        self.optimizers = optimizers
+        self.schedulers = schedulers
+        self.train_dataloader = train_dataloader
+        self.valid_dataloader = valid_dataloader
+        self.distributed_option = distributed_option
+
+    def build_options(self, args: argparse.Namespace) -> TrainerOptions:
+        """Build options consumed by train(), eval()"""
+        assert check_argument_types()
+        return build_dataclass(TrainerOptions, args)
+
+    @classmethod
+    def add_arguments(cls, parser: argparse.ArgumentParser):
+        """Reserved for future development of another Trainer"""
+        pass
+
+    def resume(self,
+               checkpoint: Union[str, Path],
+               model: torch.nn.Module,
+               reporter: Reporter,
+               optimizers: Sequence[torch.optim.Optimizer],
+               schedulers: Sequence[Optional[AbsScheduler]],
+               scaler: Optional[GradScaler],
+               ngpu: int = 0,
+               ):
+        states = torch.load(
+            checkpoint,
+            map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
+        )
+        model.load_state_dict(states["model"])
+        reporter.load_state_dict(states["reporter"])
+        for optimizer, state in zip(optimizers, states["optimizers"]):
+            optimizer.load_state_dict(state)
+        for scheduler, state in zip(schedulers, states["schedulers"]):
+            if scheduler is not None:
+                scheduler.load_state_dict(state)
+        if scaler is not None:
+            if states["scaler"] is None:
+                logging.warning("scaler state is not found")
+            else:
+                scaler.load_state_dict(states["scaler"])
+
+        logging.info(f"The training was resumed using {checkpoint}")
+
+    def run(self) -> None:
+        """Perform training. This method performs the main process of training."""
+        assert check_argument_types()
+        # NOTE(kamo): Don't check the type more strictly as far trainer_options
+        model = self.model
+        optimizers = self.optimizers
+        schedulers = self.schedulers
+        train_dataloader = self.train_dataloader
+        valid_dataloader = self.valid_dataloader
+        trainer_options = self.trainer_options
+        distributed_option = self.distributed_option
+        assert is_dataclass(trainer_options), type(trainer_options)
+        assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
+
+        if isinstance(trainer_options.keep_nbest_models, int):
+            keep_nbest_models = [trainer_options.keep_nbest_models]
+        else:
+            if len(trainer_options.keep_nbest_models) == 0:
+                logging.warning("No keep_nbest_models is given. Change to [1]")
+                trainer_options.keep_nbest_models = [1]
+            keep_nbest_models = trainer_options.keep_nbest_models
+
+        output_dir = Path(trainer_options.output_dir)
+        reporter = Reporter()
+        if trainer_options.use_amp:
+            if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
+                raise RuntimeError(
+                    "Require torch>=1.6.0 for  Automatic Mixed Precision"
+                )
+            # if trainer_options.sharded_ddp:
+            #     if fairscale is None:
+            #         raise RuntimeError(
+            #             "Requiring fairscale. Do 'pip install fairscale'"
+            #         )
+            #     scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
+            # else:
+            scaler = GradScaler()
+        else:
+            scaler = None
+
+        if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
+            self.resume(
+                checkpoint=output_dir / "checkpoint.pb",
+                model=model,
+                optimizers=optimizers,
+                schedulers=schedulers,
+                reporter=reporter,
+                scaler=scaler,
+                ngpu=trainer_options.ngpu,
+            )
+
+        start_epoch = reporter.get_epoch() + 1
+        if start_epoch == trainer_options.max_epoch + 1:
+            logging.warning(
+                f"The training has already reached at max_epoch: {start_epoch}"
+            )
+
+        if distributed_option.distributed:
+            dp_model = torch.nn.parallel.DistributedDataParallel(
+                model, find_unused_parameters=trainer_options.unused_parameters)
+        elif distributed_option.ngpu > 1:
+            dp_model = torch.nn.parallel.DataParallel(
+                model,
+                device_ids=list(range(distributed_option.ngpu)),
+            )
+        else:
+            # NOTE(kamo): DataParallel also should work with ngpu=1,
+            # but for debuggability it's better to keep this block.
+            dp_model = model
+
+        if trainer_options.use_tensorboard and (
+                not distributed_option.distributed or distributed_option.dist_rank == 0
+        ):
+            from torch.utils.tensorboard import SummaryWriter
+            if trainer_options.use_pai:
+                train_summary_writer = SummaryWriter(
+                    os.path.join(trainer_options.output_dir, "tensorboard/train")
+                )
+                valid_summary_writer = SummaryWriter(
+                    os.path.join(trainer_options.output_dir, "tensorboard/valid")
+                )
+            else:
+                train_summary_writer = SummaryWriter(
+                    str(output_dir / "tensorboard" / "train")
+                )
+                valid_summary_writer = SummaryWriter(
+                    str(output_dir / "tensorboard" / "valid")
+                )
+        else:
+            train_summary_writer = None
+
+        start_time = time.perf_counter()
+        for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
+            if iepoch != start_epoch:
+                logging.info(
+                    "{}/{}epoch started. Estimated time to finish: {}".format(
+                        iepoch,
+                        trainer_options.max_epoch,
+                        humanfriendly.format_timespan(
+                            (time.perf_counter() - start_time)
+                            / (iepoch - start_epoch)
+                            * (trainer_options.max_epoch - iepoch + 1)
+                        ),
+                    )
+                )
+            else:
+                logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
+            set_all_random_seed(trainer_options.seed + iepoch)
+
+            reporter.set_epoch(iepoch)
+            # 1. Train and validation for one-epoch
+            with reporter.observe("train") as sub_reporter:
+                all_steps_are_invalid, max_update_stop = self.train_one_epoch(
+                    model=dp_model,
+                    optimizers=optimizers,
+                    schedulers=schedulers,
+                    iterator=train_dataloader.build_iter(iepoch),
+                    reporter=sub_reporter,
+                    scaler=scaler,
+                    summary_writer=train_summary_writer,
+                    options=trainer_options,
+                    distributed_option=distributed_option,
+                )
+
+            with reporter.observe("valid") as sub_reporter:
+                self.validate_one_epoch(
+                    model=dp_model,
+                    iterator=valid_dataloader.build_iter(iepoch),
+                    reporter=sub_reporter,
+                    options=trainer_options,
+                    distributed_option=distributed_option,
+                )
+
+            # 2. LR Scheduler step
+            for scheduler in schedulers:
+                if isinstance(scheduler, AbsValEpochStepScheduler):
+                    scheduler.step(
+                        reporter.get_value(*trainer_options.val_scheduler_criterion)
+                    )
+                elif isinstance(scheduler, AbsEpochStepScheduler):
+                    scheduler.step()
+            # if trainer_options.sharded_ddp:
+            #     for optimizer in optimizers:
+            #         if isinstance(optimizer, fairscale.optim.oss.OSS):
+            #             optimizer.consolidate_state_dict()
+
+            if not distributed_option.distributed or distributed_option.dist_rank == 0:
+                # 3. Report the results
+                logging.info(reporter.log_message())
+                if train_summary_writer is not None:
+                    reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
+                    reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
+                # if trainer_options.use_wandb:
+                #     reporter.wandb_log()
+
+                # save tensorboard on oss
+                if trainer_options.use_pai and train_summary_writer is not None:
+                    def write_tensorboard_summary(summary_writer_path, oss_bucket):
+                        file_list = []
+                        for root, dirs, files in os.walk(summary_writer_path, topdown=False):
+                            for name in files:
+                                file_full_path = os.path.join(root, name)
+                                file_list.append(file_full_path)
+
+                        for file_full_path in file_list:
+                            with open(file_full_path, "rb") as f:
+                                oss_bucket.put_object(file_full_path, f)
+
+                    write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"),
+                                              trainer_options.oss_bucket)
+                    write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"),
+                                              trainer_options.oss_bucket)
+
+                # 4. Save/Update the checkpoint
+                if trainer_options.use_pai:
+                    buffer = BytesIO()
+                    torch.save(
+                        {
+                            "model": model.state_dict(),
+                            "reporter": reporter.state_dict(),
+                            "optimizers": [o.state_dict() for o in optimizers],
+                            "schedulers": [
+                                s.state_dict() if s is not None else None
+                                for s in schedulers
+                            ],
+                            "scaler": scaler.state_dict() if scaler is not None else None,
+                            "ema_model": model.encoder.ema.model.state_dict()
+                            if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
+                        },
+                        buffer,
+                    )
+                    trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"),
+                                                          buffer.getvalue())
+                else:
+                    torch.save(
+                        {
+                            "model": model.state_dict(),
+                            "reporter": reporter.state_dict(),
+                            "optimizers": [o.state_dict() for o in optimizers],
+                            "schedulers": [
+                                s.state_dict() if s is not None else None
+                                for s in schedulers
+                            ],
+                            "scaler": scaler.state_dict() if scaler is not None else None,
+                        },
+                        output_dir / "checkpoint.pb",
+                    )
+
+                # 5. Save and log the model and update the link to the best model
+                if trainer_options.use_pai:
+                    buffer = BytesIO()
+                    torch.save(model.state_dict(), buffer)
+                    trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
+                                                                       f"{iepoch}epoch.pb"), buffer.getvalue())
+                else:
+                    torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
+
+                # Creates a sym link latest.pb -> {iepoch}epoch.pb
+                if trainer_options.use_pai:
+                    p = os.path.join(trainer_options.output_dir, "latest.pb")
+                    if trainer_options.oss_bucket.object_exists(p):
+                        trainer_options.oss_bucket.delete_object(p)
+                    trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
+                                                           os.path.join(trainer_options.output_dir,
+                                                                        f"{iepoch}epoch.pb"), p)
+                else:
+                    p = output_dir / "latest.pb"
+                    if p.is_symlink() or p.exists():
+                        p.unlink()
+                    p.symlink_to(f"{iepoch}epoch.pb")
+
+                _improved = []
+                for _phase, k, _mode in trainer_options.best_model_criterion:
+                    # e.g. _phase, k, _mode = "train", "loss", "min"
+                    if reporter.has(_phase, k):
+                        best_epoch = reporter.get_best_epoch(_phase, k, _mode)
+                        # Creates sym links if it's the best result
+                        if best_epoch == iepoch:
+                            if trainer_options.use_pai:
+                                p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
+                                if trainer_options.oss_bucket.object_exists(p):
+                                    trainer_options.oss_bucket.delete_object(p)
+                                trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
+                                                                       os.path.join(trainer_options.output_dir,
+                                                                                    f"{iepoch}epoch.pb"), p)
+                            else:
+                                p = output_dir / f"{_phase}.{k}.best.pb"
+                                if p.is_symlink() or p.exists():
+                                    p.unlink()
+                                p.symlink_to(f"{iepoch}epoch.pb")
+                            _improved.append(f"{_phase}.{k}")
+                if len(_improved) == 0:
+                    logging.info("There are no improvements in this epoch")
+                else:
+                    logging.info(
+                        "The best model has been updated: " + ", ".join(_improved)
+                    )
+
+                # log_model = (
+                #         trainer_options.wandb_model_log_interval > 0
+                #         and iepoch % trainer_options.wandb_model_log_interval == 0
+                # )
+                # if log_model and trainer_options.use_wandb:
+                #     import wandb
+                #
+                #     logging.info("Logging Model on this epoch :::::")
+                #     artifact = wandb.Artifact(
+                #         name=f"model_{wandb.run.id}",
+                #         type="model",
+                #         metadata={"improved": _improved},
+                #     )
+                #     artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
+                #     aliases = [
+                #         f"epoch-{iepoch}",
+                #         "best" if best_epoch == iepoch else "",
+                #     ]
+                #     wandb.log_artifact(artifact, aliases=aliases)
+
+                # 6. Remove the model files excluding n-best epoch and latest epoch
+                _removed = []
+                # Get the union set of the n-best among multiple criterion
+                nbests = set().union(
+                    *[
+                        set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
+                        for ph, k, m in trainer_options.best_model_criterion
+                        if reporter.has(ph, k)
+                    ]
+                )
+
+                # Generated n-best averaged model
+                if (
+                        trainer_options.nbest_averaging_interval > 0
+                        and iepoch % trainer_options.nbest_averaging_interval == 0
+                ):
+                    average_nbest_models(
+                        reporter=reporter,
+                        output_dir=output_dir,
+                        best_model_criterion=trainer_options.best_model_criterion,
+                        nbest=keep_nbest_models,
+                        suffix=f"till{iepoch}epoch",
+                        oss_bucket=trainer_options.oss_bucket,
+                        pai_output_dir=trainer_options.output_dir,
+                    )
+
+                for e in range(1, iepoch):
+                    if trainer_options.use_pai:
+                        p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
+                        if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
+                            trainer_options.oss_bucket.delete_object(p)
+                            _removed.append(str(p))
+                    else:
+                        p = output_dir / f"{e}epoch.pb"
+                        if p.exists() and e not in nbests:
+                            p.unlink()
+                            _removed.append(str(p))
+                if len(_removed) != 0:
+                    logging.info("The model files were removed: " + ", ".join(_removed))
+
+            # 7. If any updating haven't happened, stops the training
+            if all_steps_are_invalid:
+                logging.warning(
+                    f"The gradients at all steps are invalid in this epoch. "
+                    f"Something seems wrong. This training was stopped at {iepoch}epoch"
+                )
+                break
+
+            if max_update_stop:
+                logging.info(
+                    f"Stopping training due to "
+                    f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
+                )
+                break
+
+            # 8. Check early stopping
+            if trainer_options.patience is not None:
+                if reporter.check_early_stopping(
+                        trainer_options.patience, *trainer_options.early_stopping_criterion
+                ):
+                    break
+
+        else:
+            logging.info(
+                f"The training was finished at {trainer_options.max_epoch} epochs "
+            )
+
+        # Generated n-best averaged model
+        if not distributed_option.distributed or distributed_option.dist_rank == 0:
+            average_nbest_models(
+                reporter=reporter,
+                output_dir=output_dir,
+                best_model_criterion=trainer_options.best_model_criterion,
+                nbest=keep_nbest_models,
+                oss_bucket=trainer_options.oss_bucket,
+                pai_output_dir=trainer_options.output_dir,
+            )
+
+    def train_one_epoch(
+            self,
+            model: torch.nn.Module,
+            iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
+            optimizers: Sequence[torch.optim.Optimizer],
+            schedulers: Sequence[Optional[AbsScheduler]],
+            scaler: Optional[GradScaler],
+            reporter: SubReporter,
+            summary_writer,
+            options: TrainerOptions,
+            distributed_option: DistributedOption,
+    ) -> Tuple[bool, bool]:
+        assert check_argument_types()
+
+        grad_noise = options.grad_noise
+        accum_grad = options.accum_grad
+        grad_clip = options.grad_clip
+        grad_clip_type = options.grad_clip_type
+        log_interval = options.log_interval
+        # no_forward_run = options.no_forward_run
+        ngpu = options.ngpu
+        # use_wandb = options.use_wandb
+        distributed = distributed_option.distributed
+
+        if log_interval is None:
+            try:
+                log_interval = max(len(iterator) // 20, 10)
+            except TypeError:
+                log_interval = 100
+
+        model.train()
+        all_steps_are_invalid = True
+        max_update_stop = False
+        # [For distributed] Because iteration counts are not always equals between
+        # processes, send stop-flag to the other processes if iterator is finished
+        iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
+
+        start_time = time.perf_counter()
+        for iiter, (_, batch) in enumerate(
+                reporter.measure_iter_time(iterator, "iter_time"), 1
+        ):
+            assert isinstance(batch, dict), type(batch)
+
+            if distributed:
+                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+                if iterator_stop > 0:
+                    break
+
+            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
+            # if no_forward_run:
+            #     all_steps_are_invalid = False
+            #     continue
+
+            with autocast(scaler is not None):
+                with reporter.measure_time("forward_time"):
+                    retval = model(**batch)
+
+                    # Note(kamo):
+                    # Supporting two patterns for the returned value from the model
+                    #   a. dict type
+                    if isinstance(retval, dict):
+                        loss = retval["loss"]
+                        stats = retval["stats"]
+                        weight = retval["weight"]
+                        optim_idx = retval.get("optim_idx")
+                        if optim_idx is not None and not isinstance(optim_idx, int):
+                            if not isinstance(optim_idx, torch.Tensor):
+                                raise RuntimeError(
+                                    "optim_idx must be int or 1dim torch.Tensor, "
+                                    f"but got {type(optim_idx)}"
+                                )
+                            if optim_idx.dim() >= 2:
+                                raise RuntimeError(
+                                    "optim_idx must be int or 1dim torch.Tensor, "
+                                    f"but got {optim_idx.dim()}dim tensor"
+                                )
+                            if optim_idx.dim() == 1:
+                                for v in optim_idx:
+                                    if v != optim_idx[0]:
+                                        raise RuntimeError(
+                                            "optim_idx must be 1dim tensor "
+                                            "having same values for all entries"
+                                        )
+                                optim_idx = optim_idx[0].item()
+                            else:
+                                optim_idx = optim_idx.item()
+
+                    #   b. tuple or list type
+                    else:
+                        loss, stats, weight = retval
+                        optim_idx = None
+
+                stats = {k: v for k, v in stats.items() if v is not None}
+                if ngpu > 1 or distributed:
+                    # Apply weighted averaging for loss and stats
+                    loss = (loss * weight.type(loss.dtype)).sum()
+
+                    # if distributed, this method can also apply all_reduce()
+                    stats, weight = recursive_average(stats, weight, distributed)
+
+                    # Now weight is summation over all workers
+                    loss /= weight
+                if distributed:
+                    # NOTE(kamo): Multiply world_size because DistributedDataParallel
+                    # automatically normalizes the gradient by world_size.
+                    loss *= torch.distributed.get_world_size()
+
+                loss /= accum_grad
+
+            reporter.register(stats, weight)
+
+            with reporter.measure_time("backward_time"):
+                if scaler is not None:
+                    # Scales loss.  Calls backward() on scaled loss
+                    # to create scaled gradients.
+                    # Backward passes under autocast are not recommended.
+                    # Backward ops run in the same dtype autocast chose
+                    # for corresponding forward ops.
+                    scaler.scale(loss).backward()
+                else:
+                    loss.backward()
+
+            if iiter % accum_grad == 0:
+                if scaler is not None:
+                    # Unscales the gradients of optimizer's assigned params in-place
+                    for iopt, optimizer in enumerate(optimizers):
+                        if optim_idx is not None and iopt != optim_idx:
+                            continue
+                        scaler.unscale_(optimizer)
+
+                # gradient noise injection
+                if grad_noise:
+                    add_gradient_noise(
+                        model,
+                        reporter.get_total_count(),
+                        duration=100,
+                        eta=1.0,
+                        scale_factor=0.55,
+                    )
+
+                # compute the gradient norm to check if it is normal or not
+                grad_norm = torch.nn.utils.clip_grad_norm_(
+                    model.parameters(),
+                    max_norm=grad_clip,
+                    norm_type=grad_clip_type,
+                )
+                # PyTorch<=1.4, clip_grad_norm_ returns float value
+                if not isinstance(grad_norm, torch.Tensor):
+                    grad_norm = torch.tensor(grad_norm)
+
+                if not torch.isfinite(grad_norm):
+                    logging.warning(
+                        f"The grad norm is {grad_norm}. Skipping updating the model."
+                    )
+
+                    # Must invoke scaler.update() if unscale_() is used in the iteration
+                    # to avoid the following error:
+                    #   RuntimeError: unscale_() has already been called
+                    #   on this optimizer since the last update().
+                    # Note that if the gradient has inf/nan values,
+                    # scaler.step skips optimizer.step().
+                    if scaler is not None:
+                        for iopt, optimizer in enumerate(optimizers):
+                            if optim_idx is not None and iopt != optim_idx:
+                                continue
+                            scaler.step(optimizer)
+                            scaler.update()
+
+                else:
+                    all_steps_are_invalid = False
+                    with reporter.measure_time("optim_step_time"):
+                        for iopt, (optimizer, scheduler) in enumerate(
+                                zip(optimizers, schedulers)
+                        ):
+                            if optim_idx is not None and iopt != optim_idx:
+                                continue
+                            if scaler is not None:
+                                # scaler.step() first unscales the gradients of
+                                # the optimizer's assigned params.
+                                scaler.step(optimizer)
+                                # Updates the scale for next iteration.
+                                scaler.update()
+                            else:
+                                optimizer.step()
+                            if isinstance(scheduler, AbsBatchStepScheduler):
+                                scheduler.step()
+                for iopt, optimizer in enumerate(optimizers):
+                    if optim_idx is not None and iopt != optim_idx:
+                        continue
+                    optimizer.zero_grad()
+
+                # Register lr and train/load time[sec/step],
+                # where step refers to accum_grad * mini-batch
+                reporter.register(
+                    dict(
+                        {
+                            f"optim{i}_lr{j}": pg["lr"]
+                            for i, optimizer in enumerate(optimizers)
+                            for j, pg in enumerate(optimizer.param_groups)
+                            if "lr" in pg
+                        },
+                        train_time=time.perf_counter() - start_time,
+                    ),
+                )
+                start_time = time.perf_counter()
+
+                # update num_updates
+                if distributed:
+                    if hasattr(model.module, "num_updates"):
+                        model.module.set_num_updates(model.module.get_num_updates() + 1)
+                        options.num_updates = model.module.get_num_updates()
+                        if model.module.get_num_updates() >= options.max_update:
+                            max_update_stop = True
+                else:
+                    if hasattr(model, "num_updates"):
+                        model.set_num_updates(model.get_num_updates() + 1)
+                        options.num_updates = model.get_num_updates()
+                        if model.get_num_updates() >= options.max_update:
+                            max_update_stop = True
+
+            # NOTE(kamo): Call log_message() after next()
+            reporter.next()
+            if iiter % log_interval == 0:
+                num_updates = options.num_updates if hasattr(options, "num_updates") else None
+                logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
+                if summary_writer is not None:
+                    reporter.tensorboard_add_scalar(summary_writer, -log_interval)
+                # if use_wandb:
+                #     reporter.wandb_log()
+
+            if max_update_stop:
+                break
+
+        else:
+            if distributed:
+                iterator_stop.fill_(1)
+                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+        return all_steps_are_invalid, max_update_stop
+
+    @torch.no_grad()
+    def validate_one_epoch(
+            self,
+            model: torch.nn.Module,
+            iterator: Iterable[Dict[str, torch.Tensor]],
+            reporter: SubReporter,
+            options: TrainerOptions,
+            distributed_option: DistributedOption,
+    ) -> None:
+        assert check_argument_types()
+        ngpu = options.ngpu
+        # no_forward_run = options.no_forward_run
+        distributed = distributed_option.distributed
+
+        model.eval()
+
+        # [For distributed] Because iteration counts are not always equals between
+        # processes, send stop-flag to the other processes if iterator is finished
+        iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
+        for (_, batch) in iterator:
+            assert isinstance(batch, dict), type(batch)
+            if distributed:
+                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+                if iterator_stop > 0:
+                    break
+
+            batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
+            # if no_forward_run:
+            #     continue
+
+            retval = model(**batch)
+            if isinstance(retval, dict):
+                stats = retval["stats"]
+                weight = retval["weight"]
+            else:
+                _, stats, weight = retval
+            if ngpu > 1 or distributed:
+                # Apply weighted averaging for stats.
+                # if distributed, this method can also apply all_reduce()
+                stats, weight = recursive_average(stats, weight, distributed)
+
+            reporter.register(stats, weight)
+            reporter.next()
+
+        else:
+            if distributed:
+                iterator_stop.fill_(1)
+                torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+
+
+def build_trainer(
+        args,
+        model: FunASRModel,
+        optimizers: Sequence[torch.optim.Optimizer],
+        schedulers: Sequence[Optional[AbsScheduler]],
+        train_dataloader: AbsIterFactory,
+        valid_dataloader: AbsIterFactory,
+        distributed_option: DistributedOption
+):
+    trainer = Trainer(
+        args=args,
+        model=model,
+        optimizers=optimizers,
+        schedulers=schedulers,
+        train_dataloader=train_dataloader,
+        valid_dataloader=valid_dataloader,
+        distributed_option=distributed_option
+    )
+    return trainer
diff --git a/funasr/build_utils/build_vad_model.py b/funasr/build_utils/build_vad_model.py
new file mode 100644
index 0000000..76eb09b
--- /dev/null
+++ b/funasr/build_utils/build_vad_model.py
@@ -0,0 +1,77 @@
+import torch
+
+from funasr.models.e2e_vad import E2EVadModel
+from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+    name="frontend",
+    classes=dict(
+        default=DefaultFrontend,
+        sliding_window=SlidingWindow,
+        s3prl=S3prlFrontend,
+        fused=FusedFrontends,
+        wav_frontend=WavFrontend,
+        wav_frontend_online=WavFrontendOnline,
+    ),
+    default="default",
+)
+encoder_choices = ClassChoices(
+    "encoder",
+    classes=dict(
+        fsmn=FSMN,
+    ),
+    type_check=torch.nn.Module,
+    default="fsmn",
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        e2evad=E2EVadModel,
+    ),
+    default="e2evad",
+)
+
+class_choices_list = [
+    # --frontend and --frontend_conf
+    frontend_choices,
+    # --encoder and --encoder_conf
+    encoder_choices,
+    # --model and --model_conf
+    model_choices,
+]
+
+
+def build_vad_model(args):
+    # frontend
+    if args.input_size is None:
+        frontend_class = frontend_choices.get_class(args.frontend)
+        if args.frontend == 'wav_frontend':
+            frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+        else:
+            frontend = frontend_class(**args.frontend_conf)
+        input_size = frontend.output_size()
+    else:
+        args.frontend = None
+        args.frontend_conf = {}
+        frontend = None
+        input_size = args.input_size
+
+    # encoder
+    encoder_class = encoder_choices.get_class(args.encoder)
+    encoder = encoder_class(**args.encoder_conf)
+
+    model_class = model_choices.get_class(args.model)
+    model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
+
+    # initialize
+    if args.init is not None:
+        initialize(model, args.init)
+
+    return model
diff --git a/funasr/datasets/iterable_dataset_modelscope.py b/funasr/datasets/iterable_dataset_modelscope.py
deleted file mode 100644
index 860492c..0000000
--- a/funasr/datasets/iterable_dataset_modelscope.py
+++ /dev/null
@@ -1,349 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-# Part of the implementation is borrowed from espnet/espnet.
-"""Iterable dataset module."""
-import copy
-from io import StringIO
-from pathlib import Path
-from typing import Callable, Collection, Dict, Iterator, Tuple, Union
-
-import kaldiio
-import numpy as np
-import soundfile
-import torch
-from funasr.datasets.dataset import ESPnetDataset
-from torch.utils.data.dataset import IterableDataset
-from typeguard import check_argument_types
-
-from funasr.utils import wav_utils
-
-
-def load_kaldi(input):
-    retval = kaldiio.load_mat(input)
-    if isinstance(retval, tuple):
-        assert len(retval) == 2, len(retval)
-        if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
-            # sound scp case
-            rate, array = retval
-        elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
-            # Extended ark format case
-            array, rate = retval
-        else:
-            raise RuntimeError(
-                f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
-
-        # Multichannel wave fie
-        # array: (NSample, Channel) or (Nsample)
-
-    else:
-        # Normal ark case
-        assert isinstance(retval, np.ndarray), type(retval)
-        array = retval
-    return array
-
-
-DATA_TYPES = {
-    'sound':
-    lambda x: soundfile.read(x)[0],
-    'kaldi_ark':
-    load_kaldi,
-    'npy':
-    np.load,
-    'text_int':
-    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
-    'csv_int':
-    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
-    'text_float':
-    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
-                         ),
-    'csv_float':
-    lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
-                         ),
-    'text':
-    lambda x: x,
-}
-
-
-class IterableESPnetDatasetModelScope(IterableDataset):
-    """Pytorch Dataset class for ESPNet.
-
-    Examples:
-        >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
-        ...                                  ('token_int', 'output', 'text_int')],
-        ...                                )
-        >>> for uid, data in dataset:
-        ...     data
-        {'input': per_utt_array, 'output': per_utt_array}
-    """
-    def __init__(self,
-                 path_name_type_list: Collection[Tuple[any, str, str]],
-                 preprocess: Callable[[str, Dict[str, np.ndarray]],
-                                      Dict[str, np.ndarray]] = None,
-                 float_dtype: str = 'float32',
-                 int_dtype: str = 'long',
-                 key_file: str = None,
-                 sample_rate: Union[dict, int] = 16000):
-        assert check_argument_types()
-        if len(path_name_type_list) == 0:
-            raise ValueError(
-                '1 or more elements are required for "path_name_type_list"')
-
-        self.preprocess = preprocess
-
-        self.float_dtype = float_dtype
-        self.int_dtype = int_dtype
-        self.key_file = key_file
-        self.sample_rate = sample_rate
-
-        self.debug_info = {}
-        non_iterable_list = []
-        self.path_name_type_list = []
-
-        path_list = path_name_type_list[0]
-        name = path_name_type_list[1]
-        _type = path_name_type_list[2]
-        if name in self.debug_info:
-            raise RuntimeError(f'"{name}" is duplicated for data-key')
-        self.debug_info[name] = path_list, _type
-        #        for path, name, _type in path_name_type_list:
-        for path in path_list:
-            self.path_name_type_list.append((path, name, _type))
-
-        if len(non_iterable_list) != 0:
-            # Some types doesn't support iterable mode
-            self.non_iterable_dataset = ESPnetDataset(
-                path_name_type_list=non_iterable_list,
-                preprocess=preprocess,
-                float_dtype=float_dtype,
-                int_dtype=int_dtype,
-            )
-        else:
-            self.non_iterable_dataset = None
-
-        self.apply_utt2category = False
-
-    def has_name(self, name) -> bool:
-        return name in self.debug_info
-
-    def names(self) -> Tuple[str, ...]:
-        return tuple(self.debug_info)
-
-    def __repr__(self):
-        _mes = self.__class__.__name__
-        _mes += '('
-        for name, (path, _type) in self.debug_info.items():
-            _mes += f'\n  {name}: {{"path": "{path}", "type": "{_type}"}}'
-        _mes += f'\n  preprocess: {self.preprocess})'
-        return _mes
-
-    def __iter__(
-            self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
-        torch.set_printoptions(profile='default')
-        count = len(self.path_name_type_list)
-        for idx in range(count):
-            # 2. Load the entry from each line and create a dict
-            data = {}
-            # 2.a. Load data streamingly
-
-            # value:  /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
-            value = self.path_name_type_list[idx][0]['file']
-            uid = self.path_name_type_list[idx][0]['key']
-            # name:  speech
-            name = self.path_name_type_list[idx][1]
-            _type = self.path_name_type_list[idx][2]
-            func = DATA_TYPES[_type]
-            array = func(value)
-
-            # 2.b. audio resample
-            if _type == 'sound':
-                audio_sr: int = 16000
-                model_sr: int = 16000
-                if isinstance(self.sample_rate, int):
-                    model_sr = self.sample_rate
-                else:
-                    if 'audio_sr' in self.sample_rate:
-                        audio_sr = self.sample_rate['audio_sr']
-                    if 'model_sr' in self.sample_rate:
-                        model_sr = self.sample_rate['model_sr']
-                array = wav_utils.torch_resample(array, audio_sr, model_sr)
-
-            # array:  [ 1.25122070e-03  ... ]
-            data[name] = array
-
-            # 3. [Option] Apply preprocessing
-            #   e.g. espnet2.train.preprocessor:CommonPreprocessor
-            if self.preprocess is not None:
-                data = self.preprocess(uid, data)
-                # data:  {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
-
-            # 4. Force data-precision
-            for name in data:
-                # value is np.ndarray data
-                value = data[name]
-                if not isinstance(value, np.ndarray):
-                    raise RuntimeError(
-                        f'All values must be converted to np.ndarray object '
-                        f'by preprocessing, but "{name}" is still {type(value)}.'
-                    )
-
-                # Cast to desired type
-                if value.dtype.kind == 'f':
-                    value = value.astype(self.float_dtype)
-                elif value.dtype.kind == 'i':
-                    value = value.astype(self.int_dtype)
-                else:
-                    raise NotImplementedError(
-                        f'Not supported dtype: {value.dtype}')
-                data[name] = value
-
-            yield uid, data
-
-        if count == 0:
-            raise RuntimeError('No iteration')
-
-
-class IterableESPnetBytesModelScope(IterableDataset):
-    """Pytorch audio bytes class for ESPNet.
-
-    Examples:
-        >>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
-        ...                                ('token_int', 'output', 'text_int')],
-        ...                                )
-        >>> for uid, data in dataset:
-        ...     data
-        {'input': per_utt_array, 'output': per_utt_array}
-    """
-    def __init__(self,
-                 path_name_type_list: Collection[Tuple[any, str, str]],
-                 preprocess: Callable[[str, Dict[str, np.ndarray]],
-                                      Dict[str, np.ndarray]] = None,
-                 float_dtype: str = 'float32',
-                 int_dtype: str = 'long',
-                 key_file: str = None,
-                 sample_rate: Union[dict, int] = 16000):
-        assert check_argument_types()
-        if len(path_name_type_list) == 0:
-            raise ValueError(
-                '1 or more elements are required for "path_name_type_list"')
-
-        self.preprocess = preprocess
-
-        self.float_dtype = float_dtype
-        self.int_dtype = int_dtype
-        self.key_file = key_file
-        self.sample_rate = sample_rate
-
-        self.debug_info = {}
-        non_iterable_list = []
-        self.path_name_type_list = []
-
-        audio_data = path_name_type_list[0]
-        name = path_name_type_list[1]
-        _type = path_name_type_list[2]
-        if name in self.debug_info:
-            raise RuntimeError(f'"{name}" is duplicated for data-key')
-        self.debug_info[name] = audio_data, _type
-        self.path_name_type_list.append((audio_data, name, _type))
-
-        if len(non_iterable_list) != 0:
-            # Some types doesn't support iterable mode
-            self.non_iterable_dataset = ESPnetDataset(
-                path_name_type_list=non_iterable_list,
-                preprocess=preprocess,
-                float_dtype=float_dtype,
-                int_dtype=int_dtype,
-            )
-        else:
-            self.non_iterable_dataset = None
-
-        self.apply_utt2category = False
-
-        if float_dtype == 'float32':
-            self.np_dtype = np.float32
-
-    def has_name(self, name) -> bool:
-        return name in self.debug_info
-
-    def names(self) -> Tuple[str, ...]:
-        return tuple(self.debug_info)
-
-    def __repr__(self):
-        _mes = self.__class__.__name__
-        _mes += '('
-        for name, (path, _type) in self.debug_info.items():
-            _mes += f'\n  {name}: {{"path": "{path}", "type": "{_type}"}}'
-        _mes += f'\n  preprocess: {self.preprocess})'
-        return _mes
-
-    def __iter__(
-            self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
-
-        torch.set_printoptions(profile='default')
-        # 2. Load the entry from each line and create a dict
-        data = {}
-        # 2.a. Load data streamingly
-
-        value = self.path_name_type_list[0][0]
-        uid = 'pcm_data'
-        # name:  speech
-        name = self.path_name_type_list[0][1]
-        _type = self.path_name_type_list[0][2]
-        func = DATA_TYPES[_type]
-        # array:  [ 1.25122070e-03  ... ]
-        #        data[name] = np.frombuffer(value, dtype=self.np_dtype)
-
-        # 2.b. byte(PCM16) to float32
-        middle_data = np.frombuffer(value, dtype=np.int16)
-        middle_data = np.asarray(middle_data)
-        if middle_data.dtype.kind not in 'iu':
-            raise TypeError("'middle_data' must be an array of integers")
-        dtype = np.dtype('float32')
-        if dtype.kind != 'f':
-            raise TypeError("'dtype' must be a floating point type")
-
-        i = np.iinfo(middle_data.dtype)
-        abs_max = 2**(i.bits - 1)
-        offset = i.min + abs_max
-        array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
-                              dtype=self.np_dtype)
-
-        # 2.c. audio resample
-        if _type == 'sound':
-            audio_sr: int = 16000
-            model_sr: int = 16000
-            if isinstance(self.sample_rate, int):
-                model_sr = self.sample_rate
-            else:
-                if 'audio_sr' in self.sample_rate:
-                    audio_sr = self.sample_rate['audio_sr']
-                if 'model_sr' in self.sample_rate:
-                    model_sr = self.sample_rate['model_sr']
-            array = wav_utils.torch_resample(array, audio_sr, model_sr)
-
-        data[name] = array
-
-        # 3. [Option] Apply preprocessing
-        #   e.g. espnet2.train.preprocessor:CommonPreprocessor
-        if self.preprocess is not None:
-            data = self.preprocess(uid, data)
-            # data:  {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
-
-        # 4. Force data-precision
-        for name in data:
-            # value is np.ndarray data
-            value = data[name]
-            if not isinstance(value, np.ndarray):
-                raise RuntimeError(
-                    f'All values must be converted to np.ndarray object '
-                    f'by preprocessing, but "{name}" is still {type(value)}.')
-
-            # Cast to desired type
-            if value.dtype.kind == 'f':
-                value = value.astype(self.float_dtype)
-            elif value.dtype.kind == 'i':
-                value = value.astype(self.int_dtype)
-            else:
-                raise NotImplementedError(
-                    f'Not supported dtype: {value.dtype}')
-            data[name] = value
-
-        yield uid, data
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 156f608..318ae0b 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -64,28 +64,26 @@
         return self.sp.DecodePieces(list(tokens))
 
 
-class ArkDataLoader(AbsIterFactory):
-    def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None,
-                 bpemodel_file=None, mode="train"):
-        symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
-        if seg_dict_file is not None:
-            seg_dict = load_seg_dict(seg_dict_file)
-        else:
-            seg_dict = None
-        if punc_dict_file is not None:
-            punc_dict = read_symbol_table(punc_dict_file)
-        else:
-            punc_dict = None
-        self.dataset_conf = dataset_conf
-        self.frontend_conf = frontend_conf
+class LargeDataLoader(AbsIterFactory):
+    def __init__(self, args, mode="train"):
+        symbol_table, seg_dict, punc_dict, bpe_tokenizer = None, None, None, None
+        if hasattr(args, "token_list") and args.token_list is not None:
+            symbol_table = read_symbol_table(args.token_list)
+        if hasattr(args, "seg_dict_file") and args.seg_dict_file is not None:
+            seg_dict = load_seg_dict(args.seg_dict_file)
+        if hasattr(args, "punc_dict_file") and args.punc_dict_file is not None:
+            punc_dict = read_symbol_table(args.punc_dict_file)
+        if hasattr(args, "bpemodel_file") and args.bpemodel_file is not None:
+            bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel_file)
+        self.dataset_conf = args.dataset_conf
+        self.frontend_conf = args.frontend_conf
         logging.info("dataloader config: {}".format(self.dataset_conf))
         batch_mode = self.dataset_conf.get("batch_mode", "padding")
-        if bpemodel_file is not None:
-            bpe_tokenizer = SentencepiecesTokenizer(bpemodel_file)
-        else:
-            bpe_tokenizer = None
+        data_list = args.train_data_file if mode == "train" else args.valid_data_file
         self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
-                               self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
+                               self.dataset_conf, self.frontend_conf,
+                               speed_perturb=args.speed_perturb if mode == "train" else None,
+                               mode=mode, batch_mode=batch_mode)
 
     def build_iter(self, epoch, shuffle=True):
         self.dataset.set_epoch(epoch)
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 8c224d8..5df61fd 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -1,20 +1,20 @@
+import logging
 import os
 import random
-import numpy
 from functools import partial
 
 import torch
-import torchaudio
 import torch.distributed as dist
+import torchaudio
 from kaldiio import ReadHelper
 from torch.utils.data import IterableDataset
 
 from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
 from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
 from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
+from funasr.datasets.large_datasets.utils.clipping import clipping
 from funasr.datasets.large_datasets.utils.filter import filter
 from funasr.datasets.large_datasets.utils.padding import padding
-from funasr.datasets.large_datasets.utils.clipping import clipping
 from funasr.datasets.large_datasets.utils.tokenize import tokenize
 
 
@@ -28,7 +28,8 @@
 
 
 class AudioDataset(IterableDataset):
-    def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"):
+    def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None,
+                 mode="train"):
         self.scp_lists = scp_lists
         self.data_names = data_names
         self.data_types = data_types
@@ -40,6 +41,9 @@
         self.world_size = 1
         self.worker_id = 0
         self.num_workers = 1
+        self.speed_perturb = speed_perturb
+        if self.speed_perturb is not None:
+            logging.info("Using speed_perturb: {}".format(speed_perturb))
 
     def set_epoch(self, epoch):
         self.epoch = epoch
@@ -124,9 +128,15 @@
                             if sampling_rate != self.frontend_conf["fs"]:
                                 waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
                                                                           new_freq=self.frontend_conf["fs"])(waveform)
-                                sampling_rate = self.frontend_conf["fs"] 
+                                sampling_rate = self.frontend_conf["fs"]
                         waveform = waveform.numpy()
                         mat = waveform[0]
+                        if self.speed_perturb is not None:
+                            speed = random.choice(self.speed_perturb)
+                            if speed != 1.0:
+                                mat, _ = torchaudio.sox_effects.apply_effects_tensor(
+                                    torch.tensor(mat).view(1, -1), sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]])
+                                mat = mat.view(-1).numpy()
                         sample_dict[data_name] = mat
                         sample_dict["sampling_rate"] = sampling_rate
                         if data_name == "speech":
@@ -168,6 +178,7 @@
             bpe_tokenizer,
             conf,
             frontend_conf,
+            speed_perturb=None,
             mode="train",
             batch_mode="padding"):
     scp_lists = read_lists(data_list_file)
@@ -196,7 +207,8 @@
                            data_names, 
                            data_types, 
                            frontend_conf=frontend_conf, 
-                           shuffle=shuffle, 
+                           shuffle=shuffle,
+                           speed_perturb=speed_perturb,
                            mode=mode, 
                            )
 
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
index f0f0c66..cf7d255 100644
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -48,7 +48,7 @@
     vad = -2
 
     if bpe_tokenizer is not None:
-        text = bpe_tokenizer.text2tokens("".join(text))
+        text = bpe_tokenizer.text2tokens(text)
 
     if seg_dict is not None:
         assert isinstance(seg_dict, dict)
diff --git a/funasr/datasets/small_datasets/collate_fn.py b/funasr/datasets/small_datasets/collate_fn.py
new file mode 100644
index 0000000..573f581
--- /dev/null
+++ b/funasr/datasets/small_datasets/collate_fn.py
@@ -0,0 +1,93 @@
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.modules.nets_utils import pad_list
+
+
+class CommonCollateFn:
+    """Functor class of common_collate_fn()"""
+
+    def __init__(
+            self,
+            float_pad_value: Union[float, int] = 0.0,
+            int_pad_value: int = -32768,
+            not_sequence: Collection[str] = (),
+            max_sample_size=None
+    ):
+        assert check_argument_types()
+        self.float_pad_value = float_pad_value
+        self.int_pad_value = int_pad_value
+        self.not_sequence = set(not_sequence)
+        self.max_sample_size = max_sample_size
+
+    def __repr__(self):
+        return (
+            f"{self.__class__}(float_pad_value={self.float_pad_value}, "
+            f"int_pad_value={self.float_pad_value})"
+        )
+
+    def __call__(
+            self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
+    ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+        return common_collate_fn(
+            data,
+            float_pad_value=self.float_pad_value,
+            int_pad_value=self.int_pad_value,
+            not_sequence=self.not_sequence,
+        )
+
+
+def common_collate_fn(
+        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
+        float_pad_value: Union[float, int] = 0.0,
+        int_pad_value: int = -32768,
+        not_sequence: Collection[str] = (),
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+    """Concatenate ndarray-list to an array and convert to torch.Tensor.
+    """
+    assert check_argument_types()
+    uttids = [u for u, _ in data]
+    data = [d for _, d in data]
+
+    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
+    assert all(
+        not k.endswith("_lengths") for k in data[0]
+    ), f"*_lengths is reserved: {list(data[0])}"
+
+    output = {}
+    for key in data[0]:
+        if data[0][key].dtype.kind == "i":
+            pad_value = int_pad_value
+        else:
+            pad_value = float_pad_value
+
+        array_list = [d[key] for d in data]
+        tensor_list = [torch.from_numpy(a) for a in array_list]
+        tensor = pad_list(tensor_list, pad_value)
+        output[key] = tensor
+
+        if key not in not_sequence:
+            lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
+            output[key + "_lengths"] = lens
+
+    output = (uttids, output)
+    assert check_return_type(output)
+    return output
+
+def crop_to_max_size(feature, target_size):
+    size = len(feature)
+    diff = size - target_size
+    if diff <= 0:
+        return feature
+
+    start = np.random.randint(0, diff + 1)
+    end = size - diff + start
+    return feature[start:end]
\ No newline at end of file
diff --git a/funasr/datasets/small_datasets/dataset.py b/funasr/datasets/small_datasets/dataset.py
new file mode 100644
index 0000000..e14e4f1
--- /dev/null
+++ b/funasr/datasets/small_datasets/dataset.py
@@ -0,0 +1,258 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+import collections
+import copy
+import logging
+import numbers
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import Mapping
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import kaldiio
+import numpy as np
+import torch
+from torch.utils.data.dataset import Dataset
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.npy_scp import NpyScpReader
+from funasr.fileio.sound_scp import SoundScpReader
+
+
+class AdapterForSoundScpReader(collections.abc.Mapping):
+    def __init__(self, loader, dtype=None):
+        assert check_argument_types()
+        self.loader = loader
+        self.dtype = dtype
+        self.rate = None
+
+    def keys(self):
+        return self.loader.keys()
+
+    def __len__(self):
+        return len(self.loader)
+
+    def __iter__(self):
+        return iter(self.loader)
+
+    def __getitem__(self, key: str) -> np.ndarray:
+        retval = self.loader[key]
+
+        if isinstance(retval, tuple):
+            assert len(retval) == 2, len(retval)
+            if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
+                # sound scp case
+                rate, array = retval
+            elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
+                # Extended ark format case
+                array, rate = retval
+            else:
+                raise RuntimeError(
+                    f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
+                )
+
+            if self.rate is not None and self.rate != rate:
+                raise RuntimeError(
+                    f"Sampling rates are mismatched: {self.rate} != {rate}"
+                )
+            self.rate = rate
+            # Multichannel wave fie
+            # array: (NSample, Channel) or (Nsample)
+            if self.dtype is not None:
+                array = array.astype(self.dtype)
+
+        else:
+            # Normal ark case
+            assert isinstance(retval, np.ndarray), type(retval)
+            array = retval
+            if self.dtype is not None:
+                array = array.astype(self.dtype)
+
+        assert isinstance(array, np.ndarray), type(array)
+        return array
+
+
+def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
+    # The file is as follows:
+    #   utterance_id_A /some/where/a.wav
+    #   utterance_id_B /some/where/a.flac
+
+    # NOTE(kamo): SoundScpReader doesn't support pipe-fashion
+    # like Kaldi e.g. "cat a.wav |".
+    # NOTE(kamo): The audio signal is normalized to [-1,1] range.
+    loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
+
+    # SoundScpReader.__getitem__() returns Tuple[int, ndarray],
+    # but ndarray is desired, so Adapter class is inserted here
+    return AdapterForSoundScpReader(loader, float_dtype)
+
+
+def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
+    loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
+    return AdapterForSoundScpReader(loader, float_dtype)
+
+
+class ESPnetDataset(Dataset):
+    """
+        Pytorch Dataset class for FunASR, modified from ESPnet
+    """
+
+    def __init__(
+            self,
+            path_name_type_list: Collection[Tuple[str, str, str]],
+            preprocess: Callable[
+                [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
+            ] = None,
+            float_dtype: str = "float32",
+            int_dtype: str = "long",
+            dest_sample_rate: int = 16000,
+            speed_perturb: Union[list, tuple] = None,
+            mode: str = "train",
+    ):
+        assert check_argument_types()
+        if len(path_name_type_list) == 0:
+            raise ValueError(
+                '1 or more elements are required for "path_name_type_list"'
+            )
+
+        path_name_type_list = copy.deepcopy(path_name_type_list)
+        self.preprocess = preprocess
+
+        self.float_dtype = float_dtype
+        self.int_dtype = int_dtype
+        self.dest_sample_rate = dest_sample_rate
+        self.speed_perturb = speed_perturb
+        self.mode = mode
+        if self.speed_perturb is not None:
+            logging.info("Using speed_perturb: {}".format(speed_perturb))
+
+        self.loader_dict = {}
+        self.debug_info = {}
+        for path, name, _type in path_name_type_list:
+            if name in self.loader_dict:
+                raise RuntimeError(f'"{name}" is duplicated for data-key')
+
+            loader = self._build_loader(path, _type)
+            self.loader_dict[name] = loader
+            self.debug_info[name] = path, _type
+            if len(self.loader_dict[name]) == 0:
+                raise RuntimeError(f"{path} has no samples")
+
+    def _build_loader(
+            self, path: str, loader_type: str
+    ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
+        """Helper function to instantiate Loader.
+
+        Args:
+            path:  The file path
+            loader_type:  loader_type. sound, npy, text, etc
+        """
+        if loader_type == "sound":
+            speed_perturb = self.speed_perturb if self.mode == "train" else None
+            loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False,
+                                    speed_perturb=speed_perturb)
+            return AdapterForSoundScpReader(loader, self.float_dtype)
+        elif loader_type == "kaldi_ark":
+            loader = kaldiio.load_scp(path)
+            return AdapterForSoundScpReader(loader, self.float_dtype)
+        elif loader_type == "npy":
+            return NpyScpReader(path)
+        elif loader_type == "text":
+            text_loader = {}
+            with open(path, "r", encoding="utf-8") as f:
+                for linenum, line in enumerate(f, 1):
+                    sps = line.rstrip().split(maxsplit=1)
+                    if len(sps) == 1:
+                        k, v = sps[0], ""
+                    else:
+                        k, v = sps
+                    if k in text_loader:
+                        raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
+                    text_loader[k] = v
+            return text_loader
+        else:
+            raise RuntimeError(f"Not supported: loader_type={loader_type}")
+
+    def has_name(self, name) -> bool:
+        return name in self.loader_dict
+
+    def names(self) -> Tuple[str, ...]:
+        return tuple(self.loader_dict)
+
+    def __iter__(self):
+        return iter(next(iter(self.loader_dict.values())))
+
+    def __repr__(self):
+        _mes = self.__class__.__name__
+        _mes += "("
+        for name, (path, _type) in self.debug_info.items():
+            _mes += f'\n  {name}: {{"path": "{path}", "type": "{_type}"}}'
+        _mes += f"\n  preprocess: {self.preprocess})"
+        return _mes
+
+    def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
+        assert check_argument_types()
+
+        # Change integer-id to string-id
+        if isinstance(uid, int):
+            d = next(iter(self.loader_dict.values()))
+            uid = list(d)[uid]
+
+        data = {}
+        # 1. Load data from each loaders
+        for name, loader in self.loader_dict.items():
+            try:
+                value = loader[uid]
+                if isinstance(value, (list, tuple)):
+                    value = np.array(value)
+                if not isinstance(
+                        value, (np.ndarray, torch.Tensor, str, numbers.Number)
+                ):
+                    raise TypeError(
+                        f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
+                    )
+            except Exception:
+                path, _type = self.debug_info[name]
+                logging.error(
+                    f"Error happened with path={path}, type={_type}, id={uid}"
+                )
+                raise
+
+            # torch.Tensor is converted to ndarray
+            if isinstance(value, torch.Tensor):
+                value = value.numpy()
+            elif isinstance(value, numbers.Number):
+                value = np.array([value])
+            data[name] = value
+
+        # 2. [Option] Apply preprocessing
+        #   e.g. funasr.train.preprocessor:CommonPreprocessor
+        if self.preprocess is not None:
+            data = self.preprocess(uid, data)
+
+        # 3. Force data-precision
+        for name in data:
+            value = data[name]
+            if not isinstance(value, np.ndarray):
+                raise RuntimeError(
+                    f"All values must be converted to np.ndarray object "
+                    f'by preprocessing, but "{name}" is still {type(value)}.'
+                )
+
+            # Cast to desired type
+            if value.dtype.kind == "f":
+                value = value.astype(self.float_dtype)
+            elif value.dtype.kind == "i":
+                value = value.astype(self.int_dtype)
+            else:
+                raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+            data[name] = value
+
+        retval = uid, data
+        assert check_return_type(retval)
+        return retval
diff --git a/funasr/datasets/small_datasets/length_batch_sampler.py b/funasr/datasets/small_datasets/length_batch_sampler.py
new file mode 100644
index 0000000..8ee8bdc
--- /dev/null
+++ b/funasr/datasets/small_datasets/length_batch_sampler.py
@@ -0,0 +1,147 @@
+from typing import Iterator
+from typing import List
+from typing import Dict
+from typing import Tuple
+from typing import Union
+
+from typeguard import check_argument_types
+
+from funasr.fileio.read_text import load_num_sequence_text
+from funasr.samplers.abs_sampler import AbsSampler
+
+
+class LengthBatchSampler(AbsSampler):
+    def __init__(
+        self,
+        batch_bins: int,
+        shape_files: Union[Tuple[str, ...], List[str], Dict],
+        min_batch_size: int = 1,
+        sort_in_batch: str = "descending",
+        sort_batch: str = "ascending",
+        drop_last: bool = False,
+        padding: bool = True,
+    ):
+        assert check_argument_types()
+        assert batch_bins > 0
+        if sort_batch != "ascending" and sort_batch != "descending":
+            raise ValueError(
+                f"sort_batch must be ascending or descending: {sort_batch}"
+            )
+        if sort_in_batch != "descending" and sort_in_batch != "ascending":
+            raise ValueError(
+                f"sort_in_batch must be ascending or descending: {sort_in_batch}"
+            )
+
+        self.batch_bins = batch_bins
+        self.shape_files = shape_files
+        self.sort_in_batch = sort_in_batch
+        self.sort_batch = sort_batch
+        self.drop_last = drop_last
+
+        # utt2shape: (Length, ...)
+        #    uttA 100,...
+        #    uttB 201,...
+        if isinstance(shape_files, dict):
+            utt2shapes = [shape_files]
+        else:
+            utt2shapes = [
+                load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
+            ]
+
+        first_utt2shape = utt2shapes[0]
+        for s, d in zip(shape_files, utt2shapes):
+            if set(d) != set(first_utt2shape):
+                raise RuntimeError(
+                    f"keys are mismatched between {s} != {shape_files[0]}"
+                )
+
+        # Sort samples in ascending order
+        # (shape order should be like (Length, Dim))
+        keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
+        if len(keys) == 0:
+            raise RuntimeError(f"0 lines found: {shape_files[0]}")
+
+        # Decide batch-sizes
+        batch_sizes = []
+        current_batch_keys = []
+        for key in keys:
+            current_batch_keys.append(key)
+            # shape: (Length, dim1, dim2, ...)
+            if padding:
+                # bins = bs x max_length
+                bins = sum(len(current_batch_keys) * sh[key][0] for sh in utt2shapes)
+            else:
+                # bins = sum of lengths
+                bins = sum(d[k][0] for k in current_batch_keys for d in utt2shapes)
+
+            if bins > batch_bins and len(current_batch_keys) >= min_batch_size:
+                batch_sizes.append(len(current_batch_keys))
+                current_batch_keys = []
+        else:
+            if len(current_batch_keys) != 0 and (
+                not self.drop_last or len(batch_sizes) == 0
+            ):
+                batch_sizes.append(len(current_batch_keys))
+
+        if len(batch_sizes) == 0:
+            # Maybe we can't reach here
+            raise RuntimeError("0 batches")
+
+        # If the last batch-size is smaller than minimum batch_size,
+        # the samples are redistributed to the other mini-batches
+        if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
+            for i in range(batch_sizes.pop(-1)):
+                batch_sizes[-(i % len(batch_sizes)) - 1] += 1
+
+        if not self.drop_last:
+            # Bug check
+            assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}"
+
+        # Set mini-batch
+        self.batch_list = []
+        iter_bs = iter(batch_sizes)
+        bs = next(iter_bs)
+        minibatch_keys = []
+        for key in keys:
+            minibatch_keys.append(key)
+            if len(minibatch_keys) == bs:
+                if sort_in_batch == "descending":
+                    minibatch_keys.reverse()
+                elif sort_in_batch == "ascending":
+                    # Key are already sorted in ascending
+                    pass
+                else:
+                    raise ValueError(
+                        "sort_in_batch must be ascending"
+                        f" or descending: {sort_in_batch}"
+                    )
+                self.batch_list.append(tuple(minibatch_keys))
+                minibatch_keys = []
+                try:
+                    bs = next(iter_bs)
+                except StopIteration:
+                    break
+
+        if sort_batch == "ascending":
+            pass
+        elif sort_batch == "descending":
+            self.batch_list.reverse()
+        else:
+            raise ValueError(
+                f"sort_batch must be ascending or descending: {sort_batch}"
+            )
+
+    def __repr__(self):
+        return (
+            f"{self.__class__.__name__}("
+            f"N-batch={len(self)}, "
+            f"batch_bins={self.batch_bins}, "
+            f"sort_in_batch={self.sort_in_batch}, "
+            f"sort_batch={self.sort_batch})"
+        )
+
+    def __len__(self):
+        return len(self.batch_list)
+
+    def __iter__(self) -> Iterator[Tuple[str, ...]]:
+        return iter(self.batch_list)
diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py
new file mode 100644
index 0000000..d80f48a
--- /dev/null
+++ b/funasr/datasets/small_datasets/preprocessor.py
@@ -0,0 +1,875 @@
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import scipy.signal
+import soundfile
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.cleaner import TextCleaner
+from funasr.text.token_id_converter import TokenIDConverter
+
+
+class AbsPreprocessor(ABC):
+    def __init__(self, train: bool):
+        self.train = train
+
+    @abstractmethod
+    def __call__(
+            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        raise NotImplementedError
+
+
+def forward_segment(text, dic):
+    word_list = []
+    i = 0
+    while i < len(text):
+        longest_word = text[i]
+        for j in range(i + 1, len(text) + 1):
+            word = text[i:j]
+            if word in dic:
+                if len(word) > len(longest_word):
+                    longest_word = word
+        word_list.append(longest_word)
+        i += len(longest_word)
+    return word_list
+
+
+def seg_tokenize(txt, seg_dict):
+    out_txt = ""
+    for word in txt:
+        if word in seg_dict:
+            out_txt += seg_dict[word] + " "
+        else:
+            out_txt += "<unk>" + " "
+    return out_txt.strip().split()
+
+
+def seg_tokenize_wo_pattern(txt, seg_dict):
+    out_txt = ""
+    for word in txt:
+        if word in seg_dict:
+            out_txt += seg_dict[word] + " "
+        else:
+            out_txt += "<unk>" + " "
+    return out_txt.strip().split()
+
+
+def framing(
+        x,
+        frame_length: int = 512,
+        frame_shift: int = 256,
+        centered: bool = True,
+        padded: bool = True,
+):
+    if x.size == 0:
+        raise ValueError("Input array size is zero")
+    if frame_length < 1:
+        raise ValueError("frame_length must be a positive integer")
+    if frame_length > x.shape[-1]:
+        raise ValueError("frame_length is greater than input length")
+    if 0 >= frame_shift:
+        raise ValueError("frame_shift must be greater than 0")
+
+    if centered:
+        pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
+            (frame_length // 2, frame_length // 2)
+        ]
+        x = np.pad(x, pad_shape, mode="constant", constant_values=0)
+
+    if padded:
+        # Pad to integer number of windowed segments
+        # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
+        #  with integer nseg
+        nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
+        pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
+        x = np.pad(x, pad_shape, mode="constant", constant_values=0)
+
+    # Created strided array of data segments
+    if frame_length == 1 and frame_length == frame_shift:
+        result = x[..., None]
+    else:
+        shape = x.shape[:-1] + (
+            (x.shape[-1] - frame_length) // frame_shift + 1,
+            frame_length,
+        )
+        strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
+        result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
+    return result
+
+
+def detect_non_silence(
+        x: np.ndarray,
+        threshold: float = 0.01,
+        frame_length: int = 1024,
+        frame_shift: int = 512,
+        window: str = "boxcar",
+) -> np.ndarray:
+    """Power based voice activity detection.
+
+    Args:
+        x: (Channel, Time)
+    >>> x = np.random.randn(1000)
+    >>> detect = detect_non_silence(x)
+    >>> assert x.shape == detect.shape
+    >>> assert detect.dtype == np.bool
+    """
+    if x.shape[-1] < frame_length:
+        return np.full(x.shape, fill_value=True, dtype=np.bool)
+
+    if x.dtype.kind == "i":
+        x = x.astype(np.float64)
+    # framed_w: (C, T, F)
+    framed_w = framing(
+        x,
+        frame_length=frame_length,
+        frame_shift=frame_shift,
+        centered=False,
+        padded=True,
+    )
+    framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
+    # power: (C, T)
+    power = (framed_w ** 2).mean(axis=-1)
+    # mean_power: (C, 1)
+    mean_power = np.mean(power, axis=-1, keepdims=True)
+    if np.all(mean_power == 0):
+        return np.full(x.shape, fill_value=True, dtype=np.bool)
+    # detect_frames: (C, T)
+    detect_frames = power / mean_power > threshold
+    # detects: (C, T, F)
+    detects = np.broadcast_to(
+        detect_frames[..., None], detect_frames.shape + (frame_shift,)
+    )
+    # detects: (C, TF)
+    detects = detects.reshape(*detect_frames.shape[:-1], -1)
+    # detects: (C, TF)
+    return np.pad(
+        detects,
+        [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
+        mode="edge",
+    )
+
+
+class CommonPreprocessor(AbsPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: str = None,
+            token_list: Union[Path, str, Iterable[str]] = None,
+            bpemodel: Union[Path, str, Iterable[str]] = None,
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: str = "text",
+            split_with_space: bool = False,
+            seg_dict_file: str = None,
+    ):
+        super().__init__(train)
+        self.train = train
+        self.speech_name = speech_name
+        self.text_name = text_name
+        self.speech_volume_normalize = speech_volume_normalize
+        self.rir_apply_prob = rir_apply_prob
+        self.noise_apply_prob = noise_apply_prob
+        self.split_with_space = split_with_space
+        self.seg_dict = None
+        if seg_dict_file is not None:
+            self.seg_dict = {}
+            with open(seg_dict_file) as f:
+                lines = f.readlines()
+            for line in lines:
+                s = line.strip().split()
+                key = s[0]
+                value = s[1:]
+                self.seg_dict[key] = " ".join(value)
+
+        if token_type is not None:
+            if token_list is None:
+                raise ValueError("token_list is required if token_type is not None")
+            self.text_cleaner = TextCleaner(text_cleaner)
+
+            self.tokenizer = build_tokenizer(
+                token_type=token_type,
+                bpemodel=bpemodel,
+                delimiter=delimiter,
+                space_symbol=space_symbol,
+                non_linguistic_symbols=non_linguistic_symbols,
+                g2p_type=g2p_type,
+            )
+            self.token_id_converter = TokenIDConverter(
+                token_list=token_list,
+                unk_symbol=unk_symbol,
+            )
+        else:
+            self.text_cleaner = None
+            self.tokenizer = None
+            self.token_id_converter = None
+
+        if train and rir_scp is not None:
+            self.rirs = []
+            with open(rir_scp, "r", encoding="utf-8") as f:
+                for line in f:
+                    sps = line.strip().split(None, 1)
+                    if len(sps) == 1:
+                        self.rirs.append(sps[0])
+                    else:
+                        self.rirs.append(sps[1])
+        else:
+            self.rirs = None
+
+        if train and noise_scp is not None:
+            self.noises = []
+            with open(noise_scp, "r", encoding="utf-8") as f:
+                for line in f:
+                    sps = line.strip().split(None, 1)
+                    if len(sps) == 1:
+                        self.noises.append(sps[0])
+                    else:
+                        self.noises.append(sps[1])
+            sps = noise_db_range.split("_")
+            if len(sps) == 1:
+                self.noise_db_low, self.noise_db_high = float(sps[0])
+            elif len(sps) == 2:
+                self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
+            else:
+                raise ValueError(
+                    "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
+                )
+        else:
+            self.noises = None
+
+    def _speech_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, Union[str, np.ndarray]]:
+        assert check_argument_types()
+        if self.speech_name in data:
+            if self.train and (self.rirs is not None or self.noises is not None):
+                speech = data[self.speech_name]
+                nsamples = len(speech)
+
+                # speech: (Nmic, Time)
+                if speech.ndim == 1:
+                    speech = speech[None, :]
+                else:
+                    speech = speech.T
+                # Calc power on non shlence region
+                power = (speech[detect_non_silence(speech)] ** 2).mean()
+
+                # 1. Convolve RIR
+                if self.rirs is not None and self.rir_apply_prob >= np.random.random():
+                    rir_path = np.random.choice(self.rirs)
+                    if rir_path is not None:
+                        rir, _ = soundfile.read(
+                            rir_path, dtype=np.float64, always_2d=True
+                        )
+
+                        # rir: (Nmic, Time)
+                        rir = rir.T
+
+                        # speech: (Nmic, Time)
+                        # Note that this operation doesn't change the signal length
+                        speech = scipy.signal.convolve(speech, rir, mode="full")[
+                                 :, : speech.shape[1]
+                                 ]
+                        # Reverse mean power to the original power
+                        power2 = (speech[detect_non_silence(speech)] ** 2).mean()
+                        speech = np.sqrt(power / max(power2, 1e-10)) * speech
+
+                # 2. Add Noise
+                if (
+                        self.noises is not None
+                        and self.noise_apply_prob >= np.random.random()
+                ):
+                    noise_path = np.random.choice(self.noises)
+                    if noise_path is not None:
+                        noise_db = np.random.uniform(
+                            self.noise_db_low, self.noise_db_high
+                        )
+                        with soundfile.SoundFile(noise_path) as f:
+                            if f.frames == nsamples:
+                                noise = f.read(dtype=np.float64, always_2d=True)
+                            elif f.frames < nsamples:
+                                offset = np.random.randint(0, nsamples - f.frames)
+                                # noise: (Time, Nmic)
+                                noise = f.read(dtype=np.float64, always_2d=True)
+                                # Repeat noise
+                                noise = np.pad(
+                                    noise,
+                                    [(offset, nsamples - f.frames - offset), (0, 0)],
+                                    mode="wrap",
+                                )
+                            else:
+                                offset = np.random.randint(0, f.frames - nsamples)
+                                f.seek(offset)
+                                # noise: (Time, Nmic)
+                                noise = f.read(
+                                    nsamples, dtype=np.float64, always_2d=True
+                                )
+                                if len(noise) != nsamples:
+                                    raise RuntimeError(f"Something wrong: {noise_path}")
+                        # noise: (Nmic, Time)
+                        noise = noise.T
+
+                        noise_power = (noise ** 2).mean()
+                        scale = (
+                                10 ** (-noise_db / 20)
+                                * np.sqrt(power)
+                                / np.sqrt(max(noise_power, 1e-10))
+                        )
+                        speech = speech + scale * noise
+
+                speech = speech.T
+                ma = np.max(np.abs(speech))
+                if ma > 1.0:
+                    speech /= ma
+                data[self.speech_name] = speech
+
+            if self.speech_volume_normalize is not None:
+                speech = data[self.speech_name]
+                ma = np.max(np.abs(speech))
+                data[self.speech_name] = speech * self.speech_volume_normalize / ma
+        assert check_return_type(data)
+        return data
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        if self.text_name in data and self.tokenizer is not None:
+            text = data[self.text_name]
+            text = self.text_cleaner(text)
+            if self.split_with_space:
+                tokens = text.strip().split(" ")
+                if self.seg_dict is not None:
+                    tokens = forward_segment("".join(tokens), self.seg_dict)
+                    tokens = seg_tokenize(tokens, self.seg_dict)
+            else:
+                tokens = self.tokenizer.text2tokens(text)
+            text_ints = self.token_id_converter.tokens2ids(tokens)
+            data[self.text_name] = np.array(text_ints, dtype=np.int64)
+        assert check_return_type(data)
+        return data
+
+    def __call__(
+            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        assert check_argument_types()
+
+        data = self._speech_process(data)
+        data = self._text_process(data)
+        return data
+
+
+## FIXME
+class LMPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: str = None,
+            token_list: Union[Path, str, Iterable[str]] = None,
+            bpemodel: Union[Path, str, Iterable[str]] = None,
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: str = "text",
+            split_with_space: bool = False,
+            seg_dict_file: str = None,
+    ):
+        super().__init__(train,
+                         token_type,
+                         token_list,
+                         bpemodel,
+                         text_cleaner,
+                         g2p_type,
+                         unk_symbol,
+                         space_symbol,
+                         non_linguistic_symbols,
+                         delimiter,
+                         rir_scp,
+                         rir_apply_prob,
+                         noise_scp,
+                         noise_apply_prob,
+                         noise_db_range,
+                         speech_volume_normalize,
+                         speech_name,
+                         text_name,
+                         split_with_space,
+                         seg_dict_file,
+                         )
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        if self.text_name in data and self.tokenizer is not None:
+            text = data[self.text_name]
+            text = self.text_cleaner(text)
+            if self.split_with_space:
+                tokens = text.strip().split(" ")
+                if self.seg_dict is not None:
+                    tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
+            else:
+                tokens = self.tokenizer.text2tokens(text)
+            text_ints = self.token_id_converter.tokens2ids(tokens)
+            data[self.text_name] = np.array(text_ints, dtype=np.int64)
+        assert check_return_type(data)
+        return data
+
+
+class CommonPreprocessor_multi(AbsPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: str = None,
+            token_list: Union[Path, str, Iterable[str]] = None,
+            bpemodel: Union[Path, str, Iterable[str]] = None,
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            speech_name: str = "speech",
+            text_name: List[str] = ["text"],
+    ):
+        super().__init__(train)
+        self.train = train
+        self.speech_name = speech_name
+        self.text_name = text_name
+
+        if token_type is not None:
+            if token_list is None:
+                raise ValueError("token_list is required if token_type is not None")
+            self.text_cleaner = TextCleaner(text_cleaner)
+
+            self.tokenizer = build_tokenizer(
+                token_type=token_type,
+                bpemodel=bpemodel,
+                delimiter=delimiter,
+                space_symbol=space_symbol,
+                non_linguistic_symbols=non_linguistic_symbols,
+                g2p_type=g2p_type,
+            )
+            self.token_id_converter = TokenIDConverter(
+                token_list=token_list,
+                unk_symbol=unk_symbol,
+            )
+        else:
+            self.text_cleaner = None
+            self.tokenizer = None
+            self.token_id_converter = None
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        for text_n in self.text_name:
+            if text_n in data and self.tokenizer is not None:
+                text = data[text_n]
+                text = self.text_cleaner(text)
+                tokens = self.tokenizer.text2tokens(text)
+                text_ints = self.token_id_converter.tokens2ids(tokens)
+                data[text_n] = np.array(text_ints, dtype=np.int64)
+        assert check_return_type(data)
+        return data
+
+    def __call__(
+            self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        assert check_argument_types()
+
+        if self.speech_name in data:
+            # Nothing now: candidates:
+            # - STFT
+            # - Fbank
+            # - CMVN
+            # - Data augmentation
+            pass
+
+        data = self._text_process(data)
+        return data
+
+
+class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: List[str] = [None],
+            token_list: List[Union[Path, str, Iterable[str]]] = [None],
+            bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: List[str] = ["text"],
+    ):
+        # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+        super().__init__(
+            train=train,
+            token_type=token_type[0],
+            token_list=token_list[0],
+            bpemodel=bpemodel[0],
+            text_cleaner=text_cleaner,
+            g2p_type=g2p_type,
+            unk_symbol=unk_symbol,
+            space_symbol=space_symbol,
+            non_linguistic_symbols=non_linguistic_symbols,
+            delimiter=delimiter,
+            speech_name=speech_name,
+            text_name=text_name[0],
+            rir_scp=rir_scp,
+            rir_apply_prob=rir_apply_prob,
+            noise_scp=noise_scp,
+            noise_apply_prob=noise_apply_prob,
+            noise_db_range=noise_db_range,
+            speech_volume_normalize=speech_volume_normalize,
+        )
+
+        assert (
+                len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+        ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+        self.num_tokenizer = len(token_type)
+        self.tokenizer = []
+        self.token_id_converter = []
+
+        for i in range(self.num_tokenizer):
+            if token_type[i] is not None:
+                if token_list[i] is None:
+                    raise ValueError("token_list is required if token_type is not None")
+
+                self.tokenizer.append(
+                    build_tokenizer(
+                        token_type=token_type[i],
+                        bpemodel=bpemodel[i],
+                        delimiter=delimiter,
+                        space_symbol=space_symbol,
+                        non_linguistic_symbols=non_linguistic_symbols,
+                        g2p_type=g2p_type,
+                    )
+                )
+                self.token_id_converter.append(
+                    TokenIDConverter(
+                        token_list=token_list[i],
+                        unk_symbol=unk_symbol,
+                    )
+                )
+            else:
+                self.tokenizer.append(None)
+                self.token_id_converter.append(None)
+
+        self.text_cleaner = TextCleaner(text_cleaner)
+        self.text_name = text_name  # override the text_name from CommonPreprocessor
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        for i in range(self.num_tokenizer):
+            text_name = self.text_name[i]
+            if text_name in data and self.tokenizer[i] is not None:
+                text = data[text_name]
+                text = self.text_cleaner(text)
+                tokens = self.tokenizer[i].text2tokens(text)
+                text_ints = self.token_id_converter[i].tokens2ids(tokens)
+                data[text_name] = np.array(text_ints, dtype=np.int64)
+        assert check_return_type(data)
+        return data
+
+
+class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: str = None,
+            token_list: Union[Path, str, Iterable[str]] = None,
+            bpemodel: Union[Path, str, Iterable[str]] = None,
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: str = "text",
+            split_text_name: str = "split_text",
+            split_with_space: bool = False,
+            seg_dict_file: str = None,
+    ):
+        super().__init__(
+            train=train,
+            # Force to use word.
+            token_type="word",
+            token_list=token_list,
+            bpemodel=bpemodel,
+            text_cleaner=text_cleaner,
+            g2p_type=g2p_type,
+            unk_symbol=unk_symbol,
+            space_symbol=space_symbol,
+            non_linguistic_symbols=non_linguistic_symbols,
+            delimiter=delimiter,
+            speech_name=speech_name,
+            text_name=text_name,
+            rir_scp=rir_scp,
+            rir_apply_prob=rir_apply_prob,
+            noise_scp=noise_scp,
+            noise_apply_prob=noise_apply_prob,
+            noise_db_range=noise_db_range,
+            speech_volume_normalize=speech_volume_normalize,
+            split_with_space=split_with_space,
+            seg_dict_file=seg_dict_file,
+        )
+        # The data field name for split text.
+        self.split_text_name = split_text_name
+
+    @classmethod
+    def split_words(cls, text: str):
+        words = []
+        segs = text.split()
+        for seg in segs:
+            # There is no space in seg.
+            current_word = ""
+            for c in seg:
+                if len(c.encode()) == 1:
+                    # This is an ASCII char.
+                    current_word += c
+                else:
+                    # This is a Chinese char.
+                    if len(current_word) > 0:
+                        words.append(current_word)
+                        current_word = ""
+                    words.append(c)
+            if len(current_word) > 0:
+                words.append(current_word)
+        return words
+
+    def __call__(
+            self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
+    ) -> Dict[str, Union[list, np.ndarray]]:
+        assert check_argument_types()
+        # Split words.
+        if isinstance(data[self.text_name], str):
+            split_text = self.split_words(data[self.text_name])
+        else:
+            split_text = data[self.text_name]
+        data[self.text_name] = " ".join(split_text)
+        data = self._speech_process(data)
+        data = self._text_process(data)
+        data[self.split_text_name] = split_text
+        return data
+
+    def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
+        result = data[self.split_text_name]
+        del data[self.split_text_name]
+        return result
+
+
+class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: List[str] = [None],
+            token_list: List[Union[Path, str, Iterable[str]]] = [None],
+            bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: List[str] = ["text"],
+            vad_name: str = "vad_indexes",
+    ):
+        # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+        super().__init__(
+            train=train,
+            token_type=token_type[0],
+            token_list=token_list[0],
+            bpemodel=bpemodel[0],
+            text_cleaner=text_cleaner,
+            g2p_type=g2p_type,
+            unk_symbol=unk_symbol,
+            space_symbol=space_symbol,
+            non_linguistic_symbols=non_linguistic_symbols,
+            delimiter=delimiter,
+            speech_name=speech_name,
+            text_name=text_name[0],
+            rir_scp=rir_scp,
+            rir_apply_prob=rir_apply_prob,
+            noise_scp=noise_scp,
+            noise_apply_prob=noise_apply_prob,
+            noise_db_range=noise_db_range,
+            speech_volume_normalize=speech_volume_normalize,
+        )
+
+        assert (
+                len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+        ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+        self.num_tokenizer = len(token_type)
+        self.tokenizer = []
+        self.token_id_converter = []
+
+        for i in range(self.num_tokenizer):
+            if token_type[i] is not None:
+                if token_list[i] is None:
+                    raise ValueError("token_list is required if token_type is not None")
+
+                self.tokenizer.append(
+                    build_tokenizer(
+                        token_type=token_type[i],
+                        bpemodel=bpemodel[i],
+                        delimiter=delimiter,
+                        space_symbol=space_symbol,
+                        non_linguistic_symbols=non_linguistic_symbols,
+                        g2p_type=g2p_type,
+                    )
+                )
+                self.token_id_converter.append(
+                    TokenIDConverter(
+                        token_list=token_list[i],
+                        unk_symbol=unk_symbol,
+                    )
+                )
+            else:
+                self.tokenizer.append(None)
+                self.token_id_converter.append(None)
+
+        self.text_cleaner = TextCleaner(text_cleaner)
+        self.text_name = text_name  # override the text_name from CommonPreprocessor
+        self.vad_name = vad_name
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        for i in range(self.num_tokenizer):
+            text_name = self.text_name[i]
+            if text_name in data and self.tokenizer[i] is not None:
+                text = data[text_name]
+                text = self.text_cleaner(text)
+                tokens = self.tokenizer[i].text2tokens(text)
+                if "vad:" in tokens[-1]:
+                    vad = tokens[-1][4:]
+                    tokens = tokens[:-1]
+                    if len(vad) == 0:
+                        vad = -1
+                    else:
+                        vad = int(vad)
+                    data[self.vad_name] = np.array([vad], dtype=np.int64)
+                text_ints = self.token_id_converter[i].tokens2ids(tokens)
+                data[text_name] = np.array(text_ints, dtype=np.int64)
+
+
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+    assert word_limit > 1
+    if len(words) <= word_limit:
+        return [words]
+    sentences = []
+    length = len(words)
+    sentence_len = length // word_limit
+    for i in range(sentence_len):
+        sentences.append(words[i * word_limit:(i + 1) * word_limit])
+    if length % word_limit > 0:
+        sentences.append(words[sentence_len * word_limit:])
+    return sentences
+
+
+def build_preprocess(args, train):
+    if not args.use_preprocessor:
+        return None
+    if args.task_name in ["asr", "data2vec", "diar", "sv"]:
+        retval = CommonPreprocessor(
+            train=train,
+            token_type=args.token_type,
+            token_list=args.token_list,
+            bpemodel=args.bpemodel,
+            non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None,
+            text_cleaner=args.cleaner,
+            g2p_type=args.g2p,
+            split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+            seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
+            rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+            rir_apply_prob=args.rir_apply_prob if hasattr(args, "rir_apply_prob") else 1.0,
+            noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+            noise_apply_prob=args.noise_apply_prob if hasattr(args, "noise_apply_prob") else 1.0,
+            noise_db_range=args.noise_db_range if hasattr(args, "noise_db_range") else "13_15",
+            speech_volume_normalize=args.speech_volume_normalize if hasattr(args, "rir_scp") else None,
+        )
+    elif args.task_name == "punc":
+        token_types = [args.token_type, args.token_type]
+        token_lists = [args.token_list, args.punc_list]
+        bpemodels = [args.bpemodel, args.bpemodel]
+        text_names = ["text", "punc"]
+        retval = PuncTrainTokenizerCommonPreprocessor(
+            train=train,
+            token_type=token_types,
+            token_list=token_lists,
+            bpemodel=bpemodels,
+            text_cleaner=args.cleaner,
+            g2p_type=args.g2p,
+            text_name=text_names,
+            non_linguistic_symbols=args.non_linguistic_symbols,
+        )
+    elif args.task_name == "lm":
+        retval = LMPreprocessor(
+            train=train,
+            token_type=args.token_type,
+            token_list=args.token_list,
+            bpemodel=args.bpemodel,
+            text_cleaner=args.cleaner,
+            g2p_type=args.g2p,
+            text_name="text",
+            non_linguistic_symbols=args.non_linguistic_symbols,
+            split_with_space=args.split_with_space,
+            seg_dict_file=args.seg_dict_file
+        )
+    elif args.task_name == "vad":
+        retval = None
+    else:
+        raise ValueError(f"Not supported task={args.task_name}")
+    return retval
diff --git a/funasr/datasets/small_datasets/sequence_iter_factory.py b/funasr/datasets/small_datasets/sequence_iter_factory.py
new file mode 100644
index 0000000..3ebcc5a
--- /dev/null
+++ b/funasr/datasets/small_datasets/sequence_iter_factory.py
@@ -0,0 +1,189 @@
+import logging
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
+from funasr.datasets.small_datasets.dataset import ESPnetDataset
+from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
+from funasr.datasets.small_datasets.preprocessor import build_preprocess
+from funasr.iterators.abs_iter_factory import AbsIterFactory
+from funasr.samplers.abs_sampler import AbsSampler
+
+
+class RawSampler(AbsSampler):
+    def __init__(self, batches):
+        self.batches = batches
+
+    def __len__(self):
+        return len(self.batches)
+
+    def __iter__(self):
+        return iter(self.batches)
+
+    def generate(self, seed):
+        return list(self.batches)
+
+
+class SequenceIterFactory(AbsIterFactory):
+    """Build iterator for each epoch, modified from ESPnet
+
+    """
+
+    def __init__(self, args, mode="train"):
+
+        # preprocess
+        preprocess_fn = build_preprocess(args, train=mode == "train")
+
+        # collate
+        if args.task_name in ["punc", "lm"]:
+            collate_fn = CommonCollateFn(int_pad_value=0)
+        else:
+            collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+        # dataset
+        dest_sample_rate = args.frontend_conf["fs"] if (
+                args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
+        if mode == "train":
+            data_path_and_name_and_type = args.train_data_path_and_name_and_type
+            shape_files = args.train_shape_file
+        elif mode == "valid":
+            data_path_and_name_and_type = args.valid_data_path_and_name_and_type
+            shape_files = args.valid_shape_file
+        else:
+            raise NotImplementedError(f"mode={mode}")
+        dataset = ESPnetDataset(
+            data_path_and_name_and_type,
+            preprocess=preprocess_fn,
+            dest_sample_rate=dest_sample_rate,
+            speed_perturb=args.speed_perturb if mode=="train" else None,
+        )
+
+        # sampler
+        dataset_conf = args.dataset_conf
+        batch_sampler = LengthBatchSampler(
+            batch_bins=dataset_conf["batch_conf"]["batch_size"] * args.ngpu,
+            shape_files=shape_files,
+            sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending",
+            sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending",
+            drop_last=False,
+            padding=True,
+        )
+
+        batches = list(batch_sampler)
+        bs_list = [len(batch) for batch in batches]
+        logging.info(f"[{mode}] dataset:\n{dataset}")
+        logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
+        logging.info(
+            f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
+            f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
+        )
+
+        if args.scheduler == "tri_stage" and mode == "train":
+            args.max_update = len(bs_list) * args.max_epoch
+            logging.info("Max update: {}".format(args.max_update))
+
+        if args.distributed and mode=="train":
+            world_size = torch.distributed.get_world_size()
+            rank = torch.distributed.get_rank()
+            for batch in batches:
+                if len(batch) < world_size:
+                    raise RuntimeError(
+                        f"The batch-size must be equal or more than world_size: "
+                        f"{len(batch)} < {world_size}"
+                    )
+            batches = [batch[rank::world_size] for batch in batches]
+
+        if not isinstance(batches, AbsSampler):
+            self.sampler = RawSampler(batches)
+        else:
+            self.sampler = batches
+
+        self.dataset = dataset
+        self.num_iters_per_epoch = None
+        self.shuffle = mode == "train"
+        self.seed = args.seed
+        self.num_workers = args.dataset_conf.get("num_workers", 8)
+        self.collate_fn = collate_fn
+        self.pin_memory = args.ngpu > 0
+
+    def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
+        if shuffle is None:
+            shuffle = self.shuffle
+
+        if self.num_iters_per_epoch is not None:
+            N = len(self.sampler)
+            # If corpus size is larger than the num_per_epoch
+            if self.num_iters_per_epoch < N:
+                N = len(self.sampler)
+                real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
+
+                if offset >= self.num_iters_per_epoch:
+                    current_batches = self.sampler.generate(real_epoch + self.seed)
+                    if shuffle:
+                        np.random.RandomState(real_epoch + self.seed).shuffle(
+                            current_batches
+                        )
+                    batches = current_batches[
+                              offset - self.num_iters_per_epoch: offset
+                              ]
+                else:
+                    prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
+                    current_batches = self.sampler.generate(real_epoch + self.seed)
+                    if shuffle:
+                        np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
+                            prev_batches
+                        )
+                        np.random.RandomState(real_epoch + self.seed).shuffle(
+                            current_batches
+                        )
+                    batches = (
+                            prev_batches[offset - self.num_iters_per_epoch:]
+                            + current_batches[:offset]
+                    )
+
+            # If corpus size is less than the num_per_epoch
+            else:
+                _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
+                _remain = self.num_iters_per_epoch
+                batches = []
+                current_batches = self.sampler.generate(_epoch + self.seed)
+                if shuffle:
+                    np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
+                while _remain > 0:
+
+                    _batches = current_batches[_cursor: _cursor + _remain]
+                    batches += _batches
+                    if _cursor + _remain >= N:
+                        _epoch += 1
+                        _cursor = 0
+                        current_batches = self.sampler.generate(_epoch + self.seed)
+                        if shuffle:
+                            np.random.RandomState(_epoch + self.seed).shuffle(
+                                current_batches
+                            )
+                    else:
+                        _cursor = _cursor + _remain
+                    _remain -= len(_batches)
+
+                assert len(batches) == self.num_iters_per_epoch
+
+        else:
+            batches = self.sampler.generate(epoch + self.seed)
+            if shuffle:
+                np.random.RandomState(epoch + self.seed).shuffle(batches)
+
+        # For backward compatibility for pytorch DataLoader
+        if self.collate_fn is not None:
+            kwargs = dict(collate_fn=self.collate_fn)
+        else:
+            kwargs = {}
+
+        return DataLoader(
+            dataset=self.dataset,
+            batch_sampler=batches,
+            num_workers=self.num_workers,
+            pin_memory=self.pin_memory,
+            **kwargs,
+        )
diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py
index d757f7f..c752fe6 100644
--- a/funasr/fileio/sound_scp.py
+++ b/funasr/fileio/sound_scp.py
@@ -2,10 +2,14 @@
 from pathlib import Path
 from typing import Union
 
+import random
 import numpy as np
 import soundfile
 import librosa
 from typeguard import check_argument_types
+
+import torch
+import torchaudio
 
 from funasr.fileio.read_text import read_2column_text
 
@@ -32,6 +36,7 @@
         always_2d: bool = False,
         normalize: bool = False,
         dest_sample_rate: int = 16000,
+        speed_perturb: Union[list, tuple] = None,
     ):
         assert check_argument_types()
         self.fname = fname
@@ -40,6 +45,7 @@
         self.normalize = normalize
         self.data = read_2column_text(fname)
         self.dest_sample_rate = dest_sample_rate
+        self.speed_perturb = speed_perturb
 
     def __getitem__(self, key):
         wav = self.data[key]
@@ -53,8 +59,17 @@
                 wav, sr=self.dest_sample_rate, mono=self.always_2d, dtype=self.dtype
             )
 
+        if self.speed_perturb is not None:
+            speed = random.choice(self.speed_perturb)
+            if speed != 1.0:
+                array, _ = torchaudio.sox_effects.apply_effects_tensor(
+                    torch.tensor(array).view(1, -1), rate,
+                    [['speed', str(speed)], ['rate', str(rate)]])
+                array = array.view(-1).numpy()
+
         if array.ndim==2:
             array=array.transpose((1, 0))
+
         return rate, array
 
     def get_path(self, key):
diff --git a/funasr/layers/abs_normalize.py b/funasr/layers/abs_normalize.py
index f2be748..4e617d0 100644
--- a/funasr/layers/abs_normalize.py
+++ b/funasr/layers/abs_normalize.py
@@ -11,4 +11,4 @@
         self, input: torch.Tensor, input_lengths: torch.Tensor = None
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # return output, output_lengths
-        raise NotImplementedError
+        raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/layers/global_mvn.py b/funasr/layers/global_mvn.py
index 5515cdd..8e43582 100644
--- a/funasr/layers/global_mvn.py
+++ b/funasr/layers/global_mvn.py
@@ -13,9 +13,7 @@
 
 class GlobalMVN(AbsNormalize, InversibleInterface):
     """Apply global mean and variance normalization
-
     TODO(kamo): Make this class portable somehow
-
     Args:
         stats_file: npy file
         norm_means: Apply mean normalization
@@ -66,7 +64,6 @@
         self, x: torch.Tensor, ilens: torch.Tensor = None
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward function
-
         Args:
             x: (B, L, ...)
             ilens: (B,)
@@ -118,4 +115,4 @@
         if norm_means:
             x += self.mean
             x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
-        return x, ilens
+        return x, ilens
\ No newline at end of file
diff --git a/funasr/layers/inversible_interface.py b/funasr/layers/inversible_interface.py
index a1a5939..657ec68 100644
--- a/funasr/layers/inversible_interface.py
+++ b/funasr/layers/inversible_interface.py
@@ -11,4 +11,4 @@
         self, input: torch.Tensor, input_lengths: torch.Tensor = None
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         # return output, output_lengths
-        raise NotImplementedError
+        raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/lm/abs_model.py b/funasr/lm/abs_model.py
index 1f3c8a7..560879e 100644
--- a/funasr/lm/abs_model.py
+++ b/funasr/lm/abs_model.py
@@ -15,7 +15,7 @@
 
 from funasr.modules.nets_utils import make_pad_mask
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 
 class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
     """The abstract LM class
@@ -39,7 +39,7 @@
         raise NotImplementedError
 
 
-class LanguageModel(AbsESPnetModel):
+class LanguageModel(FunASRModel):
     def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
         assert check_argument_types()
         super().__init__()
diff --git a/funasr/main_funcs/calculate_all_attentions.py b/funasr/main_funcs/calculate_all_attentions.py
index 8f238c6..c3bf015 100644
--- a/funasr/main_funcs/calculate_all_attentions.py
+++ b/funasr/main_funcs/calculate_all_attentions.py
@@ -21,12 +21,12 @@
 from funasr.modules.attention import MultiHeadedAttention
 
 
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 
 
 @torch.no_grad()
 def calculate_all_attentions(
-    model: AbsESPnetModel, batch: Dict[str, torch.Tensor]
+    model: FunASRModel, batch: Dict[str, torch.Tensor]
 ) -> Dict[str, List[torch.Tensor]]:
     """Derive the outputs from the all attention layers
 
diff --git a/funasr/main_funcs/collect_stats.py b/funasr/main_funcs/collect_stats.py
index bacda8f..584b85a 100644
--- a/funasr/main_funcs/collect_stats.py
+++ b/funasr/main_funcs/collect_stats.py
@@ -17,12 +17,12 @@
 from funasr.fileio.npy_scp import NpyScpWriter
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 
 
 @torch.no_grad()
 def collect_stats(
-    model: AbsESPnetModel,
+    model: FunASRModel,
     train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
     valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
     output_dir: Path,
diff --git a/funasr/models/base_model.py b/funasr/models/base_model.py
new file mode 100644
index 0000000..80b3bbd
--- /dev/null
+++ b/funasr/models/base_model.py
@@ -0,0 +1,17 @@
+import torch
+
+
+class FunASRModel(torch.nn.Module):
+    """The common model class
+
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.num_updates = 0
+
+    def set_num_updates(self, num_updates):
+        self.num_updates = num_updates
+
+    def get_num_updates(self):
+        return self.num_updates
diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py
index fcd6bd2..e5bd640 100644
--- a/funasr/models/data2vec.py
+++ b/funasr/models/data2vec.py
@@ -18,7 +18,7 @@
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -29,7 +29,7 @@
         yield
 
 
-class Data2VecPretrainModel(AbsESPnetModel):
+class Data2VecPretrainModel(FunASRModel):
     """Data2Vec Pretrain model"""
 
     def __init__(
@@ -57,7 +57,6 @@
             speech_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -106,7 +105,6 @@
             speech_lengths: torch.Tensor,
     ):
         """Frontend + Encoder.
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py
index f64ea3d..e6e6a52 100644
--- a/funasr/models/e2e_asr.py
+++ b/funasr/models/e2e_asr.py
@@ -28,7 +28,7 @@
 from funasr.modules.e2e_asr_common import ErrorCalculator
 from funasr.modules.nets_utils import th_accuracy
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -39,7 +39,7 @@
         yield
 
 
-class ESPnetASRModel(AbsESPnetModel):
+class ASRModel(FunASRModel):
     """CTC-attention hybrid Encoder-Decoder model"""
 
     def __init__(
@@ -49,9 +49,7 @@
             frontend: Optional[AbsFrontend],
             specaug: Optional[AbsSpecAug],
             normalize: Optional[AbsNormalize],
-            preencoder: Optional[AbsPreEncoder],
             encoder: AbsEncoder,
-            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             ctc_weight: float = 0.5,
@@ -64,6 +62,8 @@
             sym_space: str = "<space>",
             sym_blank: str = "<blank>",
             extract_feats_in_collect_stats: bool = True,
+            preencoder: Optional[AbsPreEncoder] = None,
+            postencoder: Optional[AbsPostEncoder] = None,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -133,7 +133,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -249,7 +248,6 @@
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -331,9 +329,7 @@
             ys_pad_lens: torch.Tensor,
     ) -> torch.Tensor:
         """Compute negative log likelihood(nll) from transformer-decoder
-
         Normally, this function is called in batchify_nll.
-
         Args:
             encoder_out: (Batch, Length, Dim)
             encoder_out_lens: (Batch,)
@@ -370,7 +366,6 @@
             batch_size: int = 100,
     ):
         """Compute negative log likelihood(nll) from transformer-decoder
-
         To avoid OOM, this fuction seperate the input into batches.
         Then call nll for each batch and combine and return results.
         Args:
diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py
index f22f12a..fbf0d11 100644
--- a/funasr/models/e2e_asr_mfcca.py
+++ b/funasr/models/e2e_asr_mfcca.py
@@ -23,7 +23,7 @@
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -35,7 +35,8 @@
 import pdb
 import random
 import math
-class MFCCA(AbsESPnetModel):
+
+class MFCCA(FunASRModel):
     """
     Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
     MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
@@ -43,26 +44,26 @@
     """
 
     def __init__(
-        self,
-        vocab_size: int,
-        token_list: Union[Tuple[str, ...], List[str]],
-        frontend: Optional[AbsFrontend],
-        specaug: Optional[AbsSpecAug],
-        normalize: Optional[AbsNormalize],
-        preencoder: Optional[AbsPreEncoder],
-        encoder: AbsEncoder,
-        decoder: AbsDecoder,
-        ctc: CTC,
-        rnnt_decoder: None,
-        ctc_weight: float = 0.5,
-        ignore_id: int = -1,
-        lsm_weight: float = 0.0,
-        mask_ratio: float = 0.0,
-        length_normalized_loss: bool = False,
-        report_cer: bool = True,
-        report_wer: bool = True,
-        sym_space: str = "<space>",
-        sym_blank: str = "<blank>",
+            self,
+            vocab_size: int,
+            token_list: Union[Tuple[str, ...], List[str]],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            encoder: AbsEncoder,
+            decoder: AbsDecoder,
+            ctc: CTC,
+            rnnt_decoder: None,
+            ctc_weight: float = 0.5,
+            ignore_id: int = -1,
+            lsm_weight: float = 0.0,
+            mask_ratio: float = 0.0,
+            length_normalized_loss: bool = False,
+            report_cer: bool = True,
+            report_wer: bool = True,
+            sym_space: str = "<space>",
+            sym_blank: str = "<blank>",
+            preencoder: Optional[AbsPreEncoder] = None,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -76,10 +77,9 @@
         self.ignore_id = ignore_id
         self.ctc_weight = ctc_weight
         self.token_list = token_list.copy()
-        
+
         self.mask_ratio = mask_ratio
 
-        
         self.frontend = frontend
         self.specaug = specaug
         self.normalize = normalize
@@ -113,14 +113,13 @@
             self.error_calculator = None
 
     def forward(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -130,22 +129,22 @@
         assert text_lengths.dim() == 1, text_lengths.shape
         # Check that batch_size is unified
         assert (
-            speech.shape[0]
-            == speech_lengths.shape[0]
-            == text.shape[0]
-            == text_lengths.shape[0]
+                speech.shape[0]
+                == speech_lengths.shape[0]
+                == text.shape[0]
+                == text_lengths.shape[0]
         ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
-        #pdb.set_trace()
-        if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
+        # pdb.set_trace()
+        if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
             rate_num = random.random()
-            #rate_num = 0.1
-            if(rate_num<=self.mask_ratio):
-                retain_channel = math.ceil(random.random() *8)
-                if(retain_channel>1):
-                    speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
+            # rate_num = 0.1
+            if (rate_num <= self.mask_ratio):
+                retain_channel = math.ceil(random.random() * 8)
+                if (retain_channel > 1):
+                    speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
                 else:
-                    speech = speech[:,:,torch.randperm(8)[0]]
-        #pdb.set_trace()
+                    speech = speech[:, :, torch.randperm(8)[0]]
+        # pdb.set_trace()
         batch_size = speech.shape[0]
         # for data-parallel
         text = text[:, : text_lengths.max()]
@@ -195,20 +194,19 @@
         return loss, stats, weight
 
     def collect_feats(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
     ) -> Dict[str, torch.Tensor]:
         feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -227,14 +225,14 @@
         # Pre-encoder, e.g. used for raw input data
         if self.preencoder is not None:
             feats, feats_lengths = self.preencoder(feats, feats_lengths)
-        #pdb.set_trace()
+        # pdb.set_trace()
         encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
 
         assert encoder_out.size(0) == speech.size(0), (
             encoder_out.size(),
             speech.size(0),
         )
-        if(encoder_out.dim()==4):
+        if (encoder_out.dim() == 4):
             assert encoder_out.size(2) <= encoder_out_lens.max(), (
                 encoder_out.size(),
                 encoder_out_lens.max(),
@@ -248,7 +246,7 @@
         return encoder_out, encoder_out_lens
 
     def _extract_feats(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         assert speech_lengths.dim() == 1, speech_lengths.shape
         # for data-parallel
@@ -266,11 +264,11 @@
         return feats, feats_lengths, channel_size
 
     def _calc_att_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
     ):
         ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
         ys_in_lens = ys_pad_lens + 1
@@ -298,14 +296,14 @@
         return loss_att, acc_att, cer_att, wer_att
 
     def _calc_ctc_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
     ):
         # Calc CTC loss
-        if(encoder_out.dim()==4):
+        if (encoder_out.dim() == 4):
             encoder_out = encoder_out.mean(1)
         loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
 
@@ -317,10 +315,10 @@
         return loss_ctc, cer_ctc
 
     def _calc_rnnt_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
     ):
-        raise NotImplementedError
+        raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index d02783f..00e08b1 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -29,9 +29,8 @@
 from funasr.modules.nets_utils import make_pad_mask, pad_list
 from funasr.modules.nets_utils import th_accuracy
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.models.predictor.cif import CifPredictorV3
-
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -42,7 +41,7 @@
         yield
 
 
-class Paraformer(AbsESPnetModel):
+class Paraformer(FunASRModel):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
@@ -56,9 +55,7 @@
             frontend: Optional[AbsFrontend],
             specaug: Optional[AbsSpecAug],
             normalize: Optional[AbsNormalize],
-            preencoder: Optional[AbsPreEncoder],
             encoder: AbsEncoder,
-            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             ctc_weight: float = 0.5,
@@ -79,6 +76,8 @@
             predictor_bias: int = 0,
             sampling_ratio: float = 0.2,
             share_embedding: bool = False,
+            preencoder: Optional[AbsPreEncoder] = None,
+            postencoder: Optional[AbsPostEncoder] = None,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -153,7 +152,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -270,7 +268,6 @@
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -368,9 +365,7 @@
             ys_pad_lens: torch.Tensor,
     ) -> torch.Tensor:
         """Compute negative log likelihood(nll) from transformer-decoder
-
         Normally, this function is called in batchify_nll.
-
         Args:
                 encoder_out: (Batch, Length, Dim)
                 encoder_out_lens: (Batch,)
@@ -407,7 +402,6 @@
             batch_size: int = 100,
     ):
         """Compute negative log likelihood(nll) from transformer-decoder
-
         To avoid OOM, this fuction seperate the input into batches.
         Then call nll for each batch and combine and return results.
         Args:
@@ -664,7 +658,10 @@
             self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
+<<<<<<< HEAD
+=======
 
+>>>>>>> 4cd79db451786548d8a100f25c3b03da0eb30f4b
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -738,9 +735,7 @@
             frontend: Optional[AbsFrontend],
             specaug: Optional[AbsSpecAug],
             normalize: Optional[AbsNormalize],
-            preencoder: Optional[AbsPreEncoder],
             encoder: AbsEncoder,
-            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             ctc_weight: float = 0.5,
@@ -763,6 +758,8 @@
             embeds_id: int = 2,
             embeds_loss_weight: float = 0.0,
             embed_dims: int = 768,
+            preencoder: Optional[AbsPreEncoder] = None,
+            postencoder: Optional[AbsPostEncoder] = None,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -894,7 +891,6 @@
             embed_lengths: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -913,9 +909,9 @@
         self.step_cur += 1
         # for data-parallel
         text = text[:, : text_lengths.max()]
-        speech = speech[:, :speech_lengths.max(), :]
+        speech = speech[:, :speech_lengths.max()]
         if embed is not None:
-            embed = embed[:, :embed_lengths.max(), :]
+            embed = embed[:, :embed_lengths.max()]
 
         # 1. Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
@@ -1003,74 +999,73 @@
 
 
 class BiCifParaformer(Paraformer):
-
     """
     Paraformer model with an extra cif predictor
     to conduct accurate timestamp prediction
     """
 
     def __init__(
-        self,
-        vocab_size: int,
-        token_list: Union[Tuple[str, ...], List[str]],
-        frontend: Optional[AbsFrontend],
-        specaug: Optional[AbsSpecAug],
-        normalize: Optional[AbsNormalize],
-        preencoder: Optional[AbsPreEncoder],
-        encoder: AbsEncoder,
-        postencoder: Optional[AbsPostEncoder],
-        decoder: AbsDecoder,
-        ctc: CTC,
-        ctc_weight: float = 0.5,
-        interctc_weight: float = 0.0,
-        ignore_id: int = -1,
-        blank_id: int = 0,
-        sos: int = 1,
-        eos: int = 2,
-        lsm_weight: float = 0.0,
-        length_normalized_loss: bool = False,
-        report_cer: bool = True,
-        report_wer: bool = True,
-        sym_space: str = "<space>",
-        sym_blank: str = "<blank>",
-        extract_feats_in_collect_stats: bool = True,
-        predictor = None,
-        predictor_weight: float = 0.0,
-        predictor_bias: int = 0,
-        sampling_ratio: float = 0.2,
+            self,
+            vocab_size: int,
+            token_list: Union[Tuple[str, ...], List[str]],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            encoder: AbsEncoder,
+            decoder: AbsDecoder,
+            ctc: CTC,
+            ctc_weight: float = 0.5,
+            interctc_weight: float = 0.0,
+            ignore_id: int = -1,
+            blank_id: int = 0,
+            sos: int = 1,
+            eos: int = 2,
+            lsm_weight: float = 0.0,
+            length_normalized_loss: bool = False,
+            report_cer: bool = True,
+            report_wer: bool = True,
+            sym_space: str = "<space>",
+            sym_blank: str = "<blank>",
+            extract_feats_in_collect_stats: bool = True,
+            predictor=None,
+            predictor_weight: float = 0.0,
+            predictor_bias: int = 0,
+            sampling_ratio: float = 0.2,
+            preencoder: Optional[AbsPreEncoder] = None,
+            postencoder: Optional[AbsPostEncoder] = None,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
         super().__init__(
-        vocab_size=vocab_size,
-        token_list=token_list,
-        frontend=frontend,
-        specaug=specaug,
-        normalize=normalize,
-        preencoder=preencoder,
-        encoder=encoder,
-        postencoder=postencoder,
-        decoder=decoder,
-        ctc=ctc,
-        ctc_weight=ctc_weight,
-        interctc_weight=interctc_weight,
-        ignore_id=ignore_id,
-        blank_id=blank_id,
-        sos=sos,
-        eos=eos,
-        lsm_weight=lsm_weight,
-        length_normalized_loss=length_normalized_loss,
-        report_cer=report_cer,
-        report_wer=report_wer,
-        sym_space=sym_space,
-        sym_blank=sym_blank,
-        extract_feats_in_collect_stats=extract_feats_in_collect_stats,
-        predictor=predictor,
-        predictor_weight=predictor_weight,
-        predictor_bias=predictor_bias,
-        sampling_ratio=sampling_ratio,
+            vocab_size=vocab_size,
+            token_list=token_list,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            preencoder=preencoder,
+            encoder=encoder,
+            postencoder=postencoder,
+            decoder=decoder,
+            ctc=ctc,
+            ctc_weight=ctc_weight,
+            interctc_weight=interctc_weight,
+            ignore_id=ignore_id,
+            blank_id=blank_id,
+            sos=sos,
+            eos=eos,
+            lsm_weight=lsm_weight,
+            length_normalized_loss=length_normalized_loss,
+            report_cer=report_cer,
+            report_wer=report_wer,
+            sym_space=sym_space,
+            sym_blank=sym_blank,
+            extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+            predictor=predictor,
+            predictor_weight=predictor_weight,
+            predictor_bias=predictor_bias,
+            sampling_ratio=sampling_ratio,
         )
         assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
 
@@ -1145,21 +1140,23 @@
             cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
 
         return loss_att, acc_att, cer_att, wer_att, loss_pre
-    
+
     def calc_predictor(self, encoder_out, encoder_out_lens):
 
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
-                                                                                  ignore_id=self.ignore_id)
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+                                                                                                          None,
+                                                                                                          encoder_out_mask,
+                                                                                                          ignore_id=self.ignore_id)
         return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-    
+
     def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
         ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
-                                                                                               encoder_out_mask,
-                                                                                               token_num)
+                                                                                            encoder_out_mask,
+                                                                                            token_num)
         return ds_alphas, ds_cif_peak, us_alphas, us_peaks
 
     def forward(
@@ -1170,7 +1167,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -1253,7 +1249,8 @@
         elif self.ctc_weight == 1.0:
             loss = loss_ctc
         else:
-            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+            loss = self.ctc_weight * loss_ctc + (
+                        1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
 
         # Collect Attn branch stats
         stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -1282,9 +1279,7 @@
             frontend: Optional[AbsFrontend],
             specaug: Optional[AbsSpecAug],
             normalize: Optional[AbsNormalize],
-            preencoder: Optional[AbsPreEncoder],
             encoder: AbsEncoder,
-            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             ctc_weight: float = 0.5,
@@ -1314,6 +1309,8 @@
             bias_encoder_type: str = 'lstm',
             label_bracket: bool = False,
             use_decoder_embedding: bool = False,
+            preencoder: Optional[AbsPreEncoder] = None,
+            postencoder: Optional[AbsPostEncoder] = None,
     ):
         assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -1377,7 +1374,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -1761,4 +1757,4 @@
                     "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
                                                                                   var_dict_tf[name_tf].shape))
 
-        return var_dict_torch_update
+        return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index 097b23a..da7c674 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -16,7 +16,7 @@
 from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
 from funasr.modules.eend_ola.utils.power import generate_mapping_dict
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     pass
@@ -34,7 +34,7 @@
     return att
 
 
-class DiarEENDOLAModel(AbsESPnetModel):
+class DiarEENDOLAModel(FunASRModel):
     """EEND-OLA diarization model"""
 
     def __init__(
@@ -91,7 +91,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index 3f7011d..9c3fb92 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -22,7 +22,7 @@
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
 from funasr.utils.misc import int2vec
 
@@ -35,7 +35,7 @@
         yield
 
 
-class DiarSondModel(AbsESPnetModel):
+class DiarSondModel(FunASRModel):
     """
     Author: Speech Lab, Alibaba Group, China
     SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
@@ -115,7 +115,6 @@
         binary_labels_lengths: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
-
         Args:
             speech: (Batch, samples) or (Batch, frames, input_size)
             speech_lengths: (Batch,) default None for chunk interator,
@@ -391,7 +390,6 @@
         self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch,)
@@ -491,4 +489,4 @@
             speaker_miss,
             speaker_falarm,
             speaker_error,
-        )
+        )
\ No newline at end of file
diff --git a/funasr/models/e2e_sv.py b/funasr/models/e2e_sv.py
index 5b21277..bd5178e 100644
--- a/funasr/models/e2e_sv.py
+++ b/funasr/models/e2e_sv.py
@@ -29,7 +29,7 @@
 from funasr.modules.e2e_asr_common import ErrorCalculator
 from funasr.modules.nets_utils import th_accuracy
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -40,7 +40,7 @@
         yield
 
 
-class ESPnetSVModel(AbsESPnetModel):
+class ESPnetSVModel(FunASRModel):
     """CTC-attention hybrid Encoder-Decoder model"""
 
     def __init__(
@@ -80,7 +80,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -221,7 +220,6 @@
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
             speech: (Batch, Length, ...)
             speech_lengths: (Batch, )
@@ -271,4 +269,4 @@
         else:
             # No frontend and no feature extract
             feats, feats_lengths = speech, speech_lengths
-        return feats, feats_lengths
+        return feats, feats_lengths
\ No newline at end of file
diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
index d1367ab..33948f9 100644
--- a/funasr/models/e2e_tp.py
+++ b/funasr/models/e2e_tp.py
@@ -17,9 +17,8 @@
 from funasr.modules.add_sos_eos import add_sos_eos
 from funasr.modules.nets_utils import make_pad_mask, pad_list
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.models.predictor.cif import CifPredictorV3
-
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -30,7 +29,7 @@
         yield
 
 
-class TimestampPredictor(AbsESPnetModel):
+class TimestampPredictor(FunASRModel):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     """
@@ -56,7 +55,7 @@
         self.predictor_bias = predictor_bias
         self.criterion_pre = mae_loss()
         self.token_list = token_list
-    
+
     def forward(
             self,
             speech: torch.Tensor,
@@ -65,7 +64,6 @@
             text_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -113,7 +111,6 @@
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
                 speech: (Batch, Length, ...)
                 speech_lengths: (Batch, )
@@ -128,7 +125,7 @@
         encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
 
         return encoder_out, encoder_out_lens
-    
+
     def _extract_feats(
             self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -151,8 +148,8 @@
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
         ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
-                                                                                               encoder_out_mask,
-                                                                                               token_num)
+                                                                                            encoder_out_mask,
+                                                                                            token_num)
         return ds_alphas, ds_cif_peak, us_alphas, us_peaks
 
     def collect_feats(
diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py
index ca76244..d08ea37 100644
--- a/funasr/models/e2e_uni_asr.py
+++ b/funasr/models/e2e_uni_asr.py
@@ -25,7 +25,7 @@
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
 from funasr.models.predictor.cif import mae_loss
 
@@ -38,7 +38,7 @@
         yield
 
 
-class UniASR(AbsESPnetModel):
+class UniASR(FunASRModel):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     """
@@ -179,7 +179,6 @@
         decoding_ind: int = None,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Decoder + Calc loss
-
         Args:
                         speech: (Batch, Length, ...)
                         speech_lengths: (Batch, )
@@ -469,7 +468,6 @@
         self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
                         speech: (Batch, Length, ...)
                         speech_lengths: (Batch, )
@@ -533,7 +531,6 @@
         ind: int = 0,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Frontend + Encoder. Note that this method is used by asr_inference.py
-
         Args:
                         speech: (Batch, Length, ...)
                         speech_lengths: (Batch, )
@@ -627,9 +624,7 @@
         ys_pad_lens: torch.Tensor,
     ) -> torch.Tensor:
         """Compute negative log likelihood(nll) from transformer-decoder
-
         Normally, this function is called in batchify_nll.
-
         Args:
                         encoder_out: (Batch, Length, Dim)
                         encoder_out_lens: (Batch,)
@@ -666,7 +661,6 @@
         batch_size: int = 100,
     ):
         """Compute negative log likelihood(nll) from transformer-decoder
-
         To avoid OOM, this fuction seperate the input into batches.
         Then call nll for each batch and combine and return results.
         Args:
@@ -1072,4 +1066,3 @@
             ys_hat = self.ctc2.argmax(encoder_out).data
             cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
         return loss_ctc, cer_ctc
-
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index d72c635..82d8422 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -469,7 +469,7 @@
                         - 1)) / self.vad_opts.noise_frame_num_used_for_snr
 
         return frame_state
-     
+
     def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                 is_final: bool = False
                 ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
@@ -499,11 +499,11 @@
         return segments, in_cache
 
     def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
-                is_final: bool = False, max_end_sil: int = 800
-                ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+                       is_final: bool = False, max_end_sil: int = 800
+                       ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
         self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
         self.waveform = waveform  # compute decibel for each frame
-        
+
         self.ComputeScores(feats, in_cache)
         self.ComputeDecibel()
         if not is_final:
diff --git a/funasr/models/encoder/abs_encoder.py b/funasr/models/encoder/abs_encoder.py
index 1fb7c97..034bc1f 100644
--- a/funasr/models/encoder/abs_encoder.py
+++ b/funasr/models/encoder/abs_encoder.py
@@ -18,4 +18,4 @@
         ilens: torch.Tensor,
         prev_states: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        raise NotImplementedError
+        raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 434f2a4..aa3b67e 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -15,13 +15,13 @@
 from typeguard import check_argument_types
 
 from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.modules.attention import (
     MultiHeadedAttention,  # noqa: H301
     RelPositionMultiHeadedAttention,  # noqa: H301
     RelPositionMultiHeadedAttentionChunk,
     LegacyRelPositionMultiHeadedAttention,  # noqa: H301
 )
+from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.modules.embedding import (
     PositionalEncoding,  # noqa: H301
     ScaledPositionalEncoding,  # noqa: H301
diff --git a/funasr/models/encoder/data2vec_encoder.py b/funasr/models/encoder/data2vec_encoder.py
index fd1796c..64c2144 100644
--- a/funasr/models/encoder/data2vec_encoder.py
+++ b/funasr/models/encoder/data2vec_encoder.py
@@ -574,4 +574,4 @@
             )
 
     def output_size(self) -> int:
-        return self.encoder_embed_dim
+        return self.encoder_embed_dim
\ No newline at end of file
diff --git a/funasr/models/encoder/mfcca_encoder.py b/funasr/models/encoder/mfcca_encoder.py
index 83d0b0e..95ccf07 100644
--- a/funasr/models/encoder/mfcca_encoder.py
+++ b/funasr/models/encoder/mfcca_encoder.py
@@ -38,13 +38,12 @@
 import pdb
 import math
 
+
 class ConvolutionModule(nn.Module):
     """ConvolutionModule in Conformer model.
-
     Args:
         channels (int): The number of channels of conv layers.
         kernel_size (int): Kernerl size of conv layers.
-
     """
 
     def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
@@ -83,13 +82,10 @@
 
     def forward(self, x):
         """Compute convolution module.
-
         Args:
             x (torch.Tensor): Input tensor (#batch, time, channels).
-
         Returns:
             torch.Tensor: Output tensor (#batch, time, channels).
-
         """
         # exchange the temporal dimension and the feature dimension
         x = x.transpose(1, 2)
@@ -107,10 +103,8 @@
         return x.transpose(1, 2)
 
 
-
 class MFCCAEncoder(AbsEncoder):
     """Conformer encoder module.
-
     Args:
         input_size (int): Input dimension.
         output_size (int): Dimention of attention.
@@ -140,33 +134,32 @@
         zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
         cnn_module_kernel (int): Kernerl size of convolution module.
         padding_idx (int): Padding idx for input_layer=embed.
-
     """
 
     def __init__(
-        self,
-        input_size: int,
-        output_size: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        attention_dropout_rate: float = 0.0,
-        input_layer: str = "conv2d",
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        positionwise_layer_type: str = "linear",
-        positionwise_conv_kernel_size: int = 3,
-        macaron_style: bool = False,
-        rel_pos_type: str = "legacy",
-        pos_enc_layer_type: str = "rel_pos",
-        selfattention_layer_type: str = "rel_selfattn",
-        activation_type: str = "swish",
-        use_cnn_module: bool = True,
-        zero_triu: bool = False,
-        cnn_module_kernel: int = 31,
-        padding_idx: int = -1,
+            self,
+            input_size: int,
+            output_size: int = 256,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            attention_dropout_rate: float = 0.0,
+            input_layer: str = "conv2d",
+            normalize_before: bool = True,
+            concat_after: bool = False,
+            positionwise_layer_type: str = "linear",
+            positionwise_conv_kernel_size: int = 3,
+            macaron_style: bool = False,
+            rel_pos_type: str = "legacy",
+            pos_enc_layer_type: str = "rel_pos",
+            selfattention_layer_type: str = "rel_selfattn",
+            activation_type: str = "swish",
+            use_cnn_module: bool = True,
+            zero_triu: bool = False,
+            cnn_module_kernel: int = 31,
+            padding_idx: int = -1,
     ):
         assert check_argument_types()
         super().__init__()
@@ -199,7 +192,7 @@
             )
         else:
             raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
-        
+
         if input_layer == "linear":
             self.embed = torch.nn.Sequential(
                 torch.nn.Linear(input_size, output_size),
@@ -283,7 +276,7 @@
             assert pos_enc_layer_type == "legacy_rel_pos"
             encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
             encoder_selfattn_layer_args = (
-               attention_heads,
+                attention_heads,
                 output_size,
                 attention_dropout_rate,
             )
@@ -326,42 +319,39 @@
         )
         if self.normalize_before:
             self.after_norm = LayerNorm(output_size)
-        self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
+        self.conv1 = torch.nn.Conv2d(8, 16, [5, 7], stride=[1, 1], padding=(2, 3))
 
-        self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
+        self.conv2 = torch.nn.Conv2d(16, 32, [5, 7], stride=[1, 1], padding=(2, 3))
 
-        self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
+        self.conv3 = torch.nn.Conv2d(32, 16, [5, 7], stride=[1, 1], padding=(2, 3))
 
-        self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
+        self.conv4 = torch.nn.Conv2d(16, 1, [5, 7], stride=[1, 1], padding=(2, 3))
 
     def output_size(self) -> int:
         return self._output_size
 
     def forward(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        channel_size: torch.Tensor,
-        prev_states: torch.Tensor = None,
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+            channel_size: torch.Tensor,
+            prev_states: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Calculate forward propagation.
-
         Args:
             xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
             ilens (torch.Tensor): Input length (#batch).
             prev_states (torch.Tensor): Not to be used now.
-
         Returns:
             torch.Tensor: Output tensor (#batch, L, output_size).
             torch.Tensor: Output length (#batch).
             torch.Tensor: Not to be used now.
-
         """
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
         if (
-            isinstance(self.embed, Conv2dSubsampling)
-            or isinstance(self.embed, Conv2dSubsampling6)
-            or isinstance(self.embed, Conv2dSubsampling8)
+                isinstance(self.embed, Conv2dSubsampling)
+                or isinstance(self.embed, Conv2dSubsampling6)
+                or isinstance(self.embed, Conv2dSubsampling8)
         ):
             short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
             if short_status:
@@ -380,48 +370,46 @@
 
         t_leng = xs_pad.size(1)
         d_dim = xs_pad.size(2)
-        xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
-        #pdb.set_trace()
-        if(channel_size<8):
-            repeat_num = math.ceil(8/channel_size)
-            xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
+        xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim)
+        # pdb.set_trace()
+        if (channel_size < 8):
+            repeat_num = math.ceil(8 / channel_size)
+            xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :]
         xs_pad = self.conv1(xs_pad)
         xs_pad = self.conv2(xs_pad)
         xs_pad = self.conv3(xs_pad)
         xs_pad = self.conv4(xs_pad)
-        xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
+        xs_pad = xs_pad.squeeze().reshape(-1, t_leng, d_dim)
         mask_tmp = masks.size(1)
-        masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
+        masks = masks.reshape(-1, channel_size, mask_tmp, t_leng)[:, 0, :, :]
 
         if self.normalize_before:
             xs_pad = self.after_norm(xs_pad)
 
         olens = masks.squeeze(1).sum(1)
         return xs_pad, olens, None
+
     def forward_hidden(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        prev_states: torch.Tensor = None,
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+            prev_states: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Calculate forward propagation.
-
         Args:
             xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
             ilens (torch.Tensor): Input length (#batch).
             prev_states (torch.Tensor): Not to be used now.
-
         Returns:
             torch.Tensor: Output tensor (#batch, L, output_size).
             torch.Tensor: Output length (#batch).
             torch.Tensor: Not to be used now.
-
         """
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
         if (
-            isinstance(self.embed, Conv2dSubsampling)
-            or isinstance(self.embed, Conv2dSubsampling6)
-            or isinstance(self.embed, Conv2dSubsampling8)
+                isinstance(self.embed, Conv2dSubsampling)
+                or isinstance(self.embed, Conv2dSubsampling6)
+                or isinstance(self.embed, Conv2dSubsampling8)
         ):
             short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
             if short_status:
@@ -447,4 +435,4 @@
             self.hidden_feature = self.after_norm(hidden_feature)
 
         olens = masks.squeeze(1).sum(1)
-        return xs_pad, olens, None
+        return xs_pad, olens, None
\ No newline at end of file
diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py
index 93695c8..8445feb 100644
--- a/funasr/models/encoder/resnet34_encoder.py
+++ b/funasr/models/encoder/resnet34_encoder.py
@@ -850,4 +850,4 @@
                 else:
                     logging.warning("{} is missed from tf checkpoint".format(name))
 
-        return var_dict_torch_update
+        return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/encoder/rnn_encoder.py b/funasr/models/encoder/rnn_encoder.py
index 7a3b053..59730da 100644
--- a/funasr/models/encoder/rnn_encoder.py
+++ b/funasr/models/encoder/rnn_encoder.py
@@ -1,3 +1,4 @@
+
 from typing import Optional
 from typing import Sequence
 from typing import Tuple
@@ -14,7 +15,6 @@
 
 class RNNEncoder(AbsEncoder):
     """RNNEncoder class.
-
     Args:
         input_size: The number of expected features in the input
         output_size: The number of output features
@@ -23,7 +23,6 @@
         use_projection: Use projection layer or not
         num_layers: Number of recurrent layers
         dropout: dropout probability
-
     """
 
     def __init__(
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 2a68011..e071e57 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -27,9 +27,10 @@
 from funasr.modules.subsampling import Conv2dSubsampling8
 from funasr.modules.subsampling import TooShortUttError
 from funasr.modules.subsampling import check_short_utt
+from funasr.modules.mask import subsequent_mask, vad_mask
+
 from funasr.models.ctc import CTC
 from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.modules.mask import subsequent_mask, vad_mask
 
 class EncoderLayerSANM(nn.Module):
     def __init__(
diff --git a/funasr/models/frontend/abs_frontend.py b/funasr/models/frontend/abs_frontend.py
index 538236f..6049a01 100644
--- a/funasr/models/frontend/abs_frontend.py
+++ b/funasr/models/frontend/abs_frontend.py
@@ -14,4 +14,4 @@
     def forward(
         self, input: torch.Tensor, input_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
-        raise NotImplementedError
+        raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 2e1b0c4..71f8ec4 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -18,7 +18,6 @@
 
 class DefaultFrontend(AbsFrontend):
     """Conventional frontend structure for ASR.
-
     Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
     """
 
@@ -142,7 +141,6 @@
 
 class MultiChannelFrontend(AbsFrontend):
     """Conventional frontend structure for ASR.
-
     Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
     """
 
@@ -260,4 +258,4 @@
         # Change torch.Tensor to ComplexTensor
         # input_stft: (..., F, 2) -> (..., F)
         input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
-        return input_stft, feats_lens
+        return input_stft, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/fused.py b/funasr/models/frontend/fused.py
index 8b5e56e..857486d 100644
--- a/funasr/models/frontend/fused.py
+++ b/funasr/models/frontend/fused.py
@@ -143,4 +143,4 @@
         else:
             raise NotImplementedError
 
-        return input_feats, feats_lens
+        return input_feats, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/s3prl.py b/funasr/models/frontend/s3prl.py
index f2b6107..b03d2c9 100644
--- a/funasr/models/frontend/s3prl.py
+++ b/funasr/models/frontend/s3prl.py
@@ -100,7 +100,6 @@
 
     def _tile_representations(self, feature):
         """Tile up the representations by `tile_factor`.
-
         Input - sequence of representations
                 shape: (batch_size, seq_len, feature_dim)
         Output - sequence of tiled representations
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 1dbf490..35fab57 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -500,4 +500,4 @@
         feats_pad = pad_sequence(feats,
                                  batch_first=True,
                                  padding_value=0.0)
-        return feats_pad, feats_lens
+        return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/wav_frontend_kaldifeat.py b/funasr/models/frontend/wav_frontend_kaldifeat.py
index b91ac63..5372de3 100644
--- a/funasr/models/frontend/wav_frontend_kaldifeat.py
+++ b/funasr/models/frontend/wav_frontend_kaldifeat.py
@@ -1,15 +1,9 @@
 # Copyright (c) Alibaba, Inc. and its affiliates.
 # Part of the implementation is borrowed from espnet/espnet.
 
-from typing import Tuple
-
 import numpy as np
 import torch
-import torchaudio.compliance.kaldi as kaldi
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from typeguard import check_argument_types
-from torch.nn.utils.rnn import pad_sequence
-# import kaldifeat
+
 
 def load_cmvn(cmvn_file):
     with open(cmvn_file, 'r', encoding='utf-8') as f:
@@ -33,9 +27,9 @@
     means = np.array(means_list).astype(np.float)
     vars = np.array(vars_list).astype(np.float)
     cmvn = np.array([means, vars])
-    cmvn = torch.as_tensor(cmvn) 
-    return cmvn 
-          
+    cmvn = torch.as_tensor(cmvn)
+    return cmvn
+
 
 def apply_cmvn(inputs, cmvn_file):  # noqa
     """
@@ -73,108 +67,3 @@
             LFR_inputs.append(frame)
     LFR_outputs = torch.vstack(LFR_inputs)
     return LFR_outputs.type(torch.float32)
-
-
-# class WavFrontend_kaldifeat(AbsFrontend):
-#     """Conventional frontend structure for ASR.
-#     """
-#
-#     def __init__(
-#         self,
-#         cmvn_file: str = None,
-#         fs: int = 16000,
-#         window: str = 'hamming',
-#         n_mels: int = 80,
-#         frame_length: int = 25,
-#         frame_shift: int = 10,
-#         lfr_m: int = 1,
-#         lfr_n: int = 1,
-#         dither: float = 1.0,
-#         snip_edges: bool = True,
-#         upsacle_samples: bool = True,
-#         device: str = 'cpu',
-#         **kwargs,
-#     ):
-#         super().__init__()
-#
-#         opts = kaldifeat.FbankOptions()
-#         opts.device = device
-#         opts.frame_opts.samp_freq = fs
-#         opts.frame_opts.dither = dither
-#         opts.frame_opts.window_type = window
-#         opts.frame_opts.frame_shift_ms = float(frame_shift)
-#         opts.frame_opts.frame_length_ms = float(frame_length)
-#         opts.mel_opts.num_bins = n_mels
-#         opts.energy_floor = 0
-#         opts.frame_opts.snip_edges = snip_edges
-#         opts.mel_opts.debug_mel = False
-#         self.opts = opts
-#         self.fbank_fn = None
-#         self.fbank_beg_idx = 0
-#         self.reset_fbank_status()
-#
-#         self.lfr_m = lfr_m
-#         self.lfr_n = lfr_n
-#         self.cmvn_file = cmvn_file
-#         self.upsacle_samples = upsacle_samples
-#
-#     def output_size(self) -> int:
-#         return self.n_mels * self.lfr_m
-#
-#     def forward_fbank(
-#         self,
-#         input: torch.Tensor,
-#         input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-#         batch_size = input.size(0)
-#         feats = []
-#         feats_lens = []
-#         for i in range(batch_size):
-#             waveform_length = input_lengths[i]
-#             waveform = input[i][:waveform_length]
-#             waveform = waveform * (1 << 15)
-#
-#             self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
-#             frames = self.fbank_fn.num_frames_ready
-#             frames_cur = frames - self.fbank_beg_idx
-#             mat = torch.empty([frames_cur, self.opts.mel_opts.num_bins], dtype=torch.float32).to(
-#                 device=self.opts.device)
-#             for i in range(self.fbank_beg_idx, frames):
-#                 mat[i, :] = self.fbank_fn.get_frame(i)
-#             self.fbank_beg_idx += frames_cur
-#
-#             feat_length = mat.size(0)
-#             feats.append(mat)
-#             feats_lens.append(feat_length)
-#
-#         feats_lens = torch.as_tensor(feats_lens)
-#         feats_pad = pad_sequence(feats,
-#                                  batch_first=True,
-#                                  padding_value=0.0)
-#         return feats_pad, feats_lens
-#
-#     def reset_fbank_status(self):
-#         self.fbank_fn = kaldifeat.OnlineFbank(self.opts)
-#         self.fbank_beg_idx = 0
-#
-#     def forward_lfr_cmvn(
-#         self,
-#         input: torch.Tensor,
-#         input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-#         batch_size = input.size(0)
-#         feats = []
-#         feats_lens = []
-#         for i in range(batch_size):
-#             mat = input[i, :input_lengths[i], :]
-#             if self.lfr_m != 1 or self.lfr_n != 1:
-#                 mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
-#             if self.cmvn_file is not None:
-#                 mat = apply_cmvn(mat, self.cmvn_file)
-#             feat_length = mat.size(0)
-#             feats.append(mat)
-#             feats_lens.append(feat_length)
-#
-#         feats_lens = torch.as_tensor(feats_lens)
-#         feats_pad = pad_sequence(feats,
-#                                  batch_first=True,
-#                                  padding_value=0.0)
-#         return feats_pad, feats_lens
diff --git a/funasr/models/frontend/windowing.py b/funasr/models/frontend/windowing.py
index 7c4c568..a526758 100644
--- a/funasr/models/frontend/windowing.py
+++ b/funasr/models/frontend/windowing.py
@@ -12,12 +12,10 @@
 
 class SlidingWindow(AbsFrontend):
     """Sliding Window.
-
     Provides a sliding window over a batched continuous raw audio tensor.
     Optionally, provides padding (Currently not implemented).
     Combine this module with a pre-encoder compatible with raw audio data,
     for example Sinc convolutions.
-
     Known issues:
     Output length is calculated incorrectly if audio shorter than win_length.
     WARNING: trailing values are discarded - padding not implemented yet.
@@ -33,7 +31,6 @@
         fs=None,
     ):
         """Initialize.
-
         Args:
             win_length: Length of frame.
             hop_length: Relative starting point of next frame.
@@ -53,11 +50,9 @@
         self, input: torch.Tensor, input_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Apply a sliding window on the input.
-
         Args:
             input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
             input_lengths: Input lengths within batch.
-
         Returns:
             Tensor: Output with dimensions (B, T, C, D), with D=win_length.
             Tensor: Output lengths within batch.
@@ -78,4 +73,4 @@
 
     def output_size(self) -> int:
         """Return output length of feature dimension D, i.e. the window length."""
-        return self.win_length
+        return self.win_length
\ No newline at end of file
diff --git a/funasr/models/specaug/abs_specaug.py b/funasr/models/specaug/abs_specaug.py
index 3cbac41..da6637e 100644
--- a/funasr/models/specaug/abs_specaug.py
+++ b/funasr/models/specaug/abs_specaug.py
@@ -6,9 +6,7 @@
 
 class AbsSpecAug(torch.nn.Module):
     """Abstract class for the augmentation of spectrogram
-
     The process-flow:
-
     Frontend  -> SpecAug -> Normalization -> Encoder -> Decoder
     """
 
diff --git a/funasr/models/target_delay_transformer.py b/funasr/models/target_delay_transformer.py
index e893c65..19e5c7c 100644
--- a/funasr/models/target_delay_transformer.py
+++ b/funasr/models/target_delay_transformer.py
@@ -6,13 +6,10 @@
 import torch.nn as nn
 
 from funasr.modules.embedding import SinusoidalPositionEncoder
-#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
 from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
-#from funasr.modules.mask import subsequent_n_mask
-from funasr.train.abs_model import AbsPunctuation
 
 
-class TargetDelayTransformer(AbsPunctuation):
+class TargetDelayTransformer(torch.nn.Module):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
diff --git a/funasr/models/vad_realtime_transformer.py b/funasr/models/vad_realtime_transformer.py
index fe298ce..e2d13f9 100644
--- a/funasr/models/vad_realtime_transformer.py
+++ b/funasr/models/vad_realtime_transformer.py
@@ -7,10 +7,9 @@
 
 from funasr.modules.embedding import SinusoidalPositionEncoder
 from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
-from funasr.train.abs_model import AbsPunctuation
 
 
-class VadRealtimeTransformer(AbsPunctuation):
+class VadRealtimeTransformer(torch.nn.Module):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 55a5d79..361ff89 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -30,6 +30,7 @@
 import torch.nn
 import torch.optim
 import yaml
+from funasr.models.base_model import FunASRModel
 from torch.utils.data import DataLoader
 from typeguard import check_argument_types
 from typeguard import check_return_type
@@ -44,19 +45,18 @@
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
 from funasr.iterators.sequence_iter_factory import SequenceIterFactory
 from funasr.main_funcs.collect_stats import collect_stats
-from funasr.optimizers.sgd import SGD
 from funasr.optimizers.fairseq_adam import FairseqAdam
+from funasr.optimizers.sgd import SGD
 from funasr.samplers.build_batch_sampler import BATCH_TYPES
 from funasr.samplers.build_batch_sampler import build_batch_sampler
 from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
 from funasr.schedulers.noam_lr import NoamLR
-from funasr.schedulers.warmup_lr import WarmupLR
 from funasr.schedulers.tri_stage_scheduler import TriStageLR
+from funasr.schedulers.warmup_lr import WarmupLR
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.distributed_utils import DistributedOption
 from funasr.train.trainer import Trainer
@@ -230,8 +230,8 @@
         >>> cls.check_task_requirements()
         If your model is defined as following,
 
-        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
-        >>> class Model(AbsESPnetModel):
+        >>> from funasr.models.base_model import FunASRModel
+        >>> class Model(FunASRModel):
         ...     def forward(self, input, output, opt=None):  pass
 
         then "required_data_names" should be as
@@ -251,8 +251,8 @@
         >>> cls.check_task_requirements()
         If your model is defined as follows,
 
-        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
-        >>> class Model(AbsESPnetModel):
+        >>> from funasr.models.base_model import FunASRModel
+        >>> class Model(FunASRModel):
         ...     def forward(self, input, output, opt=None):  pass
 
         then "optional_data_names" should be as
@@ -263,7 +263,7 @@
 
     @classmethod
     @abstractmethod
-    def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
+    def build_model(cls, args: argparse.Namespace) -> FunASRModel:
         raise NotImplementedError
 
     @classmethod
@@ -1172,7 +1172,8 @@
                     args.batch_bins = args.batch_bins * args.ngpu
 
         # filter samples if wav.scp and text are mismatch
-        if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
+        if (
+                args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
             if not args.simple_ddp or distributed_option.dist_rank == 0:
                 filter_wav_text(args.data_dir, args.train_set)
                 filter_wav_text(args.data_dir, args.dev_set)
@@ -1181,8 +1182,10 @@
 
         if args.train_shape_file is None and args.dataset_type == "small":
             if not args.simple_ddp or distributed_option.dist_rank == 0:
-                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
-                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
+                calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min,
+                           args.speech_length_max)
+                calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min,
+                           args.speech_length_max)
             if args.simple_ddp:
                 dist.barrier()
             args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
@@ -1244,9 +1247,9 @@
 
         # 2. Build model
         model = cls.build_model(args=args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model = model.to(
             dtype=getattr(torch, args.train_dtype),
@@ -1374,15 +1377,21 @@
             if args.dataset_type == "large":
                 from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
                 train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
-                                                   frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
-                                                   seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
-                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
+                                                   frontend_conf=args.frontend_conf if hasattr(args,
+                                                                                               "frontend_conf") else None,
+                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
+                                                                                               "seg_dict_file") else None,
+                                                   punc_dict_file=args.punc_list if hasattr(args,
+                                                                                            "punc_list") else None,
                                                    bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                    mode="train")
-                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, 
-                                                   frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
-                                                   seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
-                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
+                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
+                                                   frontend_conf=args.frontend_conf if hasattr(args,
+                                                                                               "frontend_conf") else None,
+                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
+                                                                                               "seg_dict_file") else None,
+                                                   punc_dict_file=args.punc_list if hasattr(args,
+                                                                                            "punc_list") else None,
                                                    bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                    mode="eval")
             elif args.dataset_type == "small":
@@ -1929,7 +1938,7 @@
             model_file: Union[Path, str] = None,
             cmvn_file: Union[Path, str] = None,
             device: str = "cpu",
-    ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
+    ) -> Tuple[FunASRModel, argparse.Namespace]:
         """Build model from the files.
 
         This method is used for inference or fine-tuning.
@@ -1956,9 +1965,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         if model_file is not None:
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 43ea5ab..b01aa6f 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -38,9 +38,9 @@
 from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.models.decoder.transformer_decoder import TransformerDecoder
 from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.e2e_asr import ASRModel
 from funasr.models.decoder.rnnt_decoder import RNNTDecoder
 from funasr.models.joint_net.joint_network import JointNetwork
-from funasr.models.e2e_asr import ESPnetASRModel
 from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
 from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
 from funasr.models.e2e_tp import TimestampPredictor
@@ -76,7 +76,7 @@
 from funasr.tasks.abs_task import AbsTask
 from funasr.text.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
 from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -122,7 +122,7 @@
 model_choices = ClassChoices(
     "model",
     classes=dict(
-        asr=ESPnetASRModel,
+        asr=ASRModel,
         uniasr=UniASR,
         paraformer=Paraformer,
         paraformer_online=ParaformerOnline,
@@ -133,7 +133,7 @@
         mfcca=MFCCA,
         timestamp_prediction=TimestampPredictor,
     ),
-    type_check=AbsESPnetModel,
+    type_check=FunASRModel,
     default="asr",
 )
 preencoder_choices = ClassChoices(
@@ -457,7 +457,7 @@
                 token_type=args.token_type,
                 token_list=args.token_list,
                 bpemodel=args.bpemodel,
-                non_linguistic_symbols=args.non_linguistic_symbols,
+                non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None,
                 text_cleaner=args.cleaner,
                 g2p_type=args.g2p,
                 split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
@@ -827,9 +827,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         model_dict = dict()
@@ -1074,9 +1074,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         model_dict = dict()
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 45e4ce7..2625fec 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -58,7 +58,7 @@
 from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
 from funasr.tasks.abs_task import AbsTask
 from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
 from funasr.utils.types import float_or_none
@@ -114,7 +114,7 @@
         sond=DiarSondModel,
         eend_ola=DiarEENDOLAModel,
     ),
-    type_check=AbsESPnetModel,
+    type_check=FunASRModel,
     default="sond",
 )
 encoder_choices = ClassChoices(
@@ -544,9 +544,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         model_dict = dict()
@@ -902,9 +902,9 @@
             args = yaml.safe_load(f)
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         if model_file is not None:
             if device == "cuda":
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index 80d66d5..1e48655 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -206,6 +206,3 @@
         # 3. Initialize
         if args.init is not None:
             initialize(model, args.init)
-
-        assert check_return_type(model)
-        return model
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index 0170f28..a63bbe4 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -14,7 +14,6 @@
 
 from funasr.datasets.collate_fn import CommonCollateFn
 from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
-from funasr.train.abs_model import AbsPunctuation
 from funasr.train.abs_model import PunctuationModel
 from funasr.models.target_delay_transformer import TargetDelayTransformer
 from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
@@ -31,7 +30,6 @@
 punc_choices = ClassChoices(
     "punctuation",
     classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
-    type_check=AbsPunctuation,
     default="target_delay",
 )
 
diff --git a/funasr/tasks/sv.py b/funasr/tasks/sv.py
index 9710447..e4815da 100644
--- a/funasr/tasks/sv.py
+++ b/funasr/tasks/sv.py
@@ -25,7 +25,7 @@
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.layers.global_mvn import GlobalMVN
 from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.e2e_asr import ESPnetASRModel
+from funasr.models.e2e_asr import ASRModel
 from funasr.models.decoder.abs_decoder import AbsDecoder
 from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.encoder.rnn_encoder import RNNEncoder
@@ -49,7 +49,7 @@
 from funasr.models.specaug.specaug import SpecAug
 from funasr.tasks.abs_task import AbsTask
 from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
 from funasr.utils.types import float_or_none
@@ -94,7 +94,7 @@
     classes=dict(
         espnet=ESPnetSVModel,
     ),
-    type_check=AbsESPnetModel,
+    type_check=FunASRModel,
     default="espnet",
 )
 preencoder_choices = ClassChoices(
@@ -488,9 +488,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         model_dict = dict()
diff --git a/funasr/tasks/vad.py b/funasr/tasks/vad.py
index d07acf1..ec95596 100644
--- a/funasr/tasks/vad.py
+++ b/funasr/tasks/vad.py
@@ -1,77 +1,42 @@
 import argparse
 import logging
+import os
+from pathlib import Path
 from typing import Callable
 from typing import Collection
 from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Tuple
-import os
-from pathlib import Path
-from typing import Tuple
 from typing import Union
-import yaml
+
 import numpy as np
 import torch
+import yaml
 from typeguard import check_argument_types
 from typeguard import check_return_type
 
 from funasr.datasets.collate_fn import CommonCollateFn
-from funasr.datasets.preprocessor import CommonPreprocessor
-from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.rnn_decoder import RNNDecoder
-from funasr.models.decoder.transformer_decoder import (
-    DynamicConvolution2DTransformerDecoder,  # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
-from funasr.models.decoder.transformer_decoder import (
-    LightweightConvolution2DTransformerDecoder,  # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import (
-    LightweightConvolutionTransformerDecoder,  # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import TransformerDecoder
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.transformer_encoder import TransformerEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.postencoder.hugging_face_transformers_postencoder import (
-    HuggingFaceTransformersPostEncoder,  # noqa: H301
-)
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.preencoder.linear import LinearProjection
-from funasr.models.preencoder.sinc import LightweightSincConvs
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.specaug.specaug import SpecAug
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.layers.global_mvn import GlobalMVN
 from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
-from funasr.train.abs_espnet_model import AbsESPnetModel
-from funasr.train.class_choices import ClassChoices
-from funasr.train.trainer import Trainer
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import float_or_none
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2
-from funasr.modules.subsampling import Conv1dSubsampling
 from funasr.models.e2e_vad import E2EVadModel
 from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.tasks.abs_task import AbsTask
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str_or_none
 
 frontend_choices = ClassChoices(
     name="frontend",
@@ -292,7 +257,7 @@
             model_class = model_choices.get_class(args.model)
         except AttributeError:
             model_class = model_choices.get_class("e2evad")
-        
+
         # 1. frontend
         if args.input_size is None:
             # Extract features in the model
@@ -308,7 +273,7 @@
             args.frontend_conf = {}
             frontend = None
             input_size = args.input_size
-        
+
         model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
 
         return model
@@ -344,7 +309,7 @@
 
         with config_file.open("r", encoding="utf-8") as f:
             args = yaml.safe_load(f)
-        #if cmvn_file is not None:
+        # if cmvn_file is not None:
         args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
diff --git a/funasr/train/abs_espnet_model.py b/funasr/train/abs_espnet_model.py
deleted file mode 100644
index cc6a5a2..0000000
--- a/funasr/train/abs_espnet_model.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
-from abc import ABC
-from abc import abstractmethod
-from typing import Dict
-from typing import Tuple
-
-import torch
-
-
-class AbsESPnetModel(torch.nn.Module, ABC):
-    """The common abstract class among each tasks
-
-    "ESPnetModel" is referred to a class which inherits torch.nn.Module,
-    and makes the dnn-models forward as its member field,
-    a.k.a delegate pattern,
-    and defines "loss", "stats", and "weight" for the task.
-
-    If you intend to implement new task in ESPNet,
-    the model must inherit this class.
-    In other words, the "mediator" objects between
-    our training system and the your task class are
-    just only these three values, loss, stats, and weight.
-
-    Example:
-        >>> from funasr.tasks.abs_task import AbsTask
-        >>> class YourESPnetModel(AbsESPnetModel):
-        ...     def forward(self, input, input_lengths):
-        ...         ...
-        ...         return loss, stats, weight
-        >>> class YourTask(AbsTask):
-        ...     @classmethod
-        ...     def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
-    """
-
-    def __init__(self):
-        super().__init__()
-        self.num_updates = 0
-
-    @abstractmethod
-    def forward(
-        self, **batch: torch.Tensor
-    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
-        raise NotImplementedError
-
-    def set_num_updates(self, num_updates):
-        self.num_updates = num_updates
-
-    def get_num_updates(self):
-        return self.num_updates
diff --git a/funasr/train/abs_model.py b/funasr/train/abs_model.py
index 1c7ff3d..026140b 100644
--- a/funasr/train/abs_model.py
+++ b/funasr/train/abs_model.py
@@ -12,34 +12,12 @@
 
 from funasr.modules.nets_utils import make_pad_mask
 from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
-
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.models.base_model import FunASRModel
 
 
-class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
-    """The abstract class
-
-    To share the loss calculation way among different models,
-    We uses delegate pattern here:
-    The instance of this class should be passed to "LanguageModel"
-
-    This "model" is one of mediator objects for "Task" class.
-
-    """
-
-    @abstractmethod
-    def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def with_vad(self) -> bool:
-        raise NotImplementedError
-
-
-class PunctuationModel(AbsESPnetModel):
+class PunctuationModel(FunASRModel):
     
-    def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
+    def __init__(self, punc_model: torch.nn.Module, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
         assert check_argument_types()
         super().__init__()
         self.punc_model = punc_model
diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index a40f031..4052448 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -39,7 +39,7 @@
 from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.recursive_op import recursive_average
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.train.distributed_utils import DistributedOption
 from funasr.train.reporter import Reporter
 from funasr.train.reporter import SubReporter
@@ -166,7 +166,7 @@
     @classmethod
     def run(
         cls,
-        model: AbsESPnetModel,
+        model: FunASRModel,
         optimizers: Sequence[torch.optim.Optimizer],
         schedulers: Sequence[Optional[AbsScheduler]],
         train_iter_factory: AbsIterFactory,
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
new file mode 100644
index 0000000..3f55170
--- /dev/null
+++ b/funasr/utils/prepare_data.py
@@ -0,0 +1,209 @@
+import logging
+import os
+import shutil
+from multiprocessing import Pool
+
+import numpy as np
+import torch.distributed as dist
+import torchaudio
+
+
+def filter_wav_text(data_dir, dataset):
+    wav_file = os.path.join(data_dir, dataset, "wav.scp")
+    text_file = os.path.join(data_dir, dataset, "text")
+    with open(wav_file) as f_wav, open(text_file) as f_text:
+        wav_lines = f_wav.readlines()
+        text_lines = f_text.readlines()
+    os.rename(wav_file, "{}.bak".format(wav_file))
+    os.rename(text_file, "{}.bak".format(text_file))
+    wav_dict = {}
+    for line in wav_lines:
+        parts = line.strip().split()
+        if len(parts) < 2:
+            continue
+        wav_dict[parts[0]] = parts[1]
+    text_dict = {}
+    for line in text_lines:
+        parts = line.strip().split()
+        if len(parts) < 2:
+            continue
+        text_dict[parts[0]] = " ".join(parts[1:])
+    filter_count = 0
+    with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
+        for sample_name, wav_path in wav_dict.items():
+            if sample_name in text_dict.keys():
+                f_wav.write(sample_name + " " + wav_path + "\n")
+                f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
+            else:
+                filter_count += 1
+    logging.info("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".
+                 format(filter_count, len(wav_lines), dataset))
+
+
+def wav2num_frame(wav_path, frontend_conf):
+    waveform, sampling_rate = torchaudio.load(wav_path)
+    n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
+    feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
+    return n_frames, feature_dim
+
+
+def calc_shape_core(root_path, args, idx):
+    wav_scp_file = os.path.join(root_path, "wav.scp.{}".format(idx))
+    shape_file = os.path.join(root_path, "speech_shape.{}".format(idx))
+    with open(wav_scp_file) as f:
+        lines = f.readlines()
+    frontend_conf = args.frontend_conf
+    dataset_conf = args.dataset_conf
+    speech_length_min = dataset_conf.speech_length_min if hasattr(dataset_conf, "speech_length_min") else -1
+    speech_length_max = dataset_conf.speech_length_max if hasattr(dataset_conf, "speech_length_max") else -1
+    with open(shape_file, "w") as f:
+        for line in lines:
+            sample_name, wav_path = line.strip().split()
+            n_frames, feature_dim = wav2num_frame(wav_path, frontend_conf)
+            write_flag = True
+            if n_frames > 0 and speech_length_min > 0:
+                write_flag = n_frames >= speech_length_min
+            if n_frames > 0 and speech_length_max > 0:
+                write_flag = n_frames <= speech_length_max
+            if write_flag:
+                f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
+                f.flush()
+
+
+def calc_shape(args, dataset, nj=64):
+    shape_path = os.path.join(args.data_dir, dataset, "speech_shape")
+    if os.path.exists(shape_path):
+        logging.info('Shape file for small dataset already exists.')
+        return
+
+    split_shape_path = os.path.join(args.data_dir, dataset, "shape_files")
+    if os.path.exists(split_shape_path):
+        shutil.rmtree(split_shape_path)
+    os.mkdir(split_shape_path)
+
+    # split
+    wav_scp_file = os.path.join(args.data_dir, dataset, "wav.scp")
+    with open(wav_scp_file) as f:
+        lines = f.readlines()
+        num_lines = len(lines)
+        num_job_lines = num_lines // nj
+    start = 0
+    for i in range(nj):
+        end = start + num_job_lines
+        file = os.path.join(split_shape_path, "wav.scp.{}".format(str(i + 1)))
+        with open(file, "w") as f:
+            if i == nj - 1:
+                f.writelines(lines[start:])
+            else:
+                f.writelines(lines[start:end])
+        start = end
+
+    p = Pool(nj)
+    for i in range(nj):
+        p.apply_async(calc_shape_core, args=(split_shape_path, args, str(i + 1)))
+    logging.info("Generating shape files, please wait a few minutes...")
+    p.close()
+    p.join()
+
+    # combine
+    with open(shape_path, "w") as f:
+        for i in range(nj):
+            job_file = os.path.join(split_shape_path, "speech_shape.{}".format(str(i + 1)))
+            with open(job_file) as job_f:
+                lines = job_f.readlines()
+                f.writelines(lines)
+    logging.info('Generating shape files done.')
+
+
+def generate_data_list(data_dir, dataset, nj=64):
+    list_file = os.path.join(data_dir, dataset, "data.list")
+    if os.path.exists(list_file):
+        logging.info('Data list for large dataset already exists.')
+        return
+    split_path = os.path.join(data_dir, dataset, "split")
+    if os.path.exists(split_path):
+        shutil.rmtree(split_path)
+    os.mkdir(split_path)
+
+    with open(os.path.join(data_dir, dataset, "wav.scp")) as f_wav:
+        wav_lines = f_wav.readlines()
+    with open(os.path.join(data_dir, dataset, "text")) as f_text:
+        text_lines = f_text.readlines()
+    num_lines = len(wav_lines)
+    num_job_lines = num_lines // nj
+    start = 0
+    for i in range(nj):
+        end = start + num_job_lines
+        split_path_nj = os.path.join(split_path, str(i + 1))
+        os.mkdir(split_path_nj)
+        wav_file = os.path.join(split_path_nj, "wav.scp")
+        text_file = os.path.join(split_path_nj, "text")
+        with open(wav_file, "w") as fw, open(text_file, "w") as ft:
+            if i == nj - 1:
+                fw.writelines(wav_lines[start:])
+                ft.writelines(text_lines[start:])
+            else:
+                fw.writelines(wav_lines[start:end])
+                ft.writelines(text_lines[start:end])
+        start = end
+
+    with open(list_file, "w") as f_data:
+        for i in range(nj):
+            wav_path = os.path.join(split_path, str(i + 1), "wav.scp")
+            text_path = os.path.join(split_path, str(i + 1), "text")
+            f_data.write(wav_path + " " + text_path + "\n")
+
+
+def prepare_data(args, distributed_option):
+    distributed = distributed_option.distributed
+    if not distributed or distributed_option.dist_rank == 0:
+        filter_wav_text(args.data_dir, args.train_set)
+        filter_wav_text(args.data_dir, args.valid_set)
+
+        if args.dataset_type == "small":
+            calc_shape(args, args.train_set)
+            calc_shape(args, args.valid_set)
+
+        if args.dataset_type == "large":
+            generate_data_list(args.data_dir, args.train_set)
+            generate_data_list(args.data_dir, args.valid_set)
+
+    if args.dataset_type == "small":
+        args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
+        args.valid_shape_file = [os.path.join(args.data_dir, args.valid_set, "speech_shape")]
+        data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
+        data_types = args.dataset_conf.get("data_types", "sound,text").split(",")
+        args.train_data_path_and_name_and_type = [
+            ["{}/{}/wav.scp".format(args.data_dir, args.train_set), data_names[0], data_types[0]],
+            ["{}/{}/text".format(args.data_dir, args.train_set), data_names[1], data_types[1]]
+        ]
+        args.valid_data_path_and_name_and_type = [
+            ["{}/{}/wav.scp".format(args.data_dir, args.valid_set), data_names[0], data_types[0]],
+            ["{}/{}/text".format(args.data_dir, args.valid_set), data_names[1], data_types[1]]
+        ]
+        if args.embed_path is not None:
+            args.train_data_path_and_name_and_type.append(
+                [os.path.join(args.embed_path, "embeds", args.train_set, "embeds.scp"), "embed", "kaldi_ark"])
+            args.valid_data_path_and_name_and_type.append(
+                [os.path.join(args.embed_path, "embeds", args.valid_set, "embeds.scp"), "embed", "kaldi_ark"])
+    else:
+        args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
+        args.valid_data_file = os.path.join(args.data_dir, args.valid_set, "data.list")
+        if args.embed_path is not None:
+            if not distributed or distributed_option.dist_rank == 0:
+                for d in [args.train_set, args.valid_set]:
+                    file = os.path.join(args.data_dir, d, "data.list")
+                    with open(file) as f:
+                        lines = f.readlines()
+                    out_file = os.path.join(args.data_dir, d, "data_with_embed.list")
+                    with open(out_file, "w") as out_f:
+                        for line in lines:
+                            parts = line.strip().split()
+                            idx = parts[0].split("/")[-2]
+                            embed_file = os.path.join(args.embed_path, "embeds", args.valid_set, "ark",
+                                                      "embeds.{}.ark".format(idx))
+                            out_f.write(parts[0] + " " + parts[1] + " " + embed_file + "\n")
+            args.train_data_file = os.path.join(args.data_dir, args.train_set, "data_with_embed.list")
+            args.valid_data_file = os.path.join(args.data_dir, args.valid_set, "data_with_embed.list")
+    if distributed:
+        dist.barrier()

--
Gitblit v1.9.1