From de0ecb446fa429d210397949694d8d9ad6d66112 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 09 二月 2023 10:57:25 +0800
Subject: [PATCH] Merge pull request #79 from alibaba-damo-academy/dev_wjm

---
 egs/aishell/data2vec_transformer_finetune/run.sh                                                    |    0 
 funasr/datasets/collate_fn.py                                                                       |   52 +
 funasr/optimizers/fairseq_adam.py                                                                   |  148 ++++
 funasr/datasets/large_datasets/utils/filter.py                                                      |   24 
 egs/aishell/data2vec_paraformer_finetune/utils                                                      |    1 
 egs/aishell/data2vec_transformer_finetune/local/aishell_data_prep.sh                                |    0 
 funasr/datasets/large_datasets/utils/clipping.py                                                    |   40 +
 funasr/datasets/large_datasets/build_dataloader.py                                                  |    5 
 egs/aishell2/data2vec_pretrain/path.sh                                                              |    6 
 egs/aishell/data2vec_paraformer_finetune/conf/decode_asr_transformer_noctc_1best.yaml               |    6 
 egs/aishell/data2vec_transformer_finetune/path.sh                                                   |    0 
 funasr/bin/asr_inference_paraformer.py                                                              |    2 
 egs/aishell/data2vec_transformer_finetune/conf/decode_asr_transformer.yaml                          |    0 
 egs/aishell2/data2vec_pretrain/local/prepare_data.sh                                                |   53 +
 egs/aishell2/data2vec_pretrain/run.sh                                                               |  172 +++++
 funasr/models/data2vec.py                                                                           |  160 +++++
 egs/aishell/data2vec_paraformer_finetune/local/aishell_data_prep.sh                                 |    0 
 egs/aishell/data2vec_transformer_finetune/utils                                                     |    0 
 funasr/datasets/large_datasets/datapipes/batch.py                                                   |  219 ++++--
 egs/aishell2/data2vec_pretrain/utils                                                                |    1 
 funasr/bin/data2vec_train.py                                                                        |   45 +
 egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh                                     |    0 
 egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml                                 |   79 ++
 funasr/tasks/data2vec.py                                                                            |  376 +++++++++++
 egs/aishell/data2vec_paraformer_finetune/path.sh                                                    |    0 
 funasr/tasks/abs_task.py                                                                            |    4 
 egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml |  104 +++
 funasr/datasets/large_datasets/dataset.py                                                           |   18 
 egs/aishell/data2vec_paraformer_finetune/run.sh                                                     |  252 +++++++
 egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml           |    0 
 funasr/schedulers/tri_stage_scheduler.py                                                            |  108 +++
 egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh                                      |    0 
 32 files changed, 1,781 insertions(+), 94 deletions(-)

diff --git a/egs/aishell/data2vec_paraformer_finetune/conf/decode_asr_transformer_noctc_1best.yaml b/egs/aishell/data2vec_paraformer_finetune/conf/decode_asr_transformer_noctc_1best.yaml
new file mode 100644
index 0000000..5436b12
--- /dev/null
+++ b/egs/aishell/data2vec_paraformer_finetune/conf/decode_asr_transformer_noctc_1best.yaml
@@ -0,0 +1,6 @@
+beam_size: 1
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.0
+lm_weight: 0.15
diff --git a/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml b/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
new file mode 100644
index 0000000..f9a2cdb
--- /dev/null
+++ b/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
@@ -0,0 +1,104 @@
+# network architecture
+# encoder related
+encoder: data2vec_encoder
+encoder_conf:
+    extractor_mode: layer_norm
+    encoder_layerdrop: 0.1
+    dropout_input: 0.0
+    dropout_features: 0.0
+    feature_grad_mult: 0.0
+    encoder_embed_dim: 768
+
+    mask_prob: 0.65
+    mask_length: 10
+
+    loss_beta: 0
+    loss_scale: null
+
+    instance_norm_target_layer: true
+    average_top_k_layers: 8
+
+    pos_conv_depth: 5
+    conv_pos: 95
+
+    ema_decay: 0.999
+    ema_end_decay: 0.9999
+    ema_anneal_end_step: 30000
+    ema_transformer_only: true
+    ema_layers_only: true
+
+    require_same_masks: true
+    mask_dropout: 0
+
+
+# decoder related
+decoder: paraformer_decoder_san
+decoder_conf:
+    attention_heads: 12
+    linear_units: 3072
+    num_blocks: 6
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.0
+    src_attention_dropout_rate: 0.0
+
+model: paraformer
+model_conf:
+    ctc_weight: 0.3
+    lsm_weight: 0.1
+    length_normalized_loss: false
+    predictor_weight: 1.0
+    sampling_ratio: 0.4
+
+# minibatch related
+batch_type: length
+batch_bins: 25000
+num_workers: 16
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 50
+val_scheduler_criterion:
+    - valid
+    - acc
+best_model_criterion:
+-   - valid
+    - acc
+    - max
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+   lr: 0.00002
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+specaug: specaug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 40
+    num_time_mask: 2
+
+predictor: cif_predictor
+predictor_conf:
+  idim: 768
+  threshold: 1.0
+  l_order: 1
+  r_order: 1
+
+
+log_interval: 50
+unused_parameters: true
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell/data2vec_finetune/local/aishell_data_prep.sh b/egs/aishell/data2vec_paraformer_finetune/local/aishell_data_prep.sh
similarity index 100%
rename from egs/aishell/data2vec_finetune/local/aishell_data_prep.sh
rename to egs/aishell/data2vec_paraformer_finetune/local/aishell_data_prep.sh
diff --git a/egs/aishell/data2vec_finetune/local/prepare_data.sh b/egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh
similarity index 100%
rename from egs/aishell/data2vec_finetune/local/prepare_data.sh
rename to egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh
diff --git a/egs/aishell/data2vec_finetune/path.sh b/egs/aishell/data2vec_paraformer_finetune/path.sh
similarity index 100%
rename from egs/aishell/data2vec_finetune/path.sh
rename to egs/aishell/data2vec_paraformer_finetune/path.sh
diff --git a/egs/aishell/data2vec_paraformer_finetune/run.sh b/egs/aishell/data2vec_paraformer_finetune/run.sh
new file mode 100755
index 0000000..cada164
--- /dev/null
+++ b/egs/aishell/data2vec_paraformer_finetune/run.sh
@@ -0,0 +1,252 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=2
+count=1
+gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="../DATA" #feature output dictionary, for large data
+exp_dir="."
+lang=zh
+dumpdir=dump/fbank
+feats_type=fbank
+token_type=char
+scp=feats.scp
+type=kaldi_ark
+stage=0
+stop_stage=4
+
+# feature configuration
+feats_dim=80
+sample_frequency=16000
+nj=32
+speed_perturb="0.9,1.0,1.1"
+
+# data
+data_aishell=
+
+# exp tag
+tag=""
+
+model_name=damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch
+init_param="$HOME/.cache/modelscope/hub/$model_name/basemodel.pb"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train
+valid_set=dev
+test_sets="dev test"
+
+asr_config=conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+
+inference_config=conf/decode_asr_transformer_noctc_1best.yaml
+inference_asr_model=valid.acc.ave_10best.pth
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+    inference_nj=$[${ngpu}*${njob}]
+    _ngpu=1
+else
+    inference_nj=$njob
+    _ngpu=0
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+    echo "stage 0: Data preparation"
+    # Data preparation
+    local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+    for x in train dev test; do
+        cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+        paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
+            > ${feats_dir}/data/${x}/text
+        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+        mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
+    done
+fi
+
+feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
+feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
+feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    echo "stage 1: Feature Generation"
+    # compute fbank features
+    fbankdir=${feats_dir}/fbank
+    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
+        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
+    utils/fix_data_feat.sh ${fbankdir}/train
+    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
+        ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
+    utils/fix_data_feat.sh ${fbankdir}/dev
+    utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
+        ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
+    utils/fix_data_feat.sh ${fbankdir}/test
+     
+    # compute global cmvn
+    utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
+        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
+
+    # apply cmvn 
+    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+        ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
+    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+        ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
+    utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+        ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
+    
+    cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
+    cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
+    cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
+
+    utils/fix_data_feat.sh ${feat_train_dir}
+    utils/fix_data_feat.sh ${feat_dev_dir}
+    utils/fix_data_feat.sh ${feat_test_dir}
+
+    #generate ark list 
+    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
+    utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+fi
+
+token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+    echo "stage 2: Dictionary Preparation"
+    mkdir -p ${feats_dir}/data/${lang}_token_list/char/
+   
+    echo "make a dictionary"
+    echo "<blank>" > ${token_list}
+    echo "<s>" >> ${token_list}
+    echo "</s>" >> ${token_list}
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+    num_token=$(cat ${token_list} | wc -l)
+    echo "<unk>" >> ${token_list}
+    vocab_size=$(cat ${token_list} | wc -l)
+    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
+    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
+    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train 
+    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
+    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
+    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
+fi
+
+# Training Stage
+world_size=$gpu_num  # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    echo "stage 3: Training"
+    python utils/download_model.py  --model_name ${model_name}  # download pretrained model on ModelScope
+    mkdir -p ${exp_dir}/exp/${model_dir}
+    mkdir -p ${exp_dir}/exp/${model_dir}/log
+    INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+    if [ -f $INIT_FILE ];then
+        rm -f $INIT_FILE
+    fi 
+    init_method=file://$(readlink -f $INIT_FILE)
+    echo "$0: init method is $init_method"
+    for ((i = 0; i < $gpu_num; ++i)); do
+        {
+            rank=$i
+            local_rank=$i
+            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+            asr_train_paraformer.py \
+                --gpu_id $gpu_id \
+                --use_preprocessor true \
+                --token_type char \
+                --token_list $token_list \
+                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
+                --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
+                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
+                --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
+                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
+                --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
+                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
+                --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char  \
+                --init_param ${init_param} \
+                --resume true \
+                --output_dir ${exp_dir}/exp/${model_dir} \
+                --config $asr_config \
+                --input_size $feats_dim \
+                --ngpu $gpu_num \
+                --num_worker_count $count \
+                --multiprocessing_distributed true \
+                --dist_init_method $init_method \
+                --dist_world_size $world_size \
+                --dist_rank $rank \
+                --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+        } &
+        done
+        wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+    echo "stage 4: Inference"
+    for dset in ${test_sets}; do
+        asr_exp=${exp_dir}/exp/${model_dir}
+        inference_tag="$(basename "${inference_config}" .yaml)"
+        _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+        _logdir="${_dir}/logdir"
+        if [ -d ${_dir} ]; then
+            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+            exit 0
+        fi
+        mkdir -p "${_logdir}"
+        _data="${feats_dir}/${dumpdir}/${dset}"
+        key_file=${_data}/${scp}
+        num_scp_file="$(<${key_file} wc -l)"
+        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+        split_scps=
+        for n in $(seq "${_nj}"); do
+            split_scps+=" ${_logdir}/keys.${n}.scp"
+        done
+        # shellcheck disable=SC2086
+        utils/split_scp.pl "${key_file}" ${split_scps}
+        _opts=
+        if [ -n "${inference_config}" ]; then
+            _opts+="--config ${inference_config} "
+        fi
+        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+            python -m funasr.bin.asr_inference_launch \
+                --batch_size 1 \
+                --ngpu "${_ngpu}" \
+                --njob ${njob} \
+                --gpuid_list ${gpuid_list} \
+                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --key_file "${_logdir}"/keys.JOB.scp \
+                --asr_train_config "${asr_exp}"/config.yaml \
+                --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+                --output_dir "${_logdir}"/output.JOB \
+                --mode paraformer \
+                ${_opts}
+
+        for f in token token_int score text; do
+            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+                for i in $(seq "${_nj}"); do
+                    cat "${_logdir}/output.${i}/1best_recog/${f}"
+                done | sort -k1 >"${_dir}/${f}"
+            fi
+        done
+        python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
+        python utils/proce_text.py ${_data}/text ${_data}/text.proc
+        python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+        tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+        cat ${_dir}/text.cer.txt
+    done
+fi
diff --git a/egs/aishell/data2vec_paraformer_finetune/utils b/egs/aishell/data2vec_paraformer_finetune/utils
new file mode 120000
index 0000000..fe070dd
--- /dev/null
+++ b/egs/aishell/data2vec_paraformer_finetune/utils
@@ -0,0 +1 @@
+../../aishell/transformer/utils
\ No newline at end of file
diff --git a/egs/aishell/data2vec_finetune/conf/decode_asr_transformer.yaml b/egs/aishell/data2vec_transformer_finetune/conf/decode_asr_transformer.yaml
similarity index 100%
rename from egs/aishell/data2vec_finetune/conf/decode_asr_transformer.yaml
rename to egs/aishell/data2vec_transformer_finetune/conf/decode_asr_transformer.yaml
diff --git a/egs/aishell/data2vec_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml b/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
similarity index 100%
rename from egs/aishell/data2vec_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
rename to egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
diff --git a/egs/aishell/data2vec_finetune/local/aishell_data_prep.sh b/egs/aishell/data2vec_transformer_finetune/local/aishell_data_prep.sh
similarity index 100%
copy from egs/aishell/data2vec_finetune/local/aishell_data_prep.sh
copy to egs/aishell/data2vec_transformer_finetune/local/aishell_data_prep.sh
diff --git a/egs/aishell/data2vec_finetune/local/prepare_data.sh b/egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
similarity index 100%
copy from egs/aishell/data2vec_finetune/local/prepare_data.sh
copy to egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
diff --git a/egs/aishell/data2vec_finetune/path.sh b/egs/aishell/data2vec_transformer_finetune/path.sh
similarity index 100%
copy from egs/aishell/data2vec_finetune/path.sh
copy to egs/aishell/data2vec_transformer_finetune/path.sh
diff --git a/egs/aishell/data2vec_finetune/run.sh b/egs/aishell/data2vec_transformer_finetune/run.sh
similarity index 100%
rename from egs/aishell/data2vec_finetune/run.sh
rename to egs/aishell/data2vec_transformer_finetune/run.sh
diff --git a/egs/aishell/data2vec_finetune/utils b/egs/aishell/data2vec_transformer_finetune/utils
similarity index 100%
rename from egs/aishell/data2vec_finetune/utils
rename to egs/aishell/data2vec_transformer_finetune/utils
diff --git a/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
new file mode 100644
index 0000000..4052774
--- /dev/null
+++ b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
@@ -0,0 +1,79 @@
+# network architecture
+# encoder related
+encoder: data2vec_encoder
+encoder_conf:
+  extractor_mode: layer_norm
+  encoder_layerdrop: 0.05
+  dropout_input: 0.0
+  dropout_features: 0.0
+  feature_grad_mult: 1.0
+  encoder_embed_dim: 768
+
+  mask_prob: 0.65
+  mask_length: 10
+
+  loss_beta: 0
+  loss_scale: null
+
+  instance_norm_target_layer: true
+  average_top_k_layers: 8
+
+  pos_conv_depth: 5
+  conv_pos: 95
+
+  ema_decay: 0.999
+  ema_end_decay: 0.9999
+  ema_anneal_end_step: 30000
+  ema_transformer_only: true
+  ema_layers_only: true
+
+  require_same_masks: true
+  mask_dropout: 0
+
+log_interval: 50
+normalize: None
+
+# minibatch related
+batch_type: length
+batch_bins: 64000
+num_workers: 16
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+patience: none
+max_epoch: 600
+val_scheduler_criterion:
+    - valid
+    - acc
+best_model_criterion:
+-   - valid
+    - loss
+    - min
+keep_nbest_models: 50
+unused_parameters: true
+
+optim: fairseq_adam
+optim_conf:
+    lr: 0.0005
+    adam_betas: [0.9,0.98]
+    adam_eps: 1.0e-06
+    weight_decay: 0.01
+
+scheduler: tri_stage
+scheduler_conf:
+    phase_ratio: [0.03,0.9,0.07]
+
+# for dataset
+dataset_conf:
+    batch_mode: clipping
+    data_names: speech,none
+    data_types: kaldi_ark,none
+    shuffle: true
+    shuffle_conf:
+        shuffle_size: 12800
+        sort_size: 12800
+    batch_conf:
+        batch_type: token
+        batch_size: 64000
+    num_workers: 8
\ No newline at end of file
diff --git a/egs/aishell2/data2vec_pretrain/local/prepare_data.sh b/egs/aishell2/data2vec_pretrain/local/prepare_data.sh
new file mode 100755
index 0000000..ce6ee19
--- /dev/null
+++ b/egs/aishell2/data2vec_pretrain/local/prepare_data.sh
@@ -0,0 +1,53 @@
+#!/usr/bin/env bash
+# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
+#           2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
+# Apache 2.0
+
+# transform raw AISHELL-2 data to kaldi format
+
+. ./path.sh || exit 1;
+
+tmp=
+dir=
+
+if [ $# != 3 ]; then
+  echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
+  echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
+  exit 1;
+fi
+
+corpus=$1
+tmp=$2
+dir=$3
+
+echo "prepare_data.sh: Preparing data in $corpus"
+
+mkdir -p $tmp
+mkdir -p $dir
+
+# corpus check
+if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
+  echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
+  exit 1;
+fi
+
+# validate utt-key list, IC0803W0380 is a bad utterance
+awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
+awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
+tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+
+# wav.scp
+awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
+tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+
+# text
+tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+
+# copy prepared resources from tmp_dir to target dir
+mkdir -p $dir
+for f in wav.scp text; do
+  cp $tmp/$f $dir/$f || exit 1;
+done
+
+echo "local/prepare_data.sh succeeded"
+exit 0;
diff --git a/egs/aishell2/data2vec_pretrain/path.sh b/egs/aishell2/data2vec_pretrain/path.sh
new file mode 100755
index 0000000..ea3c0be
--- /dev/null
+++ b/egs/aishell2/data2vec_pretrain/path.sh
@@ -0,0 +1,6 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PYTHONPATH=../../../:$PYTHONPATH
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/egs/aishell2/data2vec_pretrain/run.sh b/egs/aishell2/data2vec_pretrain/run.sh
new file mode 100755
index 0000000..eceb183
--- /dev/null
+++ b/egs/aishell2/data2vec_pretrain/run.sh
@@ -0,0 +1,172 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+gpu_num=8
+count=1
+
+train_cmd=tools/run.pl
+
+# general configuration
+feats_dir="../DATA" #feature output dictionary
+exp_dir="."
+lang=zh
+dumpdir=dump/fbank
+feats_type=fbank
+token_type=char
+dataset_type=large
+stage=0
+stop_stage=4
+
+# feature configuration
+feats_dim=80
+sample_frequency=16000
+nj=100
+speed_perturb="0.9,1.0,1.1"
+
+# data
+tr_dir=
+dev_tst_dir=
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train
+valid_set=dev_ios
+
+asr_config=conf/train_pretrain_transformer.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+    echo "stage 0: Data preparation"
+    # For training set
+    local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
+    # # For dev and test set
+    for x in Android iOS Mic; do
+        local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
+        local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
+    done 
+    # Normalize text to capital letters
+    for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
+        mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+        paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
+             | tr 'A-Z' 'a-z' | tr -d " ") \
+            > ${feats_dir}/data/${x}/text
+        tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+        mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
+    done
+fi
+
+feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
+feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    echo "stage 1: Feature Generation"
+    # compute fbank features
+    fbankdir=${feats_dir}/fbank
+    steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
+        ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
+    tools/fix_data_feat.sh ${fbankdir}/train
+    for x in android ios mic; do
+        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
+            ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
+        tools/fix_data_feat.sh ${fbankdir}/dev_${x}
+        steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
+            ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
+        tools/fix_data_feat.sh ${fbankdir}/test_${x}
+    done
+    
+    # compute global cmvn
+    steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
+        ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
+
+    # apply cmvn 
+    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+        ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
+    steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+        ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
+    for x in android ios mic; do
+        steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+            ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
+    done
+    
+    cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
+    tools/fix_data_feat.sh ${feat_train_dir}
+    cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
+    tools/fix_data_feat.sh ${feat_dev_dir}
+    for x in android ios mic; do
+        cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
+        tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
+    done
+fi
+
+token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+    echo "stage 2: Dictionary Preparation"
+    mkdir -p ${feats_dir}/data/${lang}_token_list/char/
+   
+    echo "make a dictionary"
+    echo "<blank>" > ${token_list}
+    echo "<s>" >> ${token_list}
+    echo "</s>" >> ${token_list}
+    tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+    num_token=$(cat ${token_list} | wc -l)
+    echo "<unk>" >> ${token_list}
+    vocab_size=$(cat ${token_list} | wc -l)
+    awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
+    awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
+    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
+    mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
+    cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set} 
+    cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
+fi
+
+# Training Stage
+world_size=$gpu_num  # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    echo "stage 3: Training"
+    mkdir -p ${exp_dir}/exp/${model_dir}
+    mkdir -p ${exp_dir}/exp/${model_dir}/log
+    INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+    if [ -f $INIT_FILE ];then
+        rm -f $INIT_FILE
+    fi
+    init_method=file://$(readlink -f $INIT_FILE)
+    echo "$0: init method is $init_method"
+    for ((i = 0; i < $gpu_num; ++i)); do
+        {
+            rank=$i
+            local_rank=$i
+            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+            data2vec_train.py \
+                --gpu_id $gpu_id \
+                --use_preprocessor true \
+                --dataset_type $dataset_type \
+                --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
+                --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
+                --resume true \
+                --output_dir ${exp_dir}/exp/${model_dir} \
+                --config $asr_config \
+                --input_size $feats_dim \
+                --ngpu $gpu_num \
+                --num_worker_count $count \
+                --multiprocessing_distributed true \
+                --dist_init_method $init_method \
+                --dist_world_size $world_size \
+                --dist_rank $rank \
+                --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+        } &
+      done
+      wait
+fi
\ No newline at end of file
diff --git a/egs/aishell2/data2vec_pretrain/utils b/egs/aishell2/data2vec_pretrain/utils
new file mode 120000
index 0000000..fe070dd
--- /dev/null
+++ b/egs/aishell2/data2vec_pretrain/utils
@@ -0,0 +1 @@
+../../aishell/transformer/utils
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 5d7d6ea..3769b6c 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -181,7 +181,7 @@
         self.nbest = nbest
         self.frontend = frontend
         self.encoder_downsampling_factor = 1
-        if asr_train_args.encoder_conf["input_layer"] == "conv2d":
+        if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
             self.encoder_downsampling_factor = 4
 
     @torch.no_grad()
diff --git a/funasr/bin/data2vec_train.py b/funasr/bin/data2vec_train.py
new file mode 100755
index 0000000..b9dbdff
--- /dev/null
+++ b/funasr/bin/data2vec_train.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python3
+
+import os
+
+from funasr.tasks.data2vec import Data2VecTask
+
+
+def parse_args():
+    parser = Data2VecTask.get_parser()
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def main(args=None, cmd=None):
+    # for data2vec Training
+    Data2VecTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    # setup local gpu_id
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+    # DDP settings
+    if args.ngpu > 1:
+        args.distributed = True
+    else:
+        args.distributed = False
+    assert args.num_worker_count == 1
+
+    # re-compute batch size: when dataset type is small
+    if args.dataset_type == "small":
+        if args.batch_size is not None:
+            args.batch_size = args.batch_size * args.ngpu
+        if args.batch_bins is not None:
+            args.batch_bins = args.batch_bins * args.ngpu
+
+    main(args=args)
diff --git a/funasr/datasets/collate_fn.py b/funasr/datasets/collate_fn.py
index d52032f..d34d610 100644
--- a/funasr/datasets/collate_fn.py
+++ b/funasr/datasets/collate_fn.py
@@ -80,4 +80,56 @@
 
     output = (uttids, output)
     assert check_return_type(output)
+    return output
+
+def crop_to_max_size(feature, target_size):
+    size = len(feature)
+    diff = size - target_size
+    if diff <= 0:
+        return feature
+
+    start = np.random.randint(0, diff + 1)
+    end = size - diff + start
+    return feature[start:end]
+
+
+def clipping_collate_fn(
+        data: Collection[Tuple[str, Dict[str, np.ndarray]]],
+        max_sample_size=None,
+        not_sequence: Collection[str] = (),
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+    # mainly for pre-training
+    assert check_argument_types()
+    uttids = [u for u, _ in data]
+    data = [d for _, d in data]
+
+    assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
+    assert all(
+        not k.endswith("_lengths") for k in data[0]
+    ), f"*_lengths is reserved: {list(data[0])}"
+
+    output = {}
+    for key in data[0]:
+        array_list = [d[key] for d in data]
+        tensor_list = [torch.from_numpy(a) for a in array_list]
+        sizes = [len(s) for s in tensor_list]
+        if max_sample_size is None:
+            target_size = min(sizes)
+        else:
+            target_size = min(min(sizes), max_sample_size)
+        tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
+        for i, (source, size) in enumerate(zip(tensor_list, sizes)):
+            diff = size - target_size
+            if diff == 0:
+                tensor[i] = source
+            else:
+                tensor[i] = crop_to_max_size(source, target_size)
+        output[key] = tensor
+
+        if key not in not_sequence:
+            lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
+            output[key + "_lengths"] = lens
+
+    output = (uttids, output)
+    assert check_return_type(output)
     return output
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 146723d..8f7fd0b 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -35,15 +35,16 @@
 
 class ArkDataLoader(AbsIterFactory):
     def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"):
-        symbol_table = read_symbol_table(dict_file)
+        symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
         if seg_dict_file is not None:
             seg_dict = load_seg_dict(seg_dict_file)
         else:
             seg_dict = None
         self.dataset_conf = dataset_conf
         logging.info("dataloader config: {}".format(self.dataset_conf))
+        batch_mode = self.dataset_conf.get("batch_mode", "padding")
         self.dataset = Dataset(data_list, symbol_table, seg_dict,
-                               self.dataset_conf, mode=mode)
+                               self.dataset_conf, mode=mode, batch_mode=batch_mode)
 
     def build_iter(self, epoch, shuffle=True):
         self.dataset.set_epoch(epoch)
diff --git a/funasr/datasets/large_datasets/datapipes/batch.py b/funasr/datasets/large_datasets/datapipes/batch.py
index 9c85d5e..8ec43e9 100644
--- a/funasr/datasets/large_datasets/datapipes/batch.py
+++ b/funasr/datasets/large_datasets/datapipes/batch.py
@@ -24,7 +24,8 @@
             batch_size=8000,
             len_fn=_default_len_fn,
             buffer_size=10240,
-            sort_size=500
+            sort_size=500,
+            batch_mode="padding",
     ):
         assert batch_size > 0, "Batch size is required to be larger than 0!"
         assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
@@ -35,6 +36,7 @@
         self.batch_size = batch_size
         self.buffer_size = buffer_size
         self.sort_size = sort_size
+        self.batch_mode = batch_mode
 
     def set_epoch(self, epoch):
         self.epoch = epoch
@@ -44,55 +46,137 @@
         batch = []
         bucket = []
         max_lengths = 0
+        min_lengths = 999999
         batch_lengths = 0
 
-        if self.buffer_size == -1:
-            for d in self.datapipe:
-                if d[0] > self.batch_size:
-                    continue
-                buffer.append(d)
-            buffer.sort()
-            for sample in buffer:
-                length, _, token = sample
-                if length > max_lengths:
-                    max_lengths = length
-                batch_lengths = max_lengths * (len(batch) + 1)
-                if batch_lengths > self.batch_size:
-                    bucket.append(batch)
-                    batch = []
-                    max_lengths = length
-                batch.append(token)
-            random.shuffle(bucket)
-            if bucket:
-                for batch_sample in bucket:
-                    yield batch_sample
-            if batch:
-                yield batch
-
-        elif self.buffer_size == 0:
-            for d in self.datapipe:
-                if d[0] > self.batch_size:
-                    continue
-                length, _, token = d
-                if length > self.batch_size:
-                    continue
-                if length > max_lengths:
-                    max_lengths = length
-                batch_lengths = max_lengths * (len(batch) + 1)
-                if batch_lengths > self.batch_size:
-                    yield batch
-                    batch = []
-                    max_lengths = length
-                batch.append(token)
-            if batch:
-                yield batch
-
-        else:
+        if self.batch_mode == "clipping":
+            assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
             for d in self.datapipe:
                 if d[0] > self.batch_size:
                     continue
                 buffer.append(d)
                 if len(buffer) == self.buffer_size:
+                    random.shuffle(buffer)
+                    for sample in buffer:
+                        bucket.append(sample)
+                        if len(bucket) == self.sort_size:
+                            bucket.sort()
+                            for x in bucket:
+                                length, _, token = x
+                                if length < min_lengths:
+                                    min_lengths = length
+                                batch_lengths = min_lengths * (len(batch) + 1)
+                                if batch_lengths > self.batch_size:
+                                    yield batch
+                                    batch = []
+                                    min_lengths = length
+                                batch.append(token)
+                            bucket = []
+                    buffer = []
+
+            if buffer:
+                random.shuffle(buffer)
+                for sample in buffer:
+                    bucket.append(sample)
+                    if len(bucket) == self.sort_size:
+                        bucket.sort()
+                        for x in bucket:
+                            length, _, token = x
+                            if length < min_lengths:
+                                min_lengths = length
+                            batch_lengths = min_lengths * (len(batch) + 1)
+                            if batch_lengths > self.batch_size:
+                                yield batch
+                                batch = []
+                                min_lengths = length
+                            batch.append(token)
+                        bucket = []
+                buffer = []
+
+            if bucket:
+                bucket.sort()
+                for x in bucket:
+                    length, _, token = x
+                    if length < min_lengths:
+                        min_lengths = length
+                    batch_lengths = min_lengths * (len(batch) + 1)
+                    if batch_lengths > self.batch_size:
+                        yield batch
+                        batch = []
+                        min_lengths = length
+                    batch.append(token)
+                bucket = []
+
+            if batch:
+                yield batch
+
+        else:
+            if self.buffer_size == -1:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    buffer.append(d)
+                buffer.sort()
+                for sample in buffer:
+                    length, _, token = sample
+                    if length > max_lengths:
+                        max_lengths = length
+                    batch_lengths = max_lengths * (len(batch) + 1)
+                    if batch_lengths > self.batch_size:
+                        bucket.append(batch)
+                        batch = []
+                        max_lengths = length
+                    batch.append(token)
+                random.shuffle(bucket)
+                if bucket:
+                    for batch_sample in bucket:
+                        yield batch_sample
+                if batch:
+                    yield batch
+
+            elif self.buffer_size == 0:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    length, _, token = d
+                    if length > self.batch_size:
+                        continue
+                    if length > max_lengths:
+                        max_lengths = length
+                    batch_lengths = max_lengths * (len(batch) + 1)
+                    if batch_lengths > self.batch_size:
+                        yield batch
+                        batch = []
+                        max_lengths = length
+                    batch.append(token)
+                if batch:
+                    yield batch
+
+            else:
+                for d in self.datapipe:
+                    if d[0] > self.batch_size:
+                        continue
+                    buffer.append(d)
+                    if len(buffer) == self.buffer_size:
+                        random.shuffle(buffer)
+                        for sample in buffer:
+                            bucket.append(sample)
+                            if len(bucket) == self.sort_size:
+                                bucket.sort()
+                                for x in bucket:
+                                    length, _, token = x
+                                    if length > max_lengths:
+                                        max_lengths = length
+                                    batch_lengths = max_lengths * (len(batch) + 1)
+                                    if batch_lengths > self.batch_size:
+                                        yield batch
+                                        batch = []
+                                        max_lengths = length
+                                    batch.append(token)
+                                bucket = []
+                        buffer = []
+
+                if buffer:
                     random.shuffle(buffer)
                     for sample in buffer:
                         bucket.append(sample)
@@ -111,38 +195,19 @@
                             bucket = []
                     buffer = []
 
-            if buffer:
-                random.shuffle(buffer)
-                for sample in buffer:
-                    bucket.append(sample)
-                    if len(bucket) == self.sort_size:
-                        bucket.sort()
-                        for x in bucket:
-                            length, _, token = x
-                            if length > max_lengths:
-                                max_lengths = length
-                            batch_lengths = max_lengths * (len(batch) + 1)
-                            if batch_lengths > self.batch_size:
-                                yield batch
-                                batch = []
-                                max_lengths = length
-                            batch.append(token)
-                        bucket = []
-                buffer = []
+                if bucket:
+                    bucket.sort()
+                    for x in bucket:
+                        length, _, token = x
+                        if length > max_lengths:
+                            max_lengths = length
+                        batch_lengths = max_lengths * (len(batch) + 1)
+                        if batch_lengths > self.batch_size:
+                            yield batch
+                            batch = []
+                            max_lengths = length
+                        batch.append(token)
+                    bucket = []
 
-            if bucket:
-                bucket.sort()
-                for x in bucket:
-                    length, _, token = x
-                    if length > max_lengths:
-                        max_lengths = length
-                    batch_lengths = max_lengths * (len(batch) + 1)
-                    if batch_lengths > self.batch_size:
-                        yield batch
-                        batch = []
-                        max_lengths = length
-                    batch.append(token)
-                bucket = []
-
-            if batch:
-                yield batch
+                if batch:
+                    yield batch
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 41d34ab..55b0678 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -13,6 +13,7 @@
 from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
 from funasr.datasets.large_datasets.utils.filter import filter
 from funasr.datasets.large_datasets.utils.padding import padding
+from funasr.datasets.large_datasets.utils.clipping import clipping
 from funasr.datasets.large_datasets.utils.tokenize import tokenize
 
 
@@ -101,6 +102,8 @@
                 elif data_type == "text" or data_type == "sound":
                     text_reader = open(data_file, "r")
                     reader_list.append(text_reader)
+                elif data_type == "none":
+                    continue
                 else:
                     raise TypeError("Data type {} is not supported".format(data_type))
 
@@ -143,7 +146,8 @@
             dict,
             seg_dict,
             conf,
-            mode="train"):
+            mode="train",
+            batch_mode="padding"):
     scp_lists = read_lists(data_list_file)
     shuffle = conf.get('shuffle', True)
     data_names = conf.get("data_names", "speech,text")
@@ -154,9 +158,10 @@
     filter_fn = partial(filter, **filter_conf)
     dataset = FilterIterDataPipe(dataset, fn=filter_fn)
 
-    vocab = {'vocab': dict, 'seg_dict': seg_dict}
-    tokenize_fn = partial(tokenize, **vocab)
-    dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
+    if "text" in data_names:
+        vocab = {'vocab': dict, 'seg_dict': seg_dict}
+        tokenize_fn = partial(tokenize, **vocab)
+        dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
 
     if shuffle:
         buffer_conf = conf.get('shuffle_conf', {})
@@ -180,8 +185,9 @@
                                              batch_size=batch_size,
                                              len_fn=len_fn,
                                              buffer_size=buffer_size,
-                                             sort_size=sort_size)
+                                             sort_size=sort_size,
+                                             batch_mode=batch_mode)
 
-    dataset = MapperIterDataPipe(dataset, fn=padding)
+    dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping)
 
     return dataset
diff --git a/funasr/datasets/large_datasets/utils/clipping.py b/funasr/datasets/large_datasets/utils/clipping.py
new file mode 100644
index 0000000..f5c2940
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/clipping.py
@@ -0,0 +1,40 @@
+import numpy as np
+import torch
+
+from funasr.datasets.collate_fn import crop_to_max_size
+
+
+def clipping(data):
+    assert isinstance(data, list)
+    assert "key" in data[0]
+
+    keys = [x["key"] for x in data]
+
+    batch = {}
+    data_names = data[0].keys()
+    for data_name in data_names:
+        if data_name == "key":
+            continue
+        else:
+            if data[0][data_name].dtype.kind == "i":
+                tensor_type = torch.int64
+            else:
+                tensor_type = torch.float32
+
+            tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
+            tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
+
+            length_clip = min(tensor_lengths)
+            tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
+            for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
+                diff = length - length_clip
+                assert diff >= 0
+                if diff == 0:
+                    tensor_clip[i] = tensor
+                else:
+                    tensor_clip[i] = crop_to_max_size(tensor, length_clip)
+
+            batch[data_name] = tensor_clip
+            batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
+
+    return keys, batch
diff --git a/funasr/datasets/large_datasets/utils/filter.py b/funasr/datasets/large_datasets/utils/filter.py
index 91ba4be..1260a47 100644
--- a/funasr/datasets/large_datasets/utils/filter.py
+++ b/funasr/datasets/large_datasets/utils/filter.py
@@ -6,13 +6,21 @@
            speech_length_max=15000,
            token_length_min=0,
            token_length_max=200):
-    assert "speech" in data
-    assert "text" in data
+    assert "speech" in data or "text" in data
 
-    if "sampling_rate" in data:
-        speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+    if "speech" in data and "text" in data:
+        if "sampling_rate" in data:
+            speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+        else:
+            speech_length = data["speech"].shape[0]
+        num_tokens = len(data['text'])
+        return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
+    elif "speech" in data:
+        if "sampling_rate" in data:
+            speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+        else:
+            speech_length = data["speech"].shape[0]
+        return speech_length_min < speech_length < speech_length_max
     else:
-        speech_length = data["speech"].shape[0]
-    num_tokens = len(data['text'])
-
-    return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
+        num_tokens = len(data['text'])
+        return token_length_min < num_tokens < token_length_max
diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py
new file mode 100644
index 0000000..fcd6bd2
--- /dev/null
+++ b/funasr/models/data2vec.py
@@ -0,0 +1,160 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
+
+class Data2VecPretrainModel(AbsESPnetModel):
+    """Data2Vec Pretrain model"""
+
+    def __init__(
+            self,
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            preencoder: Optional[AbsPreEncoder],
+            encoder: AbsEncoder,
+    ):
+        assert check_argument_types()
+
+        super().__init__()
+
+        self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+        self.preencoder = preencoder
+        self.encoder = encoder
+        self.num_updates = 0
+
+    def forward(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Frontend + Encoder + Calc loss
+
+        Args:
+            speech: (Batch, Length, ...)
+            speech_lengths: (Batch, )
+        """
+        # Check that batch_size is unified
+        assert (
+                speech.shape[0]
+                == speech_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape)
+
+        self.encoder.set_num_updates(self.num_updates)
+
+        # 1. Encoder
+        encoder_out = self.encode(speech, speech_lengths)
+
+        losses = encoder_out["losses"]
+        loss = sum(losses.values())
+        sample_size = encoder_out["sample_size"]
+        loss = loss.sum() / sample_size
+
+        target_var = float(encoder_out["target_var"])
+        pred_var = float(encoder_out["pred_var"])
+        ema_decay = float(encoder_out["ema_decay"])
+
+        stats = dict(
+            loss=torch.clone(loss.detach()),
+            target_var=target_var,
+            pred_var=pred_var,
+            ema_decay=ema_decay,
+        )
+
+        loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
+        return loss, stats, weight
+
+    def collect_feats(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor
+    ) -> Dict[str, torch.Tensor]:
+        feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+        return {"feats": feats, "feats_lengths": feats_lengths}
+
+    def encode(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+    ):
+        """Frontend + Encoder.
+
+        Args:
+            speech: (Batch, Length, ...)
+            speech_lengths: (Batch, )
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # Pre-encoder, e.g. used for raw input data
+        if self.preencoder is not None:
+            feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+        # 4. Forward encoder
+        if min(speech_lengths) == max(speech_lengths):  # for clipping, set speech_lengths as None
+            speech_lengths = None
+        encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
+
+        return encoder_out
+
+    def _extract_feats(
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        assert speech_lengths.dim() == 1, speech_lengths.shape
+
+        # for data-parallel
+        speech = speech[:, : speech_lengths.max()]
+
+        if self.frontend is not None:
+            # Frontend
+            #  e.g. STFT and Feature extract
+            #       data_loader may send time-domain signal in this case
+            # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:
+            # No frontend and no feature extract
+            feats, feats_lengths = speech, speech_lengths
+        return feats, feats_lengths
+
+    def set_num_updates(self, num_updates):
+        self.num_updates = num_updates
+
+    def get_num_updates(self):
+        return self.num_updates
diff --git a/funasr/optimizers/fairseq_adam.py b/funasr/optimizers/fairseq_adam.py
new file mode 100644
index 0000000..9bdd0f8
--- /dev/null
+++ b/funasr/optimizers/fairseq_adam.py
@@ -0,0 +1,148 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import math
+
+import torch
+import torch.optim
+
+
+class FairseqAdam(torch.optim.Optimizer):
+    r"""Implements Adam algorithm.
+
+    This implementation is modified from torch.optim.Adam based on:
+    `Fixed Weight Decay Regularization in Adam`
+    (see https://arxiv.org/abs/1711.05101)
+
+    It has been proposed in `Adam: A Method for Stochastic Optimization`_.
+
+    Args:
+        params (iterable): iterable of parameters to optimize or dicts defining
+            parameter groups
+        lr (float, optional): learning rate (default: 1e-3)
+        betas (Tuple[float, float], optional): coefficients used for computing
+            running averages of gradient and its square (default: (0.9, 0.999))
+        eps (float, optional): term added to the denominator to improve
+            numerical stability (default: 1e-8)
+        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+            algorithm from the paper `On the Convergence of Adam and Beyond`_
+
+    .. _Adam\: A Method for Stochastic Optimization:
+        https://arxiv.org/abs/1412.6980
+    .. _On the Convergence of Adam and Beyond:
+        https://openreview.net/forum?id=ryQu7f-RZ
+    """
+
+    def __init__(
+            self,
+            params,
+            lr=1e-3,
+            adam_betas=(0.9, 0.999),
+            adam_eps=1e-8,
+            weight_decay=0,
+            amsgrad=False,
+    ):
+        defaults = dict(
+            lr=lr, betas=adam_betas, eps=adam_eps, weight_decay=weight_decay, amsgrad=amsgrad
+        )
+        super(FairseqAdam, self).__init__(params, defaults)
+        self.optimizer_lr = lr
+
+    @property
+    def supports_memory_efficient_fp16(self):
+        return True
+
+    @property
+    def supports_flat_params(self):
+        return True
+
+    def step(self, closure=None):
+        """Performs a single optimization step.
+
+        Args:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+            for p in group["params"]:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data
+                if grad.dtype in {torch.float16, torch.bfloat16}:
+                    grad = grad.float()
+                if grad.is_sparse:
+                    raise RuntimeError(
+                        "Adam does not support sparse gradients, please consider SparseAdam instead"
+                    )
+                amsgrad = group.get("amsgrad", False)
+
+                p_data_fp32 = p.data
+                if p.data.dtype in {torch.float16, torch.bfloat16}:
+                    p_data_fp32 = p_data_fp32.float()
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+                    state["step"] = 0
+                    # Exponential moving average of gradient values
+                    state["exp_avg"] = torch.zeros_like(p_data_fp32)
+                    # Exponential moving average of squared gradient values
+                    state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
+                    if amsgrad:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
+                else:
+                    state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
+                    state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
+                    if amsgrad:
+                        state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
+                            p_data_fp32
+                        )
+
+                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+                if amsgrad:
+                    max_exp_avg_sq = state["max_exp_avg_sq"]
+                beta1, beta2 = group["betas"]
+
+                state["step"] += 1
+
+                # Decay the first and second moment running average coefficient
+                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+                if amsgrad:
+                    # Maintains the maximum of all 2nd moment running avg. till now
+                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+                    # Use the max. for normalizing running avg. of gradient
+                    denom = max_exp_avg_sq.sqrt().add_(group["eps"])
+                else:
+                    denom = exp_avg_sq.sqrt().add_(group["eps"])
+
+                bias_correction1 = 1 - beta1 ** state["step"]
+                bias_correction2 = 1 - beta2 ** state["step"]
+                step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+
+                if group["weight_decay"] != 0:
+                    p_data_fp32.add_(
+                        p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
+                    )
+
+                p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
+
+                if p.data.dtype in {torch.float16, torch.bfloat16}:
+                    p.data.copy_(p_data_fp32)
+
+        return loss
+
+    def set_lr(self, lr):
+        """Set the learning rate."""
+        for param_group in self.param_groups:
+            param_group["lr"] = lr
diff --git a/funasr/schedulers/tri_stage_scheduler.py b/funasr/schedulers/tri_stage_scheduler.py
new file mode 100644
index 0000000..8dc71b4
--- /dev/null
+++ b/funasr/schedulers/tri_stage_scheduler.py
@@ -0,0 +1,108 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Optional, List
+
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+from typeguard import check_argument_types
+
+from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
+
+
+class TriStageLR(_LRScheduler, AbsBatchStepScheduler):
+    def __init__(
+            self,
+            optimizer: torch.optim.Optimizer,
+            last_epoch: int = -1,
+            phase_ratio: Optional[List[float]] = None,
+            init_lr_scale: float = 0.01,
+            final_lr_scale: float = 0.01,
+    ):
+        assert check_argument_types()
+        self.optimizer = optimizer
+        self.last_epoch = last_epoch
+        self.phase_ratio = phase_ratio
+        self.init_lr_scale = init_lr_scale
+        self.final_lr_scale = final_lr_scale
+        self.optimizer_lr = self.optimizer.defaults["lr"]
+
+    def init_tri_stage_scheudler(self, max_update):
+        self.max_update = max_update
+        self.peak_lr = self.optimizer_lr
+        self.init_lr = self.init_lr_scale * self.optimizer_lr
+        self.final_lr = self.final_lr_scale * self.optimizer_lr
+
+        assert self.max_update > 0
+        assert sum(self.phase_ratio) == 1, "phase ratios must add up to 1"
+        assert len(self.phase_ratio) == 3
+        self.warmup_steps = int(self.max_update * self.phase_ratio[0])
+        self.hold_steps = int(self.max_update * self.phase_ratio[1])
+        self.decay_steps = int(self.max_update * self.phase_ratio[2])
+
+        self.warmup_rate = (
+            (self.peak_lr - self.init_lr) / self.warmup_steps
+            if self.warmup_steps != 0
+            else 0
+        )
+        self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps
+
+        # initial learning rate
+        self.lr = self.init_lr
+
+        # __init__() must be invoked before setting field
+        # because step() is also invoked in __init__()
+        self.set_optimizer_lr(self.lr)
+        super().__init__(self.optimizer, self.last_epoch)
+
+    def _decide_stage(self, update_step):
+        """
+        return stage, and the corresponding steps within the current stage
+        """
+        if update_step < self.warmup_steps:
+            # warmup state
+            return 0, update_step
+
+        offset = self.warmup_steps
+
+        if update_step < offset + self.hold_steps:
+            # hold stage
+            return 1, update_step - offset
+
+        offset += self.hold_steps
+
+        if update_step <= offset + self.decay_steps:
+            # decay stage
+            return 2, update_step - offset
+
+        offset += self.decay_steps
+
+        # still here ? constant lr stage
+        return 3, update_step - offset
+
+    def step_update(self, num_updates):
+        """Update the learning rate after each update."""
+        stage, steps_in_stage = self._decide_stage(num_updates)
+        if stage == 0:
+            self.lr = self.init_lr + self.warmup_rate * steps_in_stage
+        elif stage == 1:
+            self.lr = self.peak_lr
+        elif stage == 2:
+            self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
+        elif stage == 3:
+            self.lr = self.final_lr
+        else:
+            raise ValueError("Undefined stage")
+        self.set_optimizer_lr(self.lr)
+
+    def set_optimizer_lr(self, lr):
+        for param_group in self.optimizer.param_groups:
+            param_group["lr"] = lr
+
+    def get_lr(self):
+        step_num = self.last_epoch + 1
+        self.step_update(step_num)
+        return [self.lr]
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 4e79c63..7899400 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -44,11 +44,13 @@
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
 from funasr.iterators.sequence_iter_factory import SequenceIterFactory
 from funasr.optimizers.sgd import SGD
+from funasr.optimizers.fairseq_adam import FairseqAdam
 from funasr.samplers.build_batch_sampler import BATCH_TYPES
 from funasr.samplers.build_batch_sampler import build_batch_sampler
 from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
 from funasr.schedulers.noam_lr import NoamLR
 from funasr.schedulers.warmup_lr import WarmupLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
 from funasr.torch_utils.load_pretrained_model import load_pretrained_model
 from funasr.torch_utils.model_summary import model_summary
 from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
@@ -83,6 +85,7 @@
 
 optim_classes = dict(
     adam=torch.optim.Adam,
+    fairseq_adam=FairseqAdam,
     adamw=torch.optim.AdamW,
     sgd=SGD,
     adadelta=torch.optim.Adadelta,
@@ -149,6 +152,7 @@
     CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
     noamlr=NoamLR,
     warmuplr=WarmupLR,
+    tri_stage=TriStageLR,
     cycliclr=torch.optim.lr_scheduler.CyclicLR,
     onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
     CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
diff --git a/funasr/tasks/data2vec.py b/funasr/tasks/data2vec.py
new file mode 100644
index 0000000..9a64e1f
--- /dev/null
+++ b/funasr/tasks/data2vec.py
@@ -0,0 +1,376 @@
+import argparse
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.data2vec import Data2VecPretrainModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.tasks.abs_task import AbsTask
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+frontend_choices = ClassChoices(
+    name="frontend",
+    classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
+    type_check=AbsFrontend,
+    default="default",
+)
+specaug_choices = ClassChoices(
+    name="specaug",
+    classes=dict(specaug=SpecAug),
+    type_check=AbsSpecAug,
+    default=None,
+    optional=True,
+)
+normalize_choices = ClassChoices(
+    "normalize",
+    classes=dict(
+        global_mvn=GlobalMVN,
+        utterance_mvn=UtteranceMVN,
+    ),
+    type_check=AbsNormalize,
+    default=None,
+    optional=True,
+)
+preencoder_choices = ClassChoices(
+    name="preencoder",
+    classes=dict(
+        sinc=LightweightSincConvs,
+    ),
+    type_check=AbsPreEncoder,
+    default=None,
+    optional=True,
+)
+encoder_choices = ClassChoices(
+    "encoder",
+    classes=dict(
+        data2vec_encoder=Data2VecEncoder,
+    ),
+    type_check=AbsEncoder,
+    default="data2vec_encoder",
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        data2vec=Data2VecPretrainModel,
+    ),
+    default="data2vec",
+)
+
+
+class Data2VecTask(AbsTask):
+    # If you need more than one optimizers, change this value
+    num_optimizers: int = 1
+
+    # Add variable objects configurations
+    class_choices_list = [
+        # --frontend and --frontend_conf
+        frontend_choices,
+        # --specaug and --specaug_conf
+        specaug_choices,
+        # --normalize and --normalize_conf
+        normalize_choices,
+        # --preencoder and --preencoder_conf
+        preencoder_choices,
+        # --encoder and --encoder_conf
+        encoder_choices,
+        # --model and --model_conf
+        model_choices,
+    ]
+
+    # If you need to modify train() or eval() procedures, change Trainer class here
+    trainer = Trainer
+
+    @classmethod
+    def add_task_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(description="Task related")
+
+        # NOTE(kamo): add_arguments(..., required=True) can't be used
+        # to provide --print_config mode. Instead of it, do as
+        group.add_argument(
+            "--token_list",
+            type=str_or_none,
+            default=None,
+            help="A text mapping int-id to token",
+        )
+        group.add_argument(
+            "--init",
+            type=lambda x: str_or_none(x.lower()),
+            default=None,
+            help="The initialization method",
+            choices=[
+                "chainer",
+                "xavier_uniform",
+                "xavier_normal",
+                "kaiming_uniform",
+                "kaiming_normal",
+                None,
+            ],
+        )
+
+        group.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+
+        group = parser.add_argument_group(description="Preprocess related")
+        group.add_argument(
+            "--use_preprocessor",
+            type=str2bool,
+            default=True,
+            help="Apply preprocessing to data or not",
+        )
+        group.add_argument(
+            "--token_type",
+            type=str,
+            default=None,
+            choices=["bpe", "char", "word", "phn"],
+            help="The text will be tokenized " "in the specified level token",
+        )
+
+        group.add_argument(
+            "--feats_type",
+            type=str,
+            default='fbank',
+            help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
+        )
+
+        group.add_argument(
+            "--bpemodel",
+            type=str_or_none,
+            default=None,
+            help="The model file of sentencepiece",
+        )
+        parser.add_argument(
+            "--non_linguistic_symbols",
+            type=str_or_none,
+            help="non_linguistic_symbols file path",
+        )
+        parser.add_argument(
+            "--cleaner",
+            type=str_or_none,
+            choices=[None, "tacotron", "jaconv", "vietnamese"],
+            default=None,
+            help="Apply text cleaning",
+        )
+        parser.add_argument(
+            "--g2p",
+            type=str_or_none,
+            choices=g2p_choices,
+            default=None,
+            help="Specify g2p method if --token_type=phn",
+        )
+        parser.add_argument(
+            "--speech_volume_normalize",
+            type=float_or_none,
+            default=None,
+            help="Scale the maximum amplitude to the given value.",
+        )
+        parser.add_argument(
+            "--rir_scp",
+            type=str_or_none,
+            default=None,
+            help="The file path of rir scp file.",
+        )
+        parser.add_argument(
+            "--rir_apply_prob",
+            type=float,
+            default=1.0,
+            help="THe probability for applying RIR convolution.",
+        )
+        parser.add_argument(
+            "--noise_scp",
+            type=str_or_none,
+            default=None,
+            help="The file path of noise scp file.",
+        )
+        parser.add_argument(
+            "--noise_apply_prob",
+            type=float,
+            default=1.0,
+            help="The probability applying Noise adding.",
+        )
+        parser.add_argument(
+            "--noise_db_range",
+            type=str,
+            default="13_15",
+            help="The range of noise decibel level.",
+        )
+        parser.add_argument(
+            "--pred_masked_weight",
+            type=float,
+            default=1.0,
+            help="weight for predictive loss for masked frames",
+        )
+        parser.add_argument(
+            "--pred_nomask_weight",
+            type=float,
+            default=0.0,
+            help="weight for predictive loss for unmasked frames",
+        )
+        parser.add_argument(
+            "--loss_weights",
+            type=float,
+            default=0.0,
+            help="weights for additional loss terms (not first one)",
+        )
+
+        for class_choices in cls.class_choices_list:
+            # Append --<name> and --<name>_conf.
+            # e.g. --encoder and --encoder_conf
+            class_choices.add_arguments(group)
+
+    @classmethod
+    def build_collate_fn(
+            cls, args: argparse.Namespace, train: bool
+    ) -> Callable[
+        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+        Tuple[List[str], Dict[str, torch.Tensor]],
+    ]:
+        assert check_argument_types()
+        return CommonCollateFn(clipping=True)
+
+    @classmethod
+    def build_preprocess_fn(
+            cls, args: argparse.Namespace, train: bool
+    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+        assert check_argument_types()
+        if args.use_preprocessor:
+            retval = CommonPreprocessor(
+                train=train,
+                bpemodel=args.bpemodel,
+                non_linguistic_symbols=args.non_linguistic_symbols,
+                text_cleaner=args.cleaner,
+                g2p_type=args.g2p,
+                # NOTE(kamo): Check attribute existence for backward compatibility
+                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+                rir_apply_prob=args.rir_apply_prob
+                if hasattr(args, "rir_apply_prob")
+                else 1.0,
+                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+                noise_apply_prob=args.noise_apply_prob
+                if hasattr(args, "noise_apply_prob")
+                else 1.0,
+                noise_db_range=args.noise_db_range
+                if hasattr(args, "noise_db_range")
+                else "13_15",
+                speech_volume_normalize=args.speech_volume_normalize
+                if hasattr(args, "rir_scp")
+                else None,
+            )
+        else:
+            retval = None
+        assert check_return_type(retval)
+        return retval
+
+    @classmethod
+    def required_data_names(
+            cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        # for pre-training
+        retval = ("speech",)
+        return retval
+
+    @classmethod
+    def optional_data_names(
+            cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        retval = ()
+        assert check_return_type(retval)
+        return retval
+
+    @classmethod
+    def build_model(cls, args: argparse.Namespace):
+        assert check_argument_types()
+
+        # 1. frontend
+        if args.input_size is None:
+            # Extract features in the model
+            frontend_class = frontend_choices.get_class(args.frontend)
+            frontend = frontend_class(**args.frontend_conf)
+            input_size = frontend.output_size()
+        else:
+            # Give features from data-loader
+            args.frontend = None
+            args.frontend_conf = {}
+            frontend = None
+            input_size = args.input_size
+
+        # 2. Data augmentation for spectrogram
+        if args.specaug is not None:
+            specaug_class = specaug_choices.get_class(args.specaug)
+            specaug = specaug_class(**args.specaug_conf)
+        else:
+            specaug = None
+
+        # 3. Normalization layer
+        if args.normalize is not None:
+            normalize_class = normalize_choices.get_class(args.normalize)
+            normalize = normalize_class(**args.normalize_conf)
+        else:
+            normalize = None
+
+        # 4. Pre-encoder input block
+        # NOTE(kan-bayashi): Use getattr to keep the compatibility
+        if getattr(args, "preencoder", None) is not None:
+            preencoder_class = preencoder_choices.get_class(args.preencoder)
+            preencoder = preencoder_class(**args.preencoder_conf)
+            input_size = preencoder.output_size()
+        else:
+            preencoder = None
+
+        # 5. Encoder
+        encoder_class = encoder_choices.get_class(args.encoder)
+        encoder = encoder_class(
+            input_size=input_size,
+            **args.encoder_conf,
+        )
+
+        # 6. Build model
+        try:
+            model_class = model_choices.get_class(args.model)
+        except AttributeError:
+            model_class = model_choices.get_class("data2vec")
+        model = model_class(
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            preencoder=preencoder,
+            encoder=encoder,
+        )
+
+        # 7. Initialize
+        if args.init is not None:
+            initialize(model, args.init)
+
+        assert check_return_type(model)
+        return model

--
Gitblit v1.9.1