From de0ecb446fa429d210397949694d8d9ad6d66112 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 09 二月 2023 10:57:25 +0800
Subject: [PATCH] Merge pull request #79 from alibaba-damo-academy/dev_wjm
---
egs/aishell/data2vec_transformer_finetune/run.sh | 0
funasr/datasets/collate_fn.py | 52 +
funasr/optimizers/fairseq_adam.py | 148 ++++
funasr/datasets/large_datasets/utils/filter.py | 24
egs/aishell/data2vec_paraformer_finetune/utils | 1
egs/aishell/data2vec_transformer_finetune/local/aishell_data_prep.sh | 0
funasr/datasets/large_datasets/utils/clipping.py | 40 +
funasr/datasets/large_datasets/build_dataloader.py | 5
egs/aishell2/data2vec_pretrain/path.sh | 6
egs/aishell/data2vec_paraformer_finetune/conf/decode_asr_transformer_noctc_1best.yaml | 6
egs/aishell/data2vec_transformer_finetune/path.sh | 0
funasr/bin/asr_inference_paraformer.py | 2
egs/aishell/data2vec_transformer_finetune/conf/decode_asr_transformer.yaml | 0
egs/aishell2/data2vec_pretrain/local/prepare_data.sh | 53 +
egs/aishell2/data2vec_pretrain/run.sh | 172 +++++
funasr/models/data2vec.py | 160 +++++
egs/aishell/data2vec_paraformer_finetune/local/aishell_data_prep.sh | 0
egs/aishell/data2vec_transformer_finetune/utils | 0
funasr/datasets/large_datasets/datapipes/batch.py | 219 ++++--
egs/aishell2/data2vec_pretrain/utils | 1
funasr/bin/data2vec_train.py | 45 +
egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh | 0
egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml | 79 ++
funasr/tasks/data2vec.py | 376 +++++++++++
egs/aishell/data2vec_paraformer_finetune/path.sh | 0
funasr/tasks/abs_task.py | 4
egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml | 104 +++
funasr/datasets/large_datasets/dataset.py | 18
egs/aishell/data2vec_paraformer_finetune/run.sh | 252 +++++++
egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml | 0
funasr/schedulers/tri_stage_scheduler.py | 108 +++
egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh | 0
32 files changed, 1,781 insertions(+), 94 deletions(-)
diff --git a/egs/aishell/data2vec_paraformer_finetune/conf/decode_asr_transformer_noctc_1best.yaml b/egs/aishell/data2vec_paraformer_finetune/conf/decode_asr_transformer_noctc_1best.yaml
new file mode 100644
index 0000000..5436b12
--- /dev/null
+++ b/egs/aishell/data2vec_paraformer_finetune/conf/decode_asr_transformer_noctc_1best.yaml
@@ -0,0 +1,6 @@
+beam_size: 1
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.0
+lm_weight: 0.15
diff --git a/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml b/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
new file mode 100644
index 0000000..f9a2cdb
--- /dev/null
+++ b/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
@@ -0,0 +1,104 @@
+# network architecture
+# encoder related
+encoder: data2vec_encoder
+encoder_conf:
+ extractor_mode: layer_norm
+ encoder_layerdrop: 0.1
+ dropout_input: 0.0
+ dropout_features: 0.0
+ feature_grad_mult: 0.0
+ encoder_embed_dim: 768
+
+ mask_prob: 0.65
+ mask_length: 10
+
+ loss_beta: 0
+ loss_scale: null
+
+ instance_norm_target_layer: true
+ average_top_k_layers: 8
+
+ pos_conv_depth: 5
+ conv_pos: 95
+
+ ema_decay: 0.999
+ ema_end_decay: 0.9999
+ ema_anneal_end_step: 30000
+ ema_transformer_only: true
+ ema_layers_only: true
+
+ require_same_masks: true
+ mask_dropout: 0
+
+
+# decoder related
+decoder: paraformer_decoder_san
+decoder_conf:
+ attention_heads: 12
+ linear_units: 3072
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.0
+ src_attention_dropout_rate: 0.0
+
+model: paraformer
+model_conf:
+ ctc_weight: 0.3
+ lsm_weight: 0.1
+ length_normalized_loss: false
+ predictor_weight: 1.0
+ sampling_ratio: 0.4
+
+# minibatch related
+batch_type: length
+batch_bins: 25000
+num_workers: 16
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 50
+val_scheduler_criterion:
+ - valid
+ - acc
+best_model_criterion:
+- - valid
+ - acc
+ - max
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.00002
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 40
+ num_time_mask: 2
+
+predictor: cif_predictor
+predictor_conf:
+ idim: 768
+ threshold: 1.0
+ l_order: 1
+ r_order: 1
+
+
+log_interval: 50
+unused_parameters: true
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell/data2vec_finetune/local/aishell_data_prep.sh b/egs/aishell/data2vec_paraformer_finetune/local/aishell_data_prep.sh
similarity index 100%
rename from egs/aishell/data2vec_finetune/local/aishell_data_prep.sh
rename to egs/aishell/data2vec_paraformer_finetune/local/aishell_data_prep.sh
diff --git a/egs/aishell/data2vec_finetune/local/prepare_data.sh b/egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh
similarity index 100%
rename from egs/aishell/data2vec_finetune/local/prepare_data.sh
rename to egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh
diff --git a/egs/aishell/data2vec_finetune/path.sh b/egs/aishell/data2vec_paraformer_finetune/path.sh
similarity index 100%
rename from egs/aishell/data2vec_finetune/path.sh
rename to egs/aishell/data2vec_paraformer_finetune/path.sh
diff --git a/egs/aishell/data2vec_paraformer_finetune/run.sh b/egs/aishell/data2vec_paraformer_finetune/run.sh
new file mode 100755
index 0000000..cada164
--- /dev/null
+++ b/egs/aishell/data2vec_paraformer_finetune/run.sh
@@ -0,0 +1,252 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=2
+count=1
+gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="../DATA" #feature output dictionary, for large data
+exp_dir="."
+lang=zh
+dumpdir=dump/fbank
+feats_type=fbank
+token_type=char
+scp=feats.scp
+type=kaldi_ark
+stage=0
+stop_stage=4
+
+# feature configuration
+feats_dim=80
+sample_frequency=16000
+nj=32
+speed_perturb="0.9,1.0,1.1"
+
+# data
+data_aishell=
+
+# exp tag
+tag=""
+
+model_name=damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch
+init_param="$HOME/.cache/modelscope/hub/$model_name/basemodel.pb"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train
+valid_set=dev
+test_sets="dev test"
+
+asr_config=conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+
+inference_config=conf/decode_asr_transformer_noctc_1best.yaml
+inference_asr_model=valid.acc.ave_10best.pth
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+ inference_nj=$[${ngpu}*${njob}]
+ _ngpu=1
+else
+ inference_nj=$njob
+ _ngpu=0
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ for x in train dev test; do
+ cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
+ > ${feats_dir}/data/${x}/text
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
+ done
+fi
+
+feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
+feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
+feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Feature Generation"
+ # compute fbank features
+ fbankdir=${feats_dir}/fbank
+ utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
+ ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
+ utils/fix_data_feat.sh ${fbankdir}/train
+ utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
+ ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
+ utils/fix_data_feat.sh ${fbankdir}/dev
+ utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
+ ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
+ utils/fix_data_feat.sh ${fbankdir}/test
+
+ # compute global cmvn
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
+ ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
+
+ # apply cmvn
+ utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
+ utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
+ utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
+
+ cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
+ cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
+ cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
+
+ utils/fix_data_feat.sh ${feat_train_dir}
+ utils/fix_data_feat.sh ${feat_dev_dir}
+ utils/fix_data_feat.sh ${feat_test_dir}
+
+ #generate ark list
+ utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
+ utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+fi
+
+token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: Dictionary Preparation"
+ mkdir -p ${feats_dir}/data/${lang}_token_list/char/
+
+ echo "make a dictionary"
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+ num_token=$(cat ${token_list} | wc -l)
+ echo "<unk>" >> ${token_list}
+ vocab_size=$(cat ${token_list} | wc -l)
+ awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
+ awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
+ mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
+ mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
+ cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
+ cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
+fi
+
+# Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: Training"
+ python utils/download_model.py --model_name ${model_name} # download pretrained model on ModelScope
+ mkdir -p ${exp_dir}/exp/${model_dir}
+ mkdir -p ${exp_dir}/exp/${model_dir}/log
+ INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $gpu_num; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ asr_train_paraformer.py \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --token_type char \
+ --token_list $token_list \
+ --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
+ --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
+ --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
+ --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
+ --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
+ --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
+ --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
+ --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --init_param ${init_param} \
+ --resume true \
+ --output_dir ${exp_dir}/exp/${model_dir} \
+ --config $asr_config \
+ --input_size $feats_dim \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: Inference"
+ for dset in ${test_sets}; do
+ asr_exp=${exp_dir}/exp/${model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/${dumpdir}/${dset}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --gpuid_list ${gpuid_list} \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${asr_exp}"/config.yaml \
+ --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode paraformer \
+ ${_opts}
+
+ for f in token token_int score text; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+ python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
+ python utils/proce_text.py ${_data}/text ${_data}/text.proc
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+ done
+fi
diff --git a/egs/aishell/data2vec_paraformer_finetune/utils b/egs/aishell/data2vec_paraformer_finetune/utils
new file mode 120000
index 0000000..fe070dd
--- /dev/null
+++ b/egs/aishell/data2vec_paraformer_finetune/utils
@@ -0,0 +1 @@
+../../aishell/transformer/utils
\ No newline at end of file
diff --git a/egs/aishell/data2vec_finetune/conf/decode_asr_transformer.yaml b/egs/aishell/data2vec_transformer_finetune/conf/decode_asr_transformer.yaml
similarity index 100%
rename from egs/aishell/data2vec_finetune/conf/decode_asr_transformer.yaml
rename to egs/aishell/data2vec_transformer_finetune/conf/decode_asr_transformer.yaml
diff --git a/egs/aishell/data2vec_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml b/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
similarity index 100%
rename from egs/aishell/data2vec_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
rename to egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
diff --git a/egs/aishell/data2vec_finetune/local/aishell_data_prep.sh b/egs/aishell/data2vec_transformer_finetune/local/aishell_data_prep.sh
similarity index 100%
copy from egs/aishell/data2vec_finetune/local/aishell_data_prep.sh
copy to egs/aishell/data2vec_transformer_finetune/local/aishell_data_prep.sh
diff --git a/egs/aishell/data2vec_finetune/local/prepare_data.sh b/egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
similarity index 100%
copy from egs/aishell/data2vec_finetune/local/prepare_data.sh
copy to egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
diff --git a/egs/aishell/data2vec_finetune/path.sh b/egs/aishell/data2vec_transformer_finetune/path.sh
similarity index 100%
copy from egs/aishell/data2vec_finetune/path.sh
copy to egs/aishell/data2vec_transformer_finetune/path.sh
diff --git a/egs/aishell/data2vec_finetune/run.sh b/egs/aishell/data2vec_transformer_finetune/run.sh
similarity index 100%
rename from egs/aishell/data2vec_finetune/run.sh
rename to egs/aishell/data2vec_transformer_finetune/run.sh
diff --git a/egs/aishell/data2vec_finetune/utils b/egs/aishell/data2vec_transformer_finetune/utils
similarity index 100%
rename from egs/aishell/data2vec_finetune/utils
rename to egs/aishell/data2vec_transformer_finetune/utils
diff --git a/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
new file mode 100644
index 0000000..4052774
--- /dev/null
+++ b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
@@ -0,0 +1,79 @@
+# network architecture
+# encoder related
+encoder: data2vec_encoder
+encoder_conf:
+ extractor_mode: layer_norm
+ encoder_layerdrop: 0.05
+ dropout_input: 0.0
+ dropout_features: 0.0
+ feature_grad_mult: 1.0
+ encoder_embed_dim: 768
+
+ mask_prob: 0.65
+ mask_length: 10
+
+ loss_beta: 0
+ loss_scale: null
+
+ instance_norm_target_layer: true
+ average_top_k_layers: 8
+
+ pos_conv_depth: 5
+ conv_pos: 95
+
+ ema_decay: 0.999
+ ema_end_decay: 0.9999
+ ema_anneal_end_step: 30000
+ ema_transformer_only: true
+ ema_layers_only: true
+
+ require_same_masks: true
+ mask_dropout: 0
+
+log_interval: 50
+normalize: None
+
+# minibatch related
+batch_type: length
+batch_bins: 64000
+num_workers: 16
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+patience: none
+max_epoch: 600
+val_scheduler_criterion:
+ - valid
+ - acc
+best_model_criterion:
+- - valid
+ - loss
+ - min
+keep_nbest_models: 50
+unused_parameters: true
+
+optim: fairseq_adam
+optim_conf:
+ lr: 0.0005
+ adam_betas: [0.9,0.98]
+ adam_eps: 1.0e-06
+ weight_decay: 0.01
+
+scheduler: tri_stage
+scheduler_conf:
+ phase_ratio: [0.03,0.9,0.07]
+
+# for dataset
+dataset_conf:
+ batch_mode: clipping
+ data_names: speech,none
+ data_types: kaldi_ark,none
+ shuffle: true
+ shuffle_conf:
+ shuffle_size: 12800
+ sort_size: 12800
+ batch_conf:
+ batch_type: token
+ batch_size: 64000
+ num_workers: 8
\ No newline at end of file
diff --git a/egs/aishell2/data2vec_pretrain/local/prepare_data.sh b/egs/aishell2/data2vec_pretrain/local/prepare_data.sh
new file mode 100755
index 0000000..ce6ee19
--- /dev/null
+++ b/egs/aishell2/data2vec_pretrain/local/prepare_data.sh
@@ -0,0 +1,53 @@
+#!/usr/bin/env bash
+# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
+# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
+# Apache 2.0
+
+# transform raw AISHELL-2 data to kaldi format
+
+. ./path.sh || exit 1;
+
+tmp=
+dir=
+
+if [ $# != 3 ]; then
+ echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
+ echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
+ exit 1;
+fi
+
+corpus=$1
+tmp=$2
+dir=$3
+
+echo "prepare_data.sh: Preparing data in $corpus"
+
+mkdir -p $tmp
+mkdir -p $dir
+
+# corpus check
+if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
+ echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
+ exit 1;
+fi
+
+# validate utt-key list, IC0803W0380 is a bad utterance
+awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
+awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
+tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+
+# wav.scp
+awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
+tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+
+# text
+tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+
+# copy prepared resources from tmp_dir to target dir
+mkdir -p $dir
+for f in wav.scp text; do
+ cp $tmp/$f $dir/$f || exit 1;
+done
+
+echo "local/prepare_data.sh succeeded"
+exit 0;
diff --git a/egs/aishell2/data2vec_pretrain/path.sh b/egs/aishell2/data2vec_pretrain/path.sh
new file mode 100755
index 0000000..ea3c0be
--- /dev/null
+++ b/egs/aishell2/data2vec_pretrain/path.sh
@@ -0,0 +1,6 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PYTHONPATH=../../../:$PYTHONPATH
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/egs/aishell2/data2vec_pretrain/run.sh b/egs/aishell2/data2vec_pretrain/run.sh
new file mode 100755
index 0000000..eceb183
--- /dev/null
+++ b/egs/aishell2/data2vec_pretrain/run.sh
@@ -0,0 +1,172 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
+gpu_num=8
+count=1
+
+train_cmd=tools/run.pl
+
+# general configuration
+feats_dir="../DATA" #feature output dictionary
+exp_dir="."
+lang=zh
+dumpdir=dump/fbank
+feats_type=fbank
+token_type=char
+dataset_type=large
+stage=0
+stop_stage=4
+
+# feature configuration
+feats_dim=80
+sample_frequency=16000
+nj=100
+speed_perturb="0.9,1.0,1.1"
+
+# data
+tr_dir=
+dev_tst_dir=
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train
+valid_set=dev_ios
+
+asr_config=conf/train_pretrain_transformer.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # For training set
+ local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
+ # # For dev and test set
+ for x in Android iOS Mic; do
+ local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
+ local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
+ done
+ # Normalize text to capital letters
+ for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
+ mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+ paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
+ | tr 'A-Z' 'a-z' | tr -d " ") \
+ > ${feats_dir}/data/${x}/text
+ tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
+ done
+fi
+
+feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
+feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Feature Generation"
+ # compute fbank features
+ fbankdir=${feats_dir}/fbank
+ steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
+ ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
+ tools/fix_data_feat.sh ${fbankdir}/train
+ for x in android ios mic; do
+ steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
+ ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
+ tools/fix_data_feat.sh ${fbankdir}/dev_${x}
+ steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
+ ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
+ tools/fix_data_feat.sh ${fbankdir}/test_${x}
+ done
+
+ # compute global cmvn
+ steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
+
+ # apply cmvn
+ steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
+ steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
+ for x in android ios mic; do
+ steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
+ done
+
+ cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
+ tools/fix_data_feat.sh ${feat_train_dir}
+ cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
+ tools/fix_data_feat.sh ${feat_dev_dir}
+ for x in android ios mic; do
+ cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
+ tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
+ done
+fi
+
+token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: Dictionary Preparation"
+ mkdir -p ${feats_dir}/data/${lang}_token_list/char/
+
+ echo "make a dictionary"
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+ num_token=$(cat ${token_list} | wc -l)
+ echo "<unk>" >> ${token_list}
+ vocab_size=$(cat ${token_list} | wc -l)
+ awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
+ awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
+ mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
+ mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
+ cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
+ cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
+fi
+
+# Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: Training"
+ mkdir -p ${exp_dir}/exp/${model_dir}
+ mkdir -p ${exp_dir}/exp/${model_dir}/log
+ INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $gpu_num; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ data2vec_train.py \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --dataset_type $dataset_type \
+ --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
+ --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
+ --resume true \
+ --output_dir ${exp_dir}/exp/${model_dir} \
+ --config $asr_config \
+ --input_size $feats_dim \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+fi
\ No newline at end of file
diff --git a/egs/aishell2/data2vec_pretrain/utils b/egs/aishell2/data2vec_pretrain/utils
new file mode 120000
index 0000000..fe070dd
--- /dev/null
+++ b/egs/aishell2/data2vec_pretrain/utils
@@ -0,0 +1 @@
+../../aishell/transformer/utils
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 5d7d6ea..3769b6c 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -181,7 +181,7 @@
self.nbest = nbest
self.frontend = frontend
self.encoder_downsampling_factor = 1
- if asr_train_args.encoder_conf["input_layer"] == "conv2d":
+ if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
self.encoder_downsampling_factor = 4
@torch.no_grad()
diff --git a/funasr/bin/data2vec_train.py b/funasr/bin/data2vec_train.py
new file mode 100755
index 0000000..b9dbdff
--- /dev/null
+++ b/funasr/bin/data2vec_train.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python3
+
+import os
+
+from funasr.tasks.data2vec import Data2VecTask
+
+
+def parse_args():
+ parser = Data2VecTask.get_parser()
+ parser.add_argument(
+ "--gpu_id",
+ type=int,
+ default=0,
+ help="local gpu id.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main(args=None, cmd=None):
+ # for data2vec Training
+ Data2VecTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ # setup local gpu_id
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+ # DDP settings
+ if args.ngpu > 1:
+ args.distributed = True
+ else:
+ args.distributed = False
+ assert args.num_worker_count == 1
+
+ # re-compute batch size: when dataset type is small
+ if args.dataset_type == "small":
+ if args.batch_size is not None:
+ args.batch_size = args.batch_size * args.ngpu
+ if args.batch_bins is not None:
+ args.batch_bins = args.batch_bins * args.ngpu
+
+ main(args=args)
diff --git a/funasr/datasets/collate_fn.py b/funasr/datasets/collate_fn.py
index d52032f..d34d610 100644
--- a/funasr/datasets/collate_fn.py
+++ b/funasr/datasets/collate_fn.py
@@ -80,4 +80,56 @@
output = (uttids, output)
assert check_return_type(output)
+ return output
+
+def crop_to_max_size(feature, target_size):
+ size = len(feature)
+ diff = size - target_size
+ if diff <= 0:
+ return feature
+
+ start = np.random.randint(0, diff + 1)
+ end = size - diff + start
+ return feature[start:end]
+
+
+def clipping_collate_fn(
+ data: Collection[Tuple[str, Dict[str, np.ndarray]]],
+ max_sample_size=None,
+ not_sequence: Collection[str] = (),
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+ # mainly for pre-training
+ assert check_argument_types()
+ uttids = [u for u, _ in data]
+ data = [d for _, d in data]
+
+ assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
+ assert all(
+ not k.endswith("_lengths") for k in data[0]
+ ), f"*_lengths is reserved: {list(data[0])}"
+
+ output = {}
+ for key in data[0]:
+ array_list = [d[key] for d in data]
+ tensor_list = [torch.from_numpy(a) for a in array_list]
+ sizes = [len(s) for s in tensor_list]
+ if max_sample_size is None:
+ target_size = min(sizes)
+ else:
+ target_size = min(min(sizes), max_sample_size)
+ tensor = tensor_list[0].new_zeros(len(tensor_list), target_size, tensor_list[0].shape[1])
+ for i, (source, size) in enumerate(zip(tensor_list, sizes)):
+ diff = size - target_size
+ if diff == 0:
+ tensor[i] = source
+ else:
+ tensor[i] = crop_to_max_size(source, target_size)
+ output[key] = tensor
+
+ if key not in not_sequence:
+ lens = torch.tensor([source.shape[0] for source in tensor], dtype=torch.long)
+ output[key + "_lengths"] = lens
+
+ output = (uttids, output)
+ assert check_return_type(output)
return output
\ No newline at end of file
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 146723d..8f7fd0b 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -35,15 +35,16 @@
class ArkDataLoader(AbsIterFactory):
def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"):
- symbol_table = read_symbol_table(dict_file)
+ symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
if seg_dict_file is not None:
seg_dict = load_seg_dict(seg_dict_file)
else:
seg_dict = None
self.dataset_conf = dataset_conf
logging.info("dataloader config: {}".format(self.dataset_conf))
+ batch_mode = self.dataset_conf.get("batch_mode", "padding")
self.dataset = Dataset(data_list, symbol_table, seg_dict,
- self.dataset_conf, mode=mode)
+ self.dataset_conf, mode=mode, batch_mode=batch_mode)
def build_iter(self, epoch, shuffle=True):
self.dataset.set_epoch(epoch)
diff --git a/funasr/datasets/large_datasets/datapipes/batch.py b/funasr/datasets/large_datasets/datapipes/batch.py
index 9c85d5e..8ec43e9 100644
--- a/funasr/datasets/large_datasets/datapipes/batch.py
+++ b/funasr/datasets/large_datasets/datapipes/batch.py
@@ -24,7 +24,8 @@
batch_size=8000,
len_fn=_default_len_fn,
buffer_size=10240,
- sort_size=500
+ sort_size=500,
+ batch_mode="padding",
):
assert batch_size > 0, "Batch size is required to be larger than 0!"
assert buffer_size >= -1, "Buffer size is required to be larger than -1!"
@@ -35,6 +36,7 @@
self.batch_size = batch_size
self.buffer_size = buffer_size
self.sort_size = sort_size
+ self.batch_mode = batch_mode
def set_epoch(self, epoch):
self.epoch = epoch
@@ -44,55 +46,137 @@
batch = []
bucket = []
max_lengths = 0
+ min_lengths = 999999
batch_lengths = 0
- if self.buffer_size == -1:
- for d in self.datapipe:
- if d[0] > self.batch_size:
- continue
- buffer.append(d)
- buffer.sort()
- for sample in buffer:
- length, _, token = sample
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- bucket.append(batch)
- batch = []
- max_lengths = length
- batch.append(token)
- random.shuffle(bucket)
- if bucket:
- for batch_sample in bucket:
- yield batch_sample
- if batch:
- yield batch
-
- elif self.buffer_size == 0:
- for d in self.datapipe:
- if d[0] > self.batch_size:
- continue
- length, _, token = d
- if length > self.batch_size:
- continue
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- if batch:
- yield batch
-
- else:
+ if self.batch_mode == "clipping":
+ assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1"
for d in self.datapipe:
if d[0] > self.batch_size:
continue
buffer.append(d)
if len(buffer) == self.buffer_size:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if buffer:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if bucket:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length < min_lengths:
+ min_lengths = length
+ batch_lengths = min_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ min_lengths = length
+ batch.append(token)
+ bucket = []
+
+ if batch:
+ yield batch
+
+ else:
+ if self.buffer_size == -1:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ buffer.sort()
+ for sample in buffer:
+ length, _, token = sample
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ bucket.append(batch)
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ random.shuffle(bucket)
+ if bucket:
+ for batch_sample in bucket:
+ yield batch_sample
+ if batch:
+ yield batch
+
+ elif self.buffer_size == 0:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ length, _, token = d
+ if length > self.batch_size:
+ continue
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ if batch:
+ yield batch
+
+ else:
+ for d in self.datapipe:
+ if d[0] > self.batch_size:
+ continue
+ buffer.append(d)
+ if len(buffer) == self.buffer_size:
+ random.shuffle(buffer)
+ for sample in buffer:
+ bucket.append(sample)
+ if len(bucket) == self.sort_size:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
+ buffer = []
+
+ if buffer:
random.shuffle(buffer)
for sample in buffer:
bucket.append(sample)
@@ -111,38 +195,19 @@
bucket = []
buffer = []
- if buffer:
- random.shuffle(buffer)
- for sample in buffer:
- bucket.append(sample)
- if len(bucket) == self.sort_size:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- bucket = []
- buffer = []
+ if bucket:
+ bucket.sort()
+ for x in bucket:
+ length, _, token = x
+ if length > max_lengths:
+ max_lengths = length
+ batch_lengths = max_lengths * (len(batch) + 1)
+ if batch_lengths > self.batch_size:
+ yield batch
+ batch = []
+ max_lengths = length
+ batch.append(token)
+ bucket = []
- if bucket:
- bucket.sort()
- for x in bucket:
- length, _, token = x
- if length > max_lengths:
- max_lengths = length
- batch_lengths = max_lengths * (len(batch) + 1)
- if batch_lengths > self.batch_size:
- yield batch
- batch = []
- max_lengths = length
- batch.append(token)
- bucket = []
-
- if batch:
- yield batch
+ if batch:
+ yield batch
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 41d34ab..55b0678 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -13,6 +13,7 @@
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
from funasr.datasets.large_datasets.utils.filter import filter
from funasr.datasets.large_datasets.utils.padding import padding
+from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.tokenize import tokenize
@@ -101,6 +102,8 @@
elif data_type == "text" or data_type == "sound":
text_reader = open(data_file, "r")
reader_list.append(text_reader)
+ elif data_type == "none":
+ continue
else:
raise TypeError("Data type {} is not supported".format(data_type))
@@ -143,7 +146,8 @@
dict,
seg_dict,
conf,
- mode="train"):
+ mode="train",
+ batch_mode="padding"):
scp_lists = read_lists(data_list_file)
shuffle = conf.get('shuffle', True)
data_names = conf.get("data_names", "speech,text")
@@ -154,9 +158,10 @@
filter_fn = partial(filter, **filter_conf)
dataset = FilterIterDataPipe(dataset, fn=filter_fn)
- vocab = {'vocab': dict, 'seg_dict': seg_dict}
- tokenize_fn = partial(tokenize, **vocab)
- dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
+ if "text" in data_names:
+ vocab = {'vocab': dict, 'seg_dict': seg_dict}
+ tokenize_fn = partial(tokenize, **vocab)
+ dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
if shuffle:
buffer_conf = conf.get('shuffle_conf', {})
@@ -180,8 +185,9 @@
batch_size=batch_size,
len_fn=len_fn,
buffer_size=buffer_size,
- sort_size=sort_size)
+ sort_size=sort_size,
+ batch_mode=batch_mode)
- dataset = MapperIterDataPipe(dataset, fn=padding)
+ dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping)
return dataset
diff --git a/funasr/datasets/large_datasets/utils/clipping.py b/funasr/datasets/large_datasets/utils/clipping.py
new file mode 100644
index 0000000..f5c2940
--- /dev/null
+++ b/funasr/datasets/large_datasets/utils/clipping.py
@@ -0,0 +1,40 @@
+import numpy as np
+import torch
+
+from funasr.datasets.collate_fn import crop_to_max_size
+
+
+def clipping(data):
+ assert isinstance(data, list)
+ assert "key" in data[0]
+
+ keys = [x["key"] for x in data]
+
+ batch = {}
+ data_names = data[0].keys()
+ for data_name in data_names:
+ if data_name == "key":
+ continue
+ else:
+ if data[0][data_name].dtype.kind == "i":
+ tensor_type = torch.int64
+ else:
+ tensor_type = torch.float32
+
+ tensor_list = [torch.tensor(np.copy(d[data_name]), dtype=tensor_type) for d in data]
+ tensor_lengths = torch.tensor([len(d[data_name]) for d in data], dtype=torch.int32)
+
+ length_clip = min(tensor_lengths)
+ tensor_clip = tensor_list[0].new_zeros(len(tensor_list), length_clip, tensor_list[0].shape[1])
+ for i, (tensor, length) in enumerate(zip(tensor_list, tensor_lengths)):
+ diff = length - length_clip
+ assert diff >= 0
+ if diff == 0:
+ tensor_clip[i] = tensor
+ else:
+ tensor_clip[i] = crop_to_max_size(tensor, length_clip)
+
+ batch[data_name] = tensor_clip
+ batch[data_name + "_lengths"] = torch.tensor([tensor.shape[0] for tensor in tensor_clip], dtype=torch.long)
+
+ return keys, batch
diff --git a/funasr/datasets/large_datasets/utils/filter.py b/funasr/datasets/large_datasets/utils/filter.py
index 91ba4be..1260a47 100644
--- a/funasr/datasets/large_datasets/utils/filter.py
+++ b/funasr/datasets/large_datasets/utils/filter.py
@@ -6,13 +6,21 @@
speech_length_max=15000,
token_length_min=0,
token_length_max=200):
- assert "speech" in data
- assert "text" in data
+ assert "speech" in data or "text" in data
- if "sampling_rate" in data:
- speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+ if "speech" in data and "text" in data:
+ if "sampling_rate" in data:
+ speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+ else:
+ speech_length = data["speech"].shape[0]
+ num_tokens = len(data['text'])
+ return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
+ elif "speech" in data:
+ if "sampling_rate" in data:
+ speech_length = (data["speech"].shape[0] / data["sampling_rate"]) * 1000.
+ else:
+ speech_length = data["speech"].shape[0]
+ return speech_length_min < speech_length < speech_length_max
else:
- speech_length = data["speech"].shape[0]
- num_tokens = len(data['text'])
-
- return speech_length_min < speech_length < speech_length_max and token_length_min < num_tokens < token_length_max
+ num_tokens = len(data['text'])
+ return token_length_min < num_tokens < token_length_max
diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py
new file mode 100644
index 0000000..fcd6bd2
--- /dev/null
+++ b/funasr/models/data2vec.py
@@ -0,0 +1,160 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+class Data2VecPretrainModel(AbsESPnetModel):
+ """Data2Vec Pretrain model"""
+
+ def __init__(
+ self,
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ encoder: AbsEncoder,
+ ):
+ assert check_argument_types()
+
+ super().__init__()
+
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ self.preencoder = preencoder
+ self.encoder = encoder
+ self.num_updates = 0
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape)
+
+ self.encoder.set_num_updates(self.num_updates)
+
+ # 1. Encoder
+ encoder_out = self.encode(speech, speech_lengths)
+
+ losses = encoder_out["losses"]
+ loss = sum(losses.values())
+ sample_size = encoder_out["sample_size"]
+ loss = loss.sum() / sample_size
+
+ target_var = float(encoder_out["target_var"])
+ pred_var = float(encoder_out["pred_var"])
+ ema_decay = float(encoder_out["ema_decay"])
+
+ stats = dict(
+ loss=torch.clone(loss.detach()),
+ target_var=target_var,
+ pred_var=pred_var,
+ ema_decay=ema_decay,
+ )
+
+ loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor
+ ) -> Dict[str, torch.Tensor]:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ):
+ """Frontend + Encoder.
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # Pre-encoder, e.g. used for raw input data
+ if self.preencoder is not None:
+ feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+ # 4. Forward encoder
+ if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None
+ speech_lengths = None
+ encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
+
+ return encoder_out
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+
+ if self.frontend is not None:
+ # Frontend
+ # e.g. STFT and Feature extract
+ # data_loader may send time-domain signal in this case
+ # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ # No frontend and no feature extract
+ feats, feats_lengths = speech, speech_lengths
+ return feats, feats_lengths
+
+ def set_num_updates(self, num_updates):
+ self.num_updates = num_updates
+
+ def get_num_updates(self):
+ return self.num_updates
diff --git a/funasr/optimizers/fairseq_adam.py b/funasr/optimizers/fairseq_adam.py
new file mode 100644
index 0000000..9bdd0f8
--- /dev/null
+++ b/funasr/optimizers/fairseq_adam.py
@@ -0,0 +1,148 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import math
+
+import torch
+import torch.optim
+
+
+class FairseqAdam(torch.optim.Optimizer):
+ r"""Implements Adam algorithm.
+
+ This implementation is modified from torch.optim.Adam based on:
+ `Fixed Weight Decay Regularization in Adam`
+ (see https://arxiv.org/abs/1711.05101)
+
+ It has been proposed in `Adam: A Method for Stochastic Optimization`_.
+
+ Args:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
+
+ .. _Adam\: A Method for Stochastic Optimization:
+ https://arxiv.org/abs/1412.6980
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ adam_betas=(0.9, 0.999),
+ adam_eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ ):
+ defaults = dict(
+ lr=lr, betas=adam_betas, eps=adam_eps, weight_decay=weight_decay, amsgrad=amsgrad
+ )
+ super(FairseqAdam, self).__init__(params, defaults)
+ self.optimizer_lr = lr
+
+ @property
+ def supports_memory_efficient_fp16(self):
+ return True
+
+ @property
+ def supports_flat_params(self):
+ return True
+
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Args:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+ grad = p.grad.data
+ if grad.dtype in {torch.float16, torch.bfloat16}:
+ grad = grad.float()
+ if grad.is_sparse:
+ raise RuntimeError(
+ "Adam does not support sparse gradients, please consider SparseAdam instead"
+ )
+ amsgrad = group.get("amsgrad", False)
+
+ p_data_fp32 = p.data
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
+ p_data_fp32 = p_data_fp32.float()
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(p_data_fp32)
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32)
+ else:
+ state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
+ state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32)
+ if amsgrad:
+ state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to(
+ p_data_fp32
+ )
+
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+ if amsgrad:
+ max_exp_avg_sq = state["max_exp_avg_sq"]
+ beta1, beta2 = group["betas"]
+
+ state["step"] += 1
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+ if amsgrad:
+ # Maintains the maximum of all 2nd moment running avg. till now
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+ # Use the max. for normalizing running avg. of gradient
+ denom = max_exp_avg_sq.sqrt().add_(group["eps"])
+ else:
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
+
+ bias_correction1 = 1 - beta1 ** state["step"]
+ bias_correction2 = 1 - beta2 ** state["step"]
+ step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+
+ if group["weight_decay"] != 0:
+ p_data_fp32.add_(
+ p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
+ )
+
+ p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
+
+ if p.data.dtype in {torch.float16, torch.bfloat16}:
+ p.data.copy_(p_data_fp32)
+
+ return loss
+
+ def set_lr(self, lr):
+ """Set the learning rate."""
+ for param_group in self.param_groups:
+ param_group["lr"] = lr
diff --git a/funasr/schedulers/tri_stage_scheduler.py b/funasr/schedulers/tri_stage_scheduler.py
new file mode 100644
index 0000000..8dc71b4
--- /dev/null
+++ b/funasr/schedulers/tri_stage_scheduler.py
@@ -0,0 +1,108 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+from typing import Optional, List
+
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+from typeguard import check_argument_types
+
+from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
+
+
+class TriStageLR(_LRScheduler, AbsBatchStepScheduler):
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ last_epoch: int = -1,
+ phase_ratio: Optional[List[float]] = None,
+ init_lr_scale: float = 0.01,
+ final_lr_scale: float = 0.01,
+ ):
+ assert check_argument_types()
+ self.optimizer = optimizer
+ self.last_epoch = last_epoch
+ self.phase_ratio = phase_ratio
+ self.init_lr_scale = init_lr_scale
+ self.final_lr_scale = final_lr_scale
+ self.optimizer_lr = self.optimizer.defaults["lr"]
+
+ def init_tri_stage_scheudler(self, max_update):
+ self.max_update = max_update
+ self.peak_lr = self.optimizer_lr
+ self.init_lr = self.init_lr_scale * self.optimizer_lr
+ self.final_lr = self.final_lr_scale * self.optimizer_lr
+
+ assert self.max_update > 0
+ assert sum(self.phase_ratio) == 1, "phase ratios must add up to 1"
+ assert len(self.phase_ratio) == 3
+ self.warmup_steps = int(self.max_update * self.phase_ratio[0])
+ self.hold_steps = int(self.max_update * self.phase_ratio[1])
+ self.decay_steps = int(self.max_update * self.phase_ratio[2])
+
+ self.warmup_rate = (
+ (self.peak_lr - self.init_lr) / self.warmup_steps
+ if self.warmup_steps != 0
+ else 0
+ )
+ self.decay_factor = -math.log(self.final_lr_scale) / self.decay_steps
+
+ # initial learning rate
+ self.lr = self.init_lr
+
+ # __init__() must be invoked before setting field
+ # because step() is also invoked in __init__()
+ self.set_optimizer_lr(self.lr)
+ super().__init__(self.optimizer, self.last_epoch)
+
+ def _decide_stage(self, update_step):
+ """
+ return stage, and the corresponding steps within the current stage
+ """
+ if update_step < self.warmup_steps:
+ # warmup state
+ return 0, update_step
+
+ offset = self.warmup_steps
+
+ if update_step < offset + self.hold_steps:
+ # hold stage
+ return 1, update_step - offset
+
+ offset += self.hold_steps
+
+ if update_step <= offset + self.decay_steps:
+ # decay stage
+ return 2, update_step - offset
+
+ offset += self.decay_steps
+
+ # still here ? constant lr stage
+ return 3, update_step - offset
+
+ def step_update(self, num_updates):
+ """Update the learning rate after each update."""
+ stage, steps_in_stage = self._decide_stage(num_updates)
+ if stage == 0:
+ self.lr = self.init_lr + self.warmup_rate * steps_in_stage
+ elif stage == 1:
+ self.lr = self.peak_lr
+ elif stage == 2:
+ self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage)
+ elif stage == 3:
+ self.lr = self.final_lr
+ else:
+ raise ValueError("Undefined stage")
+ self.set_optimizer_lr(self.lr)
+
+ def set_optimizer_lr(self, lr):
+ for param_group in self.optimizer.param_groups:
+ param_group["lr"] = lr
+
+ def get_lr(self):
+ step_num = self.last_epoch + 1
+ self.step_update(step_num)
+ return [self.lr]
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 4e79c63..7899400 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -44,11 +44,13 @@
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
from funasr.optimizers.sgd import SGD
+from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.samplers.build_batch_sampler import BATCH_TYPES
from funasr.samplers.build_batch_sampler import build_batch_sampler
from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
from funasr.schedulers.noam_lr import NoamLR
from funasr.schedulers.warmup_lr import WarmupLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
@@ -83,6 +85,7 @@
optim_classes = dict(
adam=torch.optim.Adam,
+ fairseq_adam=FairseqAdam,
adamw=torch.optim.AdamW,
sgd=SGD,
adadelta=torch.optim.Adadelta,
@@ -149,6 +152,7 @@
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
noamlr=NoamLR,
warmuplr=WarmupLR,
+ tri_stage=TriStageLR,
cycliclr=torch.optim.lr_scheduler.CyclicLR,
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
diff --git a/funasr/tasks/data2vec.py b/funasr/tasks/data2vec.py
new file mode 100644
index 0000000..9a64e1f
--- /dev/null
+++ b/funasr/tasks/data2vec.py
@@ -0,0 +1,376 @@
+import argparse
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.data2vec import Data2VecPretrainModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.tasks.abs_task import AbsTask
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
+ type_check=AbsFrontend,
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(specaug=SpecAug),
+ type_check=AbsSpecAug,
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ type_check=AbsNormalize,
+ default=None,
+ optional=True,
+)
+preencoder_choices = ClassChoices(
+ name="preencoder",
+ classes=dict(
+ sinc=LightweightSincConvs,
+ ),
+ type_check=AbsPreEncoder,
+ default=None,
+ optional=True,
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ data2vec_encoder=Data2VecEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="data2vec_encoder",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ data2vec=Data2VecPretrainModel,
+ ),
+ default="data2vec",
+)
+
+
+class Data2VecTask(AbsTask):
+ # If you need more than one optimizers, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --preencoder and --preencoder_conf
+ preencoder_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --model and --model_conf
+ model_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def add_task_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(description="Task related")
+
+ # NOTE(kamo): add_arguments(..., required=True) can't be used
+ # to provide --print_config mode. Instead of it, do as
+ group.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ group.add_argument(
+ "--init",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The initialization method",
+ choices=[
+ "chainer",
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ None,
+ ],
+ )
+
+ group.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
+ group = parser.add_argument_group(description="Preprocess related")
+ group.add_argument(
+ "--use_preprocessor",
+ type=str2bool,
+ default=True,
+ help="Apply preprocessing to data or not",
+ )
+ group.add_argument(
+ "--token_type",
+ type=str,
+ default=None,
+ choices=["bpe", "char", "word", "phn"],
+ help="The text will be tokenized " "in the specified level token",
+ )
+
+ group.add_argument(
+ "--feats_type",
+ type=str,
+ default='fbank',
+ help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
+ )
+
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model file of sentencepiece",
+ )
+ parser.add_argument(
+ "--non_linguistic_symbols",
+ type=str_or_none,
+ help="non_linguistic_symbols file path",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese"],
+ default=None,
+ help="Apply text cleaning",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="Specify g2p method if --token_type=phn",
+ )
+ parser.add_argument(
+ "--speech_volume_normalize",
+ type=float_or_none,
+ default=None,
+ help="Scale the maximum amplitude to the given value.",
+ )
+ parser.add_argument(
+ "--rir_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of rir scp file.",
+ )
+ parser.add_argument(
+ "--rir_apply_prob",
+ type=float,
+ default=1.0,
+ help="THe probability for applying RIR convolution.",
+ )
+ parser.add_argument(
+ "--noise_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+ parser.add_argument(
+ "--noise_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability applying Noise adding.",
+ )
+ parser.add_argument(
+ "--noise_db_range",
+ type=str,
+ default="13_15",
+ help="The range of noise decibel level.",
+ )
+ parser.add_argument(
+ "--pred_masked_weight",
+ type=float,
+ default=1.0,
+ help="weight for predictive loss for masked frames",
+ )
+ parser.add_argument(
+ "--pred_nomask_weight",
+ type=float,
+ default=0.0,
+ help="weight for predictive loss for unmasked frames",
+ )
+ parser.add_argument(
+ "--loss_weights",
+ type=float,
+ default=0.0,
+ help="weights for additional loss terms (not first one)",
+ )
+
+ for class_choices in cls.class_choices_list:
+ # Append --<name> and --<name>_conf.
+ # e.g. --encoder and --encoder_conf
+ class_choices.add_arguments(group)
+
+ @classmethod
+ def build_collate_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Callable[
+ [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+ Tuple[List[str], Dict[str, torch.Tensor]],
+ ]:
+ assert check_argument_types()
+ return CommonCollateFn(clipping=True)
+
+ @classmethod
+ def build_preprocess_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+ assert check_argument_types()
+ if args.use_preprocessor:
+ retval = CommonPreprocessor(
+ train=train,
+ bpemodel=args.bpemodel,
+ non_linguistic_symbols=args.non_linguistic_symbols,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ # NOTE(kamo): Check attribute existence for backward compatibility
+ rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+ rir_apply_prob=args.rir_apply_prob
+ if hasattr(args, "rir_apply_prob")
+ else 1.0,
+ noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+ noise_apply_prob=args.noise_apply_prob
+ if hasattr(args, "noise_apply_prob")
+ else 1.0,
+ noise_db_range=args.noise_db_range
+ if hasattr(args, "noise_db_range")
+ else "13_15",
+ speech_volume_normalize=args.speech_volume_normalize
+ if hasattr(args, "rir_scp")
+ else None,
+ )
+ else:
+ retval = None
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def required_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ # for pre-training
+ retval = ("speech",)
+ return retval
+
+ @classmethod
+ def optional_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ retval = ()
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ assert check_argument_types()
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Pre-encoder input block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ if getattr(args, "preencoder", None) is not None:
+ preencoder_class = preencoder_choices.get_class(args.preencoder)
+ preencoder = preencoder_class(**args.preencoder_conf)
+ input_size = preencoder.output_size()
+ else:
+ preencoder = None
+
+ # 5. Encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(
+ input_size=input_size,
+ **args.encoder_conf,
+ )
+
+ # 6. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("data2vec")
+ model = model_class(
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ )
+
+ # 7. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
--
Gitblit v1.9.1