From 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 16 五月 2023 15:07:20 +0800
Subject: [PATCH] Merge pull request #518 from alibaba-damo-academy/dev_wjm2
---
egs/aishell/data2vec_transformer_finetune/run.sh | 110
egs/librispeech_100h/conformer/local/spm_train.py | 12
funasr/datasets/small_datasets/length_batch_sampler.py | 147
funasr/models/e2e_asr.py | 13
egs/aishell2/transformer/utils/fix_data.sh | 4
egs/librispeech_100h/conformer/conf/decode_asr_transformer.yaml | 6
funasr/datasets/small_datasets/preprocessor.py | 875 +++++
egs/librispeech_100h/conformer/utils | 1
funasr/build_utils/build_trainer.py | 820 +++++
egs/aishell2/conformer/local/prepare_data.sh | 6
egs/aishell2/transformer/utils/compute_cmvn.py | 88
funasr/layers/abs_normalize.py | 2
funasr/tasks/vad.py | 79
egs/aishell2/paraformerbert/local/extract_embeds.sh | 31
egs/aishell/conformer/run.sh | 103
funasr/tasks/punctuation.py | 2
egs/aishell/transformer/utils/compute_cmvn.sh | 21
egs/librispeech/conformer/local/spm_train.py | 12
egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh | 105
egs/aishell2/transformer/utils/fix_data_feat.sh | 8
funasr/models/encoder/abs_encoder.py | 2
funasr/build_utils/__init__.py | 0
funasr/main_funcs/collect_stats.py | 4
funasr/models/encoder/resnet34_encoder.py | 2
funasr/models/encoder/conformer_encoder.py | 2
egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml | 68
funasr/build_utils/build_scheduler.py | 44
egs/aishell/paraformerbert/local/extract_embeds.sh | 12
funasr/models/encoder/sanm_encoder.py | 3
funasr/tasks/diar.py | 12
egs/aishell/paraformer/local/download_and_untar.sh | 105
funasr/tasks/lm.py | 3
egs/aishell/conformer/conf/train_asr_conformer.yaml | 28
funasr/build_utils/build_distributed.py | 38
funasr/tasks/abs_task.py | 55
egs/aishell2/paraformerbert/local/prepare_data.sh | 7
egs/aishell/data2vec_paraformer_finetune/run.sh | 112
funasr/models/encoder/data2vec_encoder.py | 2
funasr/models/e2e_vad.py | 8
egs/librispeech/conformer/conf/train_asr_conformer.yaml | 14
egs/aishell2/paraformer/run.sh | 105
funasr/main_funcs/calculate_all_attentions.py | 4
funasr/models/e2e_diar_eend_ola.py | 5
funasr/models/e2e_diar_sond.py | 8
egs/aishell2/transformer/conf/train_asr_transformer.yaml | 22
funasr/models/e2e_tp.py | 15
egs/aishell2/transformer/utils/compute_cmvn.sh | 21
funasr/train/abs_model.py | 28
egs/aishell/transformer/path.sh | 2
funasr/models/frontend/windowing.py | 7
egs/aishell/transformer/utils/cmvn_converter.py | 6
funasr/build_utils/build_diar_model.py | 296 +
egs/aishell2/paraformer/local/prepare_data.sh | 6
egs/aishell2/paraformerbert/run.sh | 136
funasr/models/base_model.py | 17
egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml | 29
funasr/datasets/small_datasets/sequence_iter_factory.py | 189 +
funasr/bin/train.py | 572 +++
docs/reference/build_task.md | 2
egs/aishell/conformer/local/download_and_untar.sh | 105
funasr/build_utils/build_vad_model.py | 77
egs/aishell2/data2vec_pretrain/run.sh | 131
funasr/build_utils/build_optimizer.py | 28
funasr/models/encoder/rnn_encoder.py | 3
funasr/models/frontend/wav_frontend_kaldifeat.py | 119
funasr/build_utils/build_pretrain_model.py | 107
funasr/models/e2e_asr_paraformer.py | 162
egs/aishell2/transformer/utils/combine_cmvn_file.py | 27
egs/librispeech_100h/conformer/path.sh | 0
egs/aishell2/transformer/utils/compute_fbank.py | 24
egs/librispeech/conformer/run.sh | 143
egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh | 105
egs/librispeech_100h/conformer/run.sh | 212 +
funasr/utils/prepare_data.py | 209 +
egs/aishell2/transformer/run.sh | 101
egs/librispeech_100h/conformer/local/spm_encode.py | 98
egs/aishell/paraformer/run.sh | 100
funasr/build_utils/build_dataloader.py | 15
egs/aishell/paraformerbert/run.sh | 114
egs/librispeech_100h/conformer/local/data_prep.sh | 0
egs/aishell/paraformerbert/local/download_and_untar.sh | 105
funasr/layers/global_mvn.py | 5
funasr/build_utils/build_args.py | 93
funasr/models/e2e_sv.py | 8
funasr/datasets/large_datasets/build_dataloader.py | 36
funasr/train/trainer.py | 4
funasr/datasets/small_datasets/dataset.py | 258 +
egs/librispeech/conformer/local/spm_encode.py | 98
funasr/models/frontend/abs_frontend.py | 2
funasr/build_utils/build_punc_model.py | 68
egs/aishell/transformer/run.sh | 112
egs/aishell/paraformerbert/local/aishell_data_prep.sh | 23
egs/librispeech/conformer/local/download_and_untar.sh | 97
funasr/models/data2vec.py | 6
egs/aishell2/transformer/utils/download_model.py | 20
funasr/build_utils/build_lm_model.py | 57
funasr/build_utils/build_asr_model.py | 352 ++
egs/aishell/transformer/utils/compute_cmvn.py | 88
egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml | 32
egs/aishell2/transformer/utils/cmvn_converter.py | 51
egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml | 29
egs/aishell/transformer/conf/train_asr_transformer.yaml | 30
funasr/models/vad_realtime_transformer.py | 3
funasr/datasets/large_datasets/dataset.py | 24
egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml | 39
funasr/bin/asr_test.py | 226 +
funasr/datasets/small_datasets/collate_fn.py | 93
funasr/datasets/large_datasets/utils/tokenize.py | 2
egs/aishell/transformer/utils/combine_cmvn_file.py | 27
funasr/fileio/sound_scp.py | 15
funasr/models/e2e_asr_mfcca.py | 138
egs/librispeech_100h/conformer/local/download_and_untar.sh | 97
funasr/layers/inversible_interface.py | 2
funasr/models/encoder/mfcca_encoder.py | 120
egs/aishell2/transformer/local/prepare_data.sh | 6
funasr/models/frontend/fused.py | 2
funasr/models/target_delay_transformer.py | 5
egs/aishell2/transformer/utils/compute_fbank.sh | 5
funasr/models/frontend/s3prl.py | 1
egs/aishell2/transformer/utils/compute_wer.py | 4
funasr/models/e2e_uni_asr.py | 11
funasr/lm/abs_model.py | 4
egs/aishell2/conformer/conf/train_asr_conformer.yaml | 22
egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml | 22
funasr/models/specaug/abs_specaug.py | 2
funasr/tasks/sv.py | 10
egs/aishell/transformer/local/download_and_untar.sh | 105
egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml | 29
funasr/tasks/asr.py | 18
egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml | 39
egs/aishell2/conformer/run.sh | 95
funasr/models/frontend/wav_frontend.py | 2
egs/librispeech/conformer/local/data_prep.sh | 0
/dev/null | 55
funasr/build_utils/build_model.py | 25
funasr/models/frontend/default.py | 4
136 files changed, 7,292 insertions(+), 1,850 deletions(-)
diff --git a/docs/reference/build_task.md b/docs/reference/build_task.md
index be2d1af..2020860 100644
--- a/docs/reference/build_task.md
+++ b/docs/reference/build_task.md
@@ -103,7 +103,7 @@
)
return model
```
-This function defines the detail of the model. For different speech recognition models, the same speech recognition `Task` can usually be shared and the remaining thing needed to be done is to define a specific model in this function. For example, a speech recognition model with a standard encoder-decoder structure has been shown above. Specifically, it first defines each module of the model, including encoder, decoder, etc. and then combine these modules together to generate a complete model. In FunASR, the model needs to inherit `AbsESPnetModel` and the corresponding code can be seen in `funasr/train/abs_espnet_model.py`. The main function needed to be implemented is the `forward` function.
+This function defines the detail of the model. For different speech recognition models, the same speech recognition `Task` can usually be shared and the remaining thing needed to be done is to define a specific model in this function. For example, a speech recognition model with a standard encoder-decoder structure has been shown above. Specifically, it first defines each module of the model, including encoder, decoder, etc. and then combine these modules together to generate a complete model. In FunASR, the model needs to inherit `FunASRModel` and the corresponding code can be seen in `funasr/train/abs_espnet_model.py`. The main function needed to be implemented is the `forward` function.
Next, we take `SANMEncoder` as an example to introduce how to use a custom encoder as a part of the model when defining the specified model and the corresponding code can be seen in `funasr/models/encoder/sanm_encoder.py`. For a custom encoder, in addition to inheriting the common encoder class `AbsEncoder`, it is also necessary to define the `forward` function to achieve the forward computation of the `encoder`. After defining the `encoder`, it should also be registered in the `Task`. The corresponding code example can be seen as below:
```python
diff --git a/egs/aishell/conformer/conf/train_asr_conformer.yaml b/egs/aishell/conformer/conf/train_asr_conformer.yaml
index ddf217e..ea2353e 100644
--- a/egs/aishell/conformer/conf/train_asr_conformer.yaml
+++ b/egs/aishell/conformer/conf/train_asr_conformer.yaml
@@ -29,21 +29,27 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
-# minibatch related
-batch_type: length
-batch_bins: 25000
-num_workers: 16
-
# optimization related
accum_grad: 1
grad_clip: 5
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -76,5 +82,15 @@
- 40
num_time_mask: 2
+dataset_conf:
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
+
log_interval: 50
normalize: None
diff --git a/egs/aishell/conformer/local/download_and_untar.sh b/egs/aishell/conformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/conformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/conformer/local/prepare_data.sh b/egs/aishell/conformer/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/conformer/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
- echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
- echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
- exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
- echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
- exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
- cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/conformer/run.sh b/egs/aishell/conformer/run.sh
index 227b3f2..8db0d97 100755
--- a/egs/aishell/conformer/run.sh
+++ b/egs/aishell/conformer/run.sh
@@ -16,22 +16,20 @@
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
stop_stage=4
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
# exp tag
tag="exp1"
@@ -49,7 +47,7 @@
test_sets="dev test"
asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -66,10 +64,16 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -79,46 +83,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -126,22 +93,14 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
# Training Stage
@@ -161,26 +120,22 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -203,7 +158,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -224,6 +179,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -244,5 +200,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
-
+fi
\ No newline at end of file
diff --git a/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml b/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
index f9a2cdb..287b088 100644
--- a/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
+++ b/egs/aishell/data2vec_paraformer_finetune/conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
@@ -30,7 +30,6 @@
require_same_masks: true
mask_dropout: 0
-
# decoder related
decoder: paraformer_decoder_san
decoder_conf:
@@ -42,6 +41,18 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+# hybrid CTC/attention
model: paraformer
model_conf:
ctc_weight: 0.3
@@ -50,15 +61,10 @@
predictor_weight: 1.0
sampling_ratio: 0.4
-# minibatch related
-batch_type: length
-batch_bins: 25000
-num_workers: 16
-
# optimization related
accum_grad: 1
grad_clip: 5
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -98,6 +104,15 @@
l_order: 1
r_order: 1
+dataset_conf:
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
log_interval: 50
unused_parameters: true
diff --git a/egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh b/egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/data2vec_paraformer_finetune/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh b/egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/data2vec_paraformer_finetune/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
- echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
- echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
- exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
- echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
- exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
- cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/data2vec_paraformer_finetune/run.sh b/egs/aishell/data2vec_paraformer_finetune/run.sh
index d033ce2..b9d1668 100755
--- a/egs/aishell/data2vec_paraformer_finetune/run.sh
+++ b/egs/aishell/data2vec_paraformer_finetune/run.sh
@@ -8,33 +8,31 @@
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
+njob=1
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
stop_stage=4
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
# exp tag
-tag=""
+tag="exp1"
model_name=damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch
init_param="$HOME/.cache/modelscope/hub/$model_name/basemodel.pb"
@@ -52,7 +50,7 @@
test_sets="dev test"
asr_config=conf/train_asr_paraformer_transformer_12e_6d_3072_768.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -69,10 +67,16 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -82,46 +86,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -129,35 +96,27 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
# Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: Training"
- python utils/download_model.py --model_name ${model_name} # download pretrained model on ModelScope
+ python utils/download_model.py --model_name ${model_name} # download pretrained model on ModelScope
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
- fi
+ fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
@@ -165,27 +124,22 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train_paraformer.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
--init_param ${init_param} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -208,7 +162,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -229,6 +183,8 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -249,4 +205,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
+fi
\ No newline at end of file
diff --git a/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml b/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
index 5bc5236..ad3ad2e 100644
--- a/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
+++ b/egs/aishell/data2vec_transformer_finetune/conf/train_asr_transformer_12e_6d_3072_768.yaml
@@ -30,25 +30,28 @@
require_same_masks: true
mask_dropout: 0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model_conf:
ctc_weight: 1.0
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
-# for logger
-log_interval: 50
-
-# minibatch related
-batch_type: length
-batch_bins: 16000
-num_workers: 16
-
# optimization related
accum_grad: 1
grad_clip: 5
patience: none
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -57,8 +60,6 @@
- cer_ctc
- min
keep_nbest_models: 10
-unused_parameters: true
-normalize: None
# NoamLR is deprecated. Use WarmupLR.
# The following is equivalent setting for NoamLR:
@@ -92,4 +93,18 @@
time_mask_width_range:
- 0
- 40
- num_time_mask: 2
\ No newline at end of file
+ num_time_mask: 2
+
+dataset_conf:
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
+
+log_interval: 50
+unused_parameters: true
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh b/egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/data2vec_transformer_finetune/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh b/egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/data2vec_transformer_finetune/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
- echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
- echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
- exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
- echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
- exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
- cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/data2vec_transformer_finetune/run.sh b/egs/aishell/data2vec_transformer_finetune/run.sh
index 26222e6..7bbcf90 100755
--- a/egs/aishell/data2vec_transformer_finetune/run.sh
+++ b/egs/aishell/data2vec_transformer_finetune/run.sh
@@ -13,28 +13,26 @@
infer_cmd=utils/run.pl
# general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
stop_stage=4
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
# exp tag
-tag=""
+tag="exp1"
model_name=damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch
init_param="$HOME/.cache/modelscope/hub/$model_name/basemodel.pb"
@@ -52,7 +50,7 @@
test_sets="dev test"
asr_config=conf/train_asr_transformer_12e_6d_3072_768.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.cer_ctc.ave_10best.pb
@@ -69,10 +67,16 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -82,46 +86,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -129,35 +96,27 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
# Training Stage
world_size=$gpu_num # run on one machine
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "stage 3: Training"
- python utils/download_model.py --model_name ${model_name} # download pretrained model on ModelScope
+ python utils/download_model.py --model_name ${model_name} # download pretrained model on ModelScope
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
- fi
+ fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
@@ -165,27 +124,23 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
--init_param ${init_param} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -208,7 +163,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -229,6 +184,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -249,4 +205,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
+fi
\ No newline at end of file
diff --git a/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml b/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
index 24b2620..bac8d04 100644
--- a/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
+++ b/egs/aishell/paraformer/conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
@@ -29,6 +29,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
model: paraformer
model_conf:
ctc_weight: 0.3
@@ -37,15 +48,10 @@
predictor_weight: 1.0
sampling_ratio: 0.4
-# minibatch related
-batch_type: length
-batch_bins: 25000
-num_workers: 16
-
# optimization related
accum_grad: 1
grad_clip: 5
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -78,7 +84,7 @@
- 40
num_time_mask: 2
-predictor: cif_predictor_v2
+predictor: cif_predictor
predictor_conf:
idim: 256
threshold: 1.0
@@ -86,6 +92,15 @@
r_order: 1
tail_threshold: 0.45
+dataset_conf:
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
log_interval: 50
normalize: None
\ No newline at end of file
diff --git a/egs/aishell/paraformer/local/download_and_untar.sh b/egs/aishell/paraformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/paraformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/paraformer/local/prepare_data.sh b/egs/aishell/paraformer/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/paraformer/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
- echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
- echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
- exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
- echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
- exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
- cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/paraformer/run.sh b/egs/aishell/paraformer/run.sh
index 53b5f90..430da6e 100755
--- a/egs/aishell/paraformer/run.sh
+++ b/egs/aishell/paraformer/run.sh
@@ -16,25 +16,23 @@
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
stop_stage=4
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
# exp tag
-tag=""
+tag="exp1"
. utils/parse_options.sh || exit 1;
@@ -49,7 +47,7 @@
test_sets="dev test"
asr_config=conf/train_asr_paraformer_conformer_12e_6d_2048_256.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -66,10 +64,16 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -79,46 +83,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -131,17 +98,9 @@
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
# Training Stage
@@ -161,26 +120,22 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train_paraformer.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -203,7 +158,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -224,6 +179,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml b/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
index f51a2ea..f2652e8 100644
--- a/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
+++ b/egs/aishell/paraformerbert/conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
@@ -29,6 +29,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model: paraformer_bert
model_conf:
@@ -41,19 +52,10 @@
embed_dims: 768
embeds_loss_weight: 2.0
-
-
-# minibatch related
-#batch_type: length
-#batch_bins: 40000
-batch_type: numel
-batch_bins: 2000000
-num_workers: 16
-
# optimization related
-accum_grad: 4
+accum_grad: 1
grad_clip: 5
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -92,8 +94,19 @@
threshold: 1.0
l_order: 1
r_order: 1
+ tail_threshold: 0.45
+dataset_conf:
+ data_names: speech,text,embed
+ data_types: sound,text,kaldi_ark
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
log_interval: 50
-normalize: None
-allow_variable_data_keys: true
\ No newline at end of file
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell/paraformerbert/local/aishell_data_prep.sh b/egs/aishell/paraformerbert/local/aishell_data_prep.sh
index b6ea36b..83f489b 100755
--- a/egs/aishell/paraformerbert/local/aishell_data_prep.sh
+++ b/egs/aishell/paraformerbert/local/aishell_data_prep.sh
@@ -5,19 +5,20 @@
#. ./path.sh || exit 1;
-if [ $# != 2 ]; then
- echo "Usage: $0 <audio-path> <text-path>"
- echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript"
+if [ $# != 3 ]; then
+ echo "Usage: $0 <audio-path> <text-path> <output-path>"
+ echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
exit 1;
fi
aishell_audio_dir=$1
aishell_text=$2/aishell_transcript_v0.8.txt
+output_dir=$3
-train_dir=data/local/train
-dev_dir=data/local/dev
-test_dir=data/local/test
-tmp_dir=data/local/tmp
+train_dir=$output_dir/data/local/train
+dev_dir=$output_dir/data/local/dev
+test_dir=$output_dir/data/local/test
+tmp_dir=$output_dir/data/local/tmp
mkdir -p $train_dir
mkdir -p $dev_dir
@@ -53,12 +54,12 @@
sort -u $dir/transcripts.txt > $dir/text
done
-mkdir -p data/train data/dev data/test
+mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
for f in wav.scp text; do
- cp $train_dir/$f data/train/$f || exit 1;
- cp $dev_dir/$f data/dev/$f || exit 1;
- cp $test_dir/$f data/test/$f || exit 1;
+ cp $train_dir/$f $output_dir/data/train/$f || exit 1;
+ cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
+ cp $test_dir/$f $output_dir/data/test/$f || exit 1;
done
echo "$0: AISHELL data preparation succeeded"
diff --git a/egs/aishell/paraformerbert/local/download_and_untar.sh b/egs/aishell/paraformerbert/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/paraformerbert/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/paraformerbert/local/extract_embeds.sh b/egs/aishell/paraformerbert/local/extract_embeds.sh
index 9cf5940..c07e528 100755
--- a/egs/aishell/paraformerbert/local/extract_embeds.sh
+++ b/egs/aishell/paraformerbert/local/extract_embeds.sh
@@ -3,20 +3,16 @@
stage=1
stop_stage=3
-bert_model_root="../../huggingface_models"
bert_model_name="bert-base-chinese"
-#bert_model_name="chinese-roberta-wwm-ext"
-#bert_model_name="mengzi-bert-base"
raw_dataset_path="../DATA"
-model_path=${bert_model_root}/${bert_model_name}
+nj=64
+model_path=${bert_model_name}
. utils/parse_options.sh || exit 1;
-nj=32
-
for data_set in train dev test;do
- scp=$raw_dataset_path/dump/fbank/${data_set}/text
- local_scp_dir_raw=$raw_dataset_path/embeds/$bert_model_name/${data_set}
+ scp=$raw_dataset_path/data/${data_set}/text
+ local_scp_dir_raw=${raw_dataset_path}/data/embeds/${data_set}
local_scp_dir=$local_scp_dir_raw/split$nj
local_records_dir=$local_scp_dir_raw/ark
diff --git a/egs/aishell/paraformerbert/run.sh b/egs/aishell/paraformerbert/run.sh
index 2487eac..a4ad195 100755
--- a/egs/aishell/paraformerbert/run.sh
+++ b/egs/aishell/paraformerbert/run.sh
@@ -8,7 +8,7 @@
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
+njob=1
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
@@ -16,29 +16,26 @@
feats_dir="../DATA" #feature output dictionary, for large data
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
stop_stage=4
skip_extract_embed=false
-bert_model_root="../../huggingface_models"
bert_model_name="bert-base-chinese"
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
# exp tag
-tag=""
+tag="exp1"
. utils/parse_options.sh || exit 1;
@@ -53,7 +50,7 @@
test_sets="dev test"
asr_config=conf/train_asr_paraformerbert_conformer_12e_6d_2048_256.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -70,10 +67,17 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -83,46 +87,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -135,17 +102,9 @@
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
# Training Stage
@@ -155,9 +114,9 @@
if ! "${skip_extract_embed}"; then
echo "extract embeddings..."
local/extract_embeds.sh \
- --bert_model_root ${bert_model_root} \
--bert_model_name ${bert_model_name} \
- --raw_dataset_path ${feats_dir}
+ --raw_dataset_path ${feats_dir} \
+ --nj $nj
fi
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
@@ -172,31 +131,23 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train_paraformer.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_data_path_and_name_and_type ${feats_dir}/embeds/${bert_model_name}/${train_set}/embeds.scp,embed,${type} \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --train_shape_file ${feats_dir}/embeds/${bert_model_name}/${train_set}/embeds.shape \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_data_path_and_name_and_type ${feats_dir}/embeds/${bert_model_name}/${valid_set}/embeds.scp,embed,${type} \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
- --valid_shape_file ${feats_dir}/embeds/${bert_model_name}/${valid_set}/embeds.shape \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --embed_path ${feats_dir}/data \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --allow_variable_data_keys true \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -260,5 +211,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
-
+fi
\ No newline at end of file
diff --git a/egs/aishell/transformer/conf/train_asr_transformer.yaml b/egs/aishell/transformer/conf/train_asr_transformer.yaml
index ce987e7..22e651b 100644
--- a/egs/aishell/transformer/conf/train_asr_transformer.yaml
+++ b/egs/aishell/transformer/conf/train_asr_transformer.yaml
@@ -23,22 +23,28 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
-# minibatch related
-batch_type: length
-batch_bins: 32000
-num_workers: 8
-
# optimization related
accum_grad: 1
grad_clip: 5
-patience: 3
-max_epoch: 20
+patience: none
+max_epoch: 60
val_scheduler_criterion:
- valid
- acc
@@ -66,5 +72,15 @@
scheduler_conf:
warmup_steps: 25000
+dataset_conf:
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 25000
+ num_workers: 8
+
log_interval: 50
normalize: None
diff --git a/egs/aishell/transformer/local/download_and_untar.sh b/egs/aishell/transformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/transformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/transformer/local/prepare_data.sh b/egs/aishell/transformer/local/prepare_data.sh
deleted file mode 100755
index 77791f9..0000000
--- a/egs/aishell/transformer/local/prepare_data.sh
+++ /dev/null
@@ -1,53 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2018 AIShell-Foundation(Authors:Jiayu DU, Xingyu NA, Bengu WU, Hao ZHENG)
-# 2018 Beijing Shell Shell Tech. Co. Ltd. (Author: Hui BU)
-# Apache 2.0
-
-# transform raw AISHELL-2 data to kaldi format
-
-. ./path.sh || exit 1;
-
-tmp=
-dir=
-
-if [ $# != 3 ]; then
- echo "Usage: $0 <corpus-data-dir> <tmp-dir> <output-dir>"
- echo " $0 /export/AISHELL-2/iOS/train data/local/train data/train"
- exit 1;
-fi
-
-corpus=$1
-tmp=$2
-dir=$3
-
-echo "prepare_data.sh: Preparing data in $corpus"
-
-mkdir -p $tmp
-mkdir -p $dir
-
-# corpus check
-if [ ! -d $corpus ] || [ ! -f $corpus/wav.scp ] || [ ! -f $corpus/trans.txt ]; then
- echo "Error: $0 requires wav.scp and trans.txt under $corpus directory."
- exit 1;
-fi
-
-# validate utt-key list, IC0803W0380 is a bad utterance
-awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
-awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
-
-# wav.scp
-awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
-
-# text
-utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
-
-# copy prepared resources from tmp_dir to target dir
-mkdir -p $dir
-for f in wav.scp text; do
- cp $tmp/$f $dir/$f || exit 1;
-done
-
-echo "local/prepare_data.sh succeeded"
-exit 0;
diff --git a/egs/aishell/transformer/path.sh b/egs/aishell/transformer/path.sh
index 7972642..b4064e1 100755
--- a/egs/aishell/transformer/path.sh
+++ b/egs/aishell/transformer/path.sh
@@ -3,3 +3,5 @@
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH
+
+export OMP_NUM_THREADS=1
diff --git a/egs/aishell/transformer/run.sh b/egs/aishell/transformer/run.sh
index f66a338..b7ad9cd 100755
--- a/egs/aishell/transformer/run.sh
+++ b/egs/aishell/transformer/run.sh
@@ -8,33 +8,31 @@
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=8
+njob=5
train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
-scp=feats.scp
-type=kaldi_ark
-stage=0
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
stop_stage=4
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=32
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_aishell=
+raw_data=
+data_url=www.openslr.org/resources/33
# exp tag
-tag=""
+tag="exp1"
. utils/parse_options.sh || exit 1;
@@ -48,8 +46,8 @@
valid_set=dev
test_sets="dev test"
-asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+asr_config=conf/train_asr_transformer.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -66,10 +64,16 @@
_ngpu=0
fi
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# Data preparation
- local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
for x in train dev test; do
cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
@@ -79,46 +83,9 @@
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- utils/fix_data_feat.sh ${fbankdir}/train
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
- utils/fix_data_feat.sh ${fbankdir}/dev
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
- utils/fix_data_feat.sh ${fbankdir}/test
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
-
- cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
- cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
-
- utils/fix_data_feat.sh ${feat_train_dir}
- utils/fix_data_feat.sh ${feat_dev_dir}
- utils/fix_data_feat.sh ${feat_test_dir}
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -126,22 +93,14 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
fi
# Training Stage
@@ -153,7 +112,7 @@
INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
if [ -f $INIT_FILE ];then
rm -f $INIT_FILE
- fi
+ fi
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
for ((i = 0; i < $gpu_num; ++i)); do
@@ -161,26 +120,22 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--token_type char \
--token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -203,7 +158,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -224,6 +179,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -244,4 +200,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
+fi
\ No newline at end of file
diff --git a/egs/aishell/transformer/utils/cmvn_converter.py b/egs/aishell/transformer/utils/cmvn_converter.py
index cb978af..d405d12 100644
--- a/egs/aishell/transformer/utils/cmvn_converter.py
+++ b/egs/aishell/transformer/utils/cmvn_converter.py
@@ -9,16 +9,14 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--cmvn-json",
- "-c",
+ "--cmvn_json",
default=False,
required=True,
type=str,
help="cmvn json file",
)
parser.add_argument(
- "--am-mvn",
- "-a",
+ "--am_mvn",
default=False,
required=True,
type=str,
diff --git a/egs/aishell/transformer/utils/combine_cmvn_file.py b/egs/aishell/transformer/utils/combine_cmvn_file.py
index b2974a4..c525973 100755
--- a/egs/aishell/transformer/utils/combine_cmvn_file.py
+++ b/egs/aishell/transformer/utils/combine_cmvn_file.py
@@ -1,6 +1,9 @@
import argparse
import json
+import os
+
import numpy as np
+
def get_parser():
parser = argparse.ArgumentParser(
@@ -8,15 +11,13 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--dims",
- "-d",
+ "--dim",
default=80,
type=int,
- help="feature dims",
+ help="feature dim",
)
parser.add_argument(
- "--cmvn-dir",
- "-c",
+ "--cmvn_dir",
default=False,
required=True,
type=str,
@@ -25,15 +26,13 @@
parser.add_argument(
"--nj",
- "-n",
default=1,
required=True,
type=int,
- help="num of cmvn file",
+ help="num of cmvn files",
)
parser.add_argument(
- "--output-dir",
- "-o",
+ "--output_dir",
default=False,
required=True,
type=str,
@@ -46,14 +45,14 @@
parser = get_parser()
args = parser.parse_args()
- total_means = np.zeros(args.dims)
- total_vars = np.zeros(args.dims)
+ total_means = np.zeros(args.dim)
+ total_vars = np.zeros(args.dim)
total_frames = 0
- cmvn_file = args.output_dir + "/cmvn.json"
+ cmvn_file = os.path.join(args.output_dir, "cmvn.json")
- for i in range(1, args.nj+1):
- with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
+ for i in range(1, args.nj + 1):
+ with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
cmvn_stats = json.load(fin)
total_means += np.array(cmvn_stats["mean_stats"])
diff --git a/egs/aishell/transformer/utils/compute_cmvn.py b/egs/aishell/transformer/utils/compute_cmvn.py
index 2b96e26..949cc08 100755
--- a/egs/aishell/transformer/utils/compute_cmvn.py
+++ b/egs/aishell/transformer/utils/compute_cmvn.py
@@ -1,8 +1,10 @@
-from kaldiio import ReadHelper
-
import argparse
-import numpy as np
import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
def get_parser():
@@ -11,55 +13,83 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--dims",
- "-d",
+ "--dim",
default=80,
type=int,
- help="feature dims",
+ help="feature dimension",
)
parser.add_argument(
- "--ark-file",
- "-a",
+ "--wav_path",
default=False,
required=True,
type=str,
- help="fbank ark file",
+ help="the path of wav scps",
)
parser.add_argument(
- "--ark-index",
- "-i",
+ "--idx",
default=1,
required=True,
type=int,
- help="ark index",
- )
- parser.add_argument(
- "--output-dir",
- "-o",
- default=False,
- required=True,
- type=str,
- help="output dir",
+ help="index",
)
return parser
+
+
+def compute_fbank(wav_file,
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0,
+ resample_rate=16000,
+ speed=1.0,
+ window_type="hamming"):
+ waveform, sample_rate = torchaudio.load(wav_file)
+ if resample_rate != sample_rate:
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+ new_freq=resample_rate)(waveform)
+ if speed != 1.0:
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, resample_rate,
+ [['speed', str(speed)], ['rate', str(resample_rate)]]
+ )
+
+ waveform = waveform * (1 << 15)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ window_type=window_type,
+ sample_frequency=resample_rate)
+
+ return mat.numpy()
def main():
parser = get_parser()
args = parser.parse_args()
- ark_file = args.ark_file + "/feats." + str(args.ark_index) + ".ark"
- cmvn_file = args.output_dir + "/cmvn." + str(args.ark_index) + ".json"
+ wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
+ cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
- mean_stats = np.zeros(args.dims)
- var_stats = np.zeros(args.dims)
+ mean_stats = np.zeros(args.dim)
+ var_stats = np.zeros(args.dim)
total_frames = 0
- with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
- for key, mat in ark_reader:
- mean_stats += np.sum(mat, axis=0)
- var_stats += np.sum(np.square(mat), axis=0)
- total_frames += mat.shape[0]
+ # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+ # for key, mat in ark_reader:
+ # mean_stats += np.sum(mat, axis=0)
+ # var_stats += np.sum(np.square(mat), axis=0)
+ # total_frames += mat.shape[0]
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ _, wav_file = line.strip().split()
+ fbank = compute_fbank(wav_file, num_mel_bins=args.dim)
+ mean_stats += np.sum(fbank, axis=0)
+ var_stats += np.sum(np.square(fbank), axis=0)
+ total_frames += fbank.shape[0]
cmvn_info = {
'mean_stats': list(mean_stats.tolist()),
diff --git a/egs/aishell/transformer/utils/compute_cmvn.sh b/egs/aishell/transformer/utils/compute_cmvn.sh
index 12173ee..7663df9 100755
--- a/egs/aishell/transformer/utils/compute_cmvn.sh
+++ b/egs/aishell/transformer/utils/compute_cmvn.sh
@@ -11,15 +11,24 @@
. utils/parse_options.sh || exit 1;
fbankdir=$1
-logdir=$2
-output_dir=${fbankdir}/cmvn; mkdir -p ${output_dir}
-mkdir -p ${logdir}
+split_dir=${fbankdir}/cmvn/split_${nj};
+mkdir -p $split_dir
+split_scps=""
+for n in $(seq $nj); do
+ split_scps="$split_scps $split_dir/wav.$n.scp"
+done
+utils/split_scp.pl ${fbankdir}/wav.scp $split_scps || exit 1;
+logdir=${fbankdir}/cmvn/log
$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
- python utils/compute_cmvn.py -d ${feats_dim} -a $fbankdir/ark -i JOB -o ${output_dir} \
- || exit 1;
+ python utils/compute_cmvn.py \
+ --dim ${feats_dim} \
+ --wav_path $split_dir \
+ --idx JOB
-python utils/combine_cmvn_file.py -d ${feats_dim} -c ${output_dir} -n $nj -o $fbankdir
+python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
+
+python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/cmvn.mvn
echo "$0: Succeeded compute global cmvn"
diff --git a/egs/aishell2/conformer/conf/train_asr_conformer.yaml b/egs/aishell2/conformer/conf/train_asr_conformer.yaml
index 02fc5a8..3fd034f 100644
--- a/egs/aishell2/conformer/conf/train_asr_conformer.yaml
+++ b/egs/aishell2/conformer/conf/train_asr_conformer.yaml
@@ -29,6 +29,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
@@ -39,7 +50,7 @@
accum_grad: 2
grad_clip: 5
patience: none
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -72,10 +83,8 @@
- 40
num_time_mask: 2
-log_interval: 50
-normalize: None
-
dataset_conf:
+ data_types: sound,text
shuffle: True
shuffle_conf:
shuffle_size: 2048
@@ -83,4 +92,7 @@
batch_conf:
batch_type: token
batch_size: 25000
- num_workers: 8
\ No newline at end of file
+ num_workers: 8
+
+log_interval: 50
+normalize: None
diff --git a/egs/aishell2/conformer/local/prepare_data.sh b/egs/aishell2/conformer/local/prepare_data.sh
index ce6ee19..77791f9 100755
--- a/egs/aishell2/conformer/local/prepare_data.sh
+++ b/egs/aishell2/conformer/local/prepare_data.sh
@@ -34,14 +34,14 @@
# validate utt-key list, IC0803W0380 is a bad utterance
awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
# wav.scp
awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
# text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
# copy prepared resources from tmp_dir to target dir
mkdir -p $dir
diff --git a/egs/aishell2/conformer/run.sh b/egs/aishell2/conformer/run.sh
index f9ea69a..65fd666 100755
--- a/egs/aishell2/conformer/run.sh
+++ b/egs/aishell2/conformer/run.sh
@@ -9,31 +9,28 @@
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
dataset_type=large
-scp=feats.scp
-type=kaldi_ark
-stage=0
+stage=3
stop_stage=4
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-tr_dir=
-dev_tst_dir=
+tr_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-2/iOS/data
+dev_tst_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-DEV-TEST-SET
# exp tag
tag="exp1"
@@ -51,13 +48,13 @@
test_sets="dev_ios test_ios"
asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
# you can set gpu num for decoding here
-gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, e.g., gpuid_list=2,3, the same as training stage by default
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
@@ -73,7 +70,7 @@
# For training set
local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
# # For dev and test set
- for x in Android iOS Mic; do
+ for x in iOS; do
local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
done
@@ -83,51 +80,14 @@
paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
| tr 'A-Z' 'a-z' | tr -d " ") \
> ${feats_dir}/data/${x}/text
- tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- tools/fix_data_feat.sh ${fbankdir}/train
- for x in android ios mic; do
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
- tools/fix_data_feat.sh ${fbankdir}/dev_${x}
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
- tools/fix_data_feat.sh ${fbankdir}/test_${x}
- done
-
- # compute global cmvn
- steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
- for x in android ios mic; do
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
- done
-
- cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
- tools/fix_data_feat.sh ${feat_train_dir}
- cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
- tools/fix_data_feat.sh ${feat_dev_dir}
- for x in android ios mic; do
- cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
- tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
- done
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -140,18 +100,12 @@
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
# Training Stage
world_size=$gpu_num # run on one machine
@@ -170,21 +124,23 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
- --dataset_type $dataset_type \
--token_type char \
--token_list $token_list \
- --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --dataset_type $dataset_type \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -207,7 +163,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -228,6 +184,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
index 4052774..b6e8808 100644
--- a/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
+++ b/egs/aishell2/data2vec_pretrain/conf/train_pretrain_transformer.yaml
@@ -2,47 +2,52 @@
# encoder related
encoder: data2vec_encoder
encoder_conf:
- extractor_mode: layer_norm
- encoder_layerdrop: 0.05
- dropout_input: 0.0
- dropout_features: 0.0
- feature_grad_mult: 1.0
- encoder_embed_dim: 768
+ extractor_mode: layer_norm
+ encoder_layerdrop: 0.05
+ dropout_input: 0.0
+ dropout_features: 0.0
+ feature_grad_mult: 1.0
+ encoder_embed_dim: 768
- mask_prob: 0.65
- mask_length: 10
+ mask_prob: 0.65
+ mask_length: 10
- loss_beta: 0
- loss_scale: null
+ loss_beta: 0
+ loss_scale: null
- instance_norm_target_layer: true
- average_top_k_layers: 8
+ instance_norm_target_layer: true
+ average_top_k_layers: 8
- pos_conv_depth: 5
- conv_pos: 95
+ pos_conv_depth: 5
+ conv_pos: 95
- ema_decay: 0.999
- ema_end_decay: 0.9999
- ema_anneal_end_step: 30000
- ema_transformer_only: true
- ema_layers_only: true
+ ema_decay: 0.999
+ ema_end_decay: 0.9999
+ ema_anneal_end_step: 30000
+ ema_transformer_only: true
+ ema_layers_only: true
- require_same_masks: true
- mask_dropout: 0
+ require_same_masks: true
+ mask_dropout: 0
-log_interval: 50
-normalize: None
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
-# minibatch related
-batch_type: length
-batch_bins: 64000
-num_workers: 16
+model: data2vec
# optimization related
accum_grad: 1
grad_clip: 5
patience: none
-max_epoch: 600
+max_epoch: 1800
val_scheduler_criterion:
- valid
- acc
@@ -68,7 +73,7 @@
dataset_conf:
batch_mode: clipping
data_names: speech,none
- data_types: kaldi_ark,none
+ data_types: sound,none
shuffle: true
shuffle_conf:
shuffle_size: 12800
@@ -76,4 +81,7 @@
batch_conf:
batch_type: token
batch_size: 64000
- num_workers: 8
\ No newline at end of file
+ num_workers: 8
+
+log_interval: 50
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell2/data2vec_pretrain/run.sh b/egs/aishell2/data2vec_pretrain/run.sh
index eceb183..2753f00 100755
--- a/egs/aishell2/data2vec_pretrain/run.sh
+++ b/egs/aishell2/data2vec_pretrain/run.sh
@@ -7,28 +7,25 @@
gpu_num=8
count=1
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
+speed_perturb="0.9 1.0 1.1"
dataset_type=large
-stage=0
-stop_stage=4
+stage=3
+stop_stage=3
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-tr_dir=
-dev_tst_dir=
+tr_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-2/iOS/data
+dev_tst_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-DEV-TEST-SET
# exp tag
tag="exp1"
@@ -45,68 +42,31 @@
valid_set=dev_ios
asr_config=conf/train_pretrain_transformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml) _${lang}_${token_type}_${tag}"
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# For training set
local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
# # For dev and test set
- for x in Android iOS Mic; do
+ for x in iOS; do
local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
- done
+ done
# Normalize text to capital letters
- for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
+ for x in train dev_ios test_ios; do
mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
| tr 'A-Z' 'a-z' | tr -d " ") \
> ${feats_dir}/data/${x}/text
- tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- tools/fix_data_feat.sh ${fbankdir}/train
- for x in android ios mic; do
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
- tools/fix_data_feat.sh ${fbankdir}/dev_${x}
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
- tools/fix_data_feat.sh ${fbankdir}/test_${x}
- done
-
- # compute global cmvn
- steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
- for x in android ios mic; do
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
- done
-
- cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
- tools/fix_data_feat.sh ${feat_train_dir}
- cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
- tools/fix_data_feat.sh ${feat_dev_dir}
- for x in android ios mic; do
- cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
- tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
- done
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -114,22 +74,59 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
+ fi
+
+# Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: Training"
+ mkdir -p ${exp_dir}/exp/${model_dir}
+ mkdir -p ${exp_dir}/exp/${model_dir}/log
+ INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $gpu_num; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ train.py \
+ --task_name asr \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --token_type char \
+ --token_list $token_list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --dataset_type $dataset_type \
+ --resume true \
+ --output_dir ${exp_dir}/exp/${model_dir} \
+ --config $asr_config \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
fi
# Training Stage
@@ -149,12 +146,16 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- data2vec_train.py \
+ train.py \
+ --task_name pretrain \
--gpu_id $gpu_id \
--use_preprocessor true \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--dataset_type $dataset_type \
- --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
diff --git a/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml b/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
index 450f71a..7fc9794 100644
--- a/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
+++ b/egs/aishell2/paraformer/conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
@@ -29,6 +29,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model: paraformer
model_conf:
@@ -42,7 +53,7 @@
accum_grad: 2
grad_clip: 5
patience: none
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -82,10 +93,8 @@
l_order: 1
r_order: 1
-log_interval: 50
-normalize: None
-
dataset_conf:
+ data_types: sound,text
shuffle: True
shuffle_conf:
shuffle_size: 2048
@@ -93,4 +102,7 @@
batch_conf:
batch_type: token
batch_size: 25000
- num_workers: 8
\ No newline at end of file
+ num_workers: 8
+
+log_interval: 50
+normalize: None
diff --git a/egs/aishell2/paraformer/local/prepare_data.sh b/egs/aishell2/paraformer/local/prepare_data.sh
index ce6ee19..77791f9 100755
--- a/egs/aishell2/paraformer/local/prepare_data.sh
+++ b/egs/aishell2/paraformer/local/prepare_data.sh
@@ -34,14 +34,14 @@
# validate utt-key list, IC0803W0380 is a bad utterance
awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
# wav.scp
awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
# text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
# copy prepared resources from tmp_dir to target dir
mkdir -p $dir
diff --git a/egs/aishell2/paraformer/run.sh b/egs/aishell2/paraformer/run.sh
index e1ea4fe..5236af5 100755
--- a/egs/aishell2/paraformer/run.sh
+++ b/egs/aishell2/paraformer/run.sh
@@ -9,31 +9,28 @@
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=1
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
dataset_type=large
-scp=feats.scp
-type=kaldi_ark
-stage=0
+stage=3
stop_stage=4
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-tr_dir=
-dev_tst_dir=
+tr_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-2/iOS/data
+dev_tst_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-DEV-TEST-SET
# exp tag
tag="exp1"
@@ -51,7 +48,7 @@
test_sets="dev_ios test_ios"
asr_config=conf/train_asr_paraformer_conformer_20e_1280_320_6d_1280_320.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -73,61 +70,24 @@
# For training set
local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
# # For dev and test set
- for x in Android iOS Mic; do
+ for x in iOS; do
local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
- done
+ done
# Normalize text to capital letters
- for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
+ for x in train dev_ios test_ios; do
mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
| tr 'A-Z' 'a-z' | tr -d " ") \
> ${feats_dir}/data/${x}/text
- tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- tools/fix_data_feat.sh ${fbankdir}/train
- for x in android ios mic; do
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
- tools/fix_data_feat.sh ${fbankdir}/dev_${x}
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
- tools/fix_data_feat.sh ${fbankdir}/test_${x}
- done
-
- # compute global cmvn
- steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
- for x in android ios mic; do
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
- done
-
- cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
- tools/fix_data_feat.sh ${feat_train_dir}
- cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
- tools/fix_data_feat.sh ${feat_dev_dir}
- for x in android ios mic; do
- cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
- tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
- done
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -135,23 +95,15 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
# Training Stage
world_size=$gpu_num # run on one machine
@@ -170,28 +122,30 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train_paraformer.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
- --dataset_type $dataset_type \
--token_type char \
--token_list $token_list \
- --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --dataset_type $dataset_type \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
--local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
} &
- done
- wait
+ done
+ wait
fi
# Testing Stage
@@ -207,7 +161,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -228,6 +182,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml b/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
index 19f123e..8968d2d 100644
--- a/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
+++ b/egs/aishell2/paraformerbert/conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
@@ -29,6 +29,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model: paraformer_bert
model_conf:
@@ -36,7 +47,7 @@
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
predictor_weight: 1.0
- glat_context_p: 0.4
+ sampling_ratio: 0.4
embeds_id: 3
embed_dims: 768
embeds_loss_weight: 2.0
@@ -45,7 +56,7 @@
accum_grad: 2
grad_clip: 5
patience: none
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -78,24 +89,24 @@
- 40
num_time_mask: 2
-predictor: cif_predictor_sanm
+predictor: cif_predictor
predictor_conf:
idim: 320
threshold: 1.0
l_order: 1
r_order: 1
-log_interval: 50
-normalize: None
-
dataset_conf:
data_names: speech,text,embed
- data_types: kaldi_ark,text,kaldi_ark
+ data_types: sound,text,kaldi_ark
shuffle: True
shuffle_conf:
- shuffle_size: 10240
+ shuffle_size: 2048
sort_size: 500
batch_conf:
batch_type: token
batch_size: 25000
- num_workers: 8
\ No newline at end of file
+ num_workers: 8
+
+log_interval: 50
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell2/paraformerbert/local/extract_embeds.sh b/egs/aishell2/paraformerbert/local/extract_embeds.sh
index 5f45ff3..ee89924 100755
--- a/egs/aishell2/paraformerbert/local/extract_embeds.sh
+++ b/egs/aishell2/paraformerbert/local/extract_embeds.sh
@@ -3,20 +3,16 @@
stage=1
stop_stage=3
-bert_model_root="../../huggingface_models"
bert_model_name="bert-base-chinese"
-#bert_model_name="chinese-roberta-wwm-ext"
-#bert_model_name="mengzi-bert-base"
raw_dataset_path="../DATA"
-model_path=${bert_model_root}/${bert_model_name}
+nj=64
+model_path=${bert_model_name}
. utils/parse_options.sh || exit 1;
-nj=100
-
-for data_set in train dev_ios test_ios;do
- scp=$raw_dataset_path/dump/fbank/${data_set}/text
- local_scp_dir_raw=$raw_dataset_path/embeds/$bert_model_name/${data_set}
+for data_set in train dev_ios;do
+ scp=$raw_dataset_path/data/${data_set}/text
+ local_scp_dir_raw=${raw_dataset_path}/data/embeds/${data_set}
local_scp_dir=$local_scp_dir_raw/split$nj
local_records_dir=$local_scp_dir_raw/ark
@@ -31,7 +27,7 @@
utils/split_scp.pl $scp ${split_scps}
- for num in {0..24};do
+ for num in {0..7};do
tmp=`expr $num \* 4`
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
@@ -41,20 +37,9 @@
{
beg=0
gpu=`expr $beg + $idx`
- echo $local_scp_dir_raw/log/log.${JOB}
- python tools/extract_embeds.py $local_scp_dir/text.$JOB.txt ${local_records_dir}/embeds.${JOB}.ark ${local_records_dir}/embeds.${JOB}.scp ${local_records_dir}/embeds.${JOB}.shape ${gpu} ${model_path} &> $local_scp_dir_raw/log/log.${JOB}
+ echo ${local_scp_dir}/log.${JOB}
+ python utils/extract_embeds.py $local_scp_dir/data.$JOB.text ${local_records_dir}/embeds.${JOB}.ark ${local_records_dir}/embeds.${JOB}.scp ${local_records_dir}/embeds.${JOB}.shape ${gpu} ${model_path} &> ${local_scp_dir}/log.${JOB}
} &
- done
- wait
- fi
-
- if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
- for idx in {1..4}; do
- JOB=`expr $tmp + $idx`
- echo "upload jobid=$JOB"
- {
- hadoop fs -put -f ${local_records_dir}/embeds.${JOB}.ark ${odps_des_feature_dir}/embeds.${JOB}.ark
- } &
done
wait
fi
diff --git a/egs/aishell2/paraformerbert/local/prepare_data.sh b/egs/aishell2/paraformerbert/local/prepare_data.sh
index 801dbe5..77791f9 100755
--- a/egs/aishell2/paraformerbert/local/prepare_data.sh
+++ b/egs/aishell2/paraformerbert/local/prepare_data.sh
@@ -17,7 +17,6 @@
fi
corpus=$1
-#dict_dir=$2
tmp=$2
dir=$3
@@ -35,14 +34,14 @@
# validate utt-key list, IC0803W0380 is a bad utterance
awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
# wav.scp
awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
# text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
# copy prepared resources from tmp_dir to target dir
mkdir -p $dir
diff --git a/egs/aishell2/paraformerbert/run.sh b/egs/aishell2/paraformerbert/run.sh
index 239a7e3..26c8063 100755
--- a/egs/aishell2/paraformerbert/run.sh
+++ b/egs/aishell2/paraformerbert/run.sh
@@ -8,36 +8,32 @@
count=1
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
-train_cmd=tools/run.pl
+njob=1
+train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
-feats_dir="../DATA" #feature output dictionary, for large data
+feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
dataset_type=large
-scp=feats.scp
-type=kaldi_ark
-stage=0
-stop_stage=5
+stage=3
+stop_stage=4
skip_extract_embed=false
-bert_model_root="../../huggingface_models"
bert_model_name="bert-base-chinese"
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-tr_dir=
-dev_tst_dir=
+tr_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-2/iOS/data
+dev_tst_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-DEV-TEST-SET
# exp tag
tag="exp1"
@@ -55,7 +51,7 @@
test_sets="dev_ios test_ios"
asr_config=conf/train_asr_paraformerbert_conformer_20e_6d_1280_320.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer_noctc_1best.yaml
inference_asr_model=valid.acc.ave_10best.pb
@@ -75,86 +71,42 @@
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "stage 0: Data preparation"
# For training set
- local/prepare_data.sh ${tr_dir} data/local/train data/train || exit 1;
+ local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
# # For dev and test set
- for x in Android iOS Mic; do
- local/prepare_data.sh ${dev_tst_dir}/${x}/dev data/local/dev_${x,,} data/dev_${x,,} || exit 1;
- local/prepare_data.sh ${dev_tst_dir}/${x}/test data/local/test_${x,,} data/test_${x,,} || exit 1;
- done
+ for x in iOS; do
+ local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
+ local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
+ done
# Normalize text to capital letters
- for x in train dev_android dev_ios dev_mic test_android test_ios test_mic; do
- mv data/${x}/text data/${x}/text.org
- paste <(cut -f 1 data/${x}/text.org) <(cut -f 2 data/${x}/text.org | tr '[:lower:]' '[:upper:]') \
- > data/${x}/text
- tools/text2token.py -n 1 -s 1 data/${x}/text > data/${x}/text.org
- mv data/${x}/text.org data/${x}/text
+ for x in train dev_ios test_ios; do
+ mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+ paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
+ | tr 'A-Z' 'a-z' | tr -d " ") \
+ > ${feats_dir}/data/${x}/text
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
- data/train exp/make_fbank/train ${fbankdir}/train
- tools/fix_data_feat.sh ${fbankdir}/train
- for x in android ios mic; do
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- data/dev_${x} exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
- tools/fix_data_feat.sh ${fbankdir}/dev_${x}
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- data/test_${x} exp/make_fbank/test_${x} ${fbankdir}/test_${x}
- tools/fix_data_feat.sh ${fbankdir}/test_${x}
- done
-
- # compute global cmvn
- steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train exp/make_fbank/train
-
- # apply cmvn
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json exp/make_fbank/${train_set} ${feat_train_dir}
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json exp/make_fbank/${valid_set} ${feat_dev_dir}
- for x in android ios mic; do
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
- done
-
- cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
- tools/fix_data_feat.sh ${feat_train_dir}
- cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
- tools/fix_data_feat.sh ${feat_dev_dir}
- for x in android ios mic; do
- cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
- tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
- done
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
- mkdir -p data/${lang}_token_list/char/
-
+ mkdir -p ${feats_dir}/data/${lang}_token_list/char/
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- tools/text2token.py -s 1 -n 1 --space "" data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p asr_stats_fbank_zh_char/${train_set}
- mkdir -p asr_stats_fbank_zh_char/${valid_set}
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char asr_stats_fbank_zh_char/${train_set}
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
# Training Stage
world_size=$gpu_num # run on one machine
@@ -163,9 +115,9 @@
if ! "${skip_extract_embed}"; then
echo "extract embeddings..."
local/extract_embeds.sh \
- --bert_model_root ${bert_model_root} \
--bert_model_name ${bert_model_name} \
- --raw_dataset_path ${feats_dir}
+ --raw_dataset_path ${feats_dir} \
+ --nj $nj
fi
mkdir -p ${exp_dir}/exp/${model_dir}
mkdir -p ${exp_dir}/exp/${model_dir}/log
@@ -180,22 +132,24 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train_paraformer.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
- --dataset_type $dataset_type \
- --token_type $token_type \
+ --token_type char \
--token_list $token_list \
- --train_data_file $feats_dir/$dumpdir/${train_set}/data_bert.list \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/data_bert.list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --embed_path ${feats_dir}/data \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --dataset_type $dataset_type \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --allow_variable_data_keys true \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -218,7 +172,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -239,6 +193,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
@@ -259,5 +214,4 @@
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
done
-fi
-
+fi
\ No newline at end of file
diff --git a/egs/aishell2/transformer/conf/train_asr_transformer.yaml b/egs/aishell2/transformer/conf/train_asr_transformer.yaml
index 3e2172d..2d16c18 100644
--- a/egs/aishell2/transformer/conf/train_asr_transformer.yaml
+++ b/egs/aishell2/transformer/conf/train_asr_transformer.yaml
@@ -23,6 +23,17 @@
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
@@ -33,7 +44,7 @@
accum_grad: 2
grad_clip: 5
patience: none
-max_epoch: 50
+max_epoch: 150
val_scheduler_criterion:
- valid
- acc
@@ -66,10 +77,8 @@
- 40
num_time_mask: 2
-log_interval: 50
-normalize: None
-
dataset_conf:
+ data_types: sound,text
shuffle: True
shuffle_conf:
shuffle_size: 2048
@@ -77,4 +86,7 @@
batch_conf:
batch_type: token
batch_size: 25000
- num_workers: 8
\ No newline at end of file
+ num_workers: 8
+
+log_interval: 50
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell2/transformer/local/prepare_data.sh b/egs/aishell2/transformer/local/prepare_data.sh
index ce6ee19..77791f9 100755
--- a/egs/aishell2/transformer/local/prepare_data.sh
+++ b/egs/aishell2/transformer/local/prepare_data.sh
@@ -34,14 +34,14 @@
# validate utt-key list, IC0803W0380 is a bad utterance
awk '{print $1}' $corpus/wav.scp | grep -v 'IC0803W0380' > $tmp/wav_utt.list
awk '{print $1}' $corpus/trans.txt > $tmp/trans_utt.list
-tools/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
+utils/filter_scp.pl -f 1 $tmp/wav_utt.list $tmp/trans_utt.list > $tmp/utt.list
# wav.scp
awk -F'\t' -v path_prefix=$corpus '{printf("%s\t%s/%s\n",$1,path_prefix,$2)}' $corpus/wav.scp > $tmp/tmp_wav.scp
-tools/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
+utils/filter_scp.pl -f 1 $tmp/utt.list $tmp/tmp_wav.scp | sort -k 1 | uniq > $tmp/wav.scp
# text
-tools/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
+utils/filter_scp.pl -f 1 $tmp/utt.list $corpus/trans.txt | sort -k 1 | uniq > $tmp/text
# copy prepared resources from tmp_dir to target dir
mkdir -p $dir
diff --git a/egs/aishell2/transformer/run.sh b/egs/aishell2/transformer/run.sh
index 6f2dd4d..d8b16aa 100755
--- a/egs/aishell2/transformer/run.sh
+++ b/egs/aishell2/transformer/run.sh
@@ -9,31 +9,28 @@
gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
njob=5
-train_cmd=tools/run.pl
+train_cmd=utils/run.pl
infer_cmd=utils/run.pl
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=zh
-dumpdir=dump/fbank
-feats_type=fbank
token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
dataset_type=large
-scp=feats.scp
-type=kaldi_ark
-stage=0
+stage=3
stop_stage=4
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-tr_dir=
-dev_tst_dir=
+tr_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-2/iOS/data
+dev_tst_dir=/nfs/wangjiaming.wjm/asr_data/aishell2/AISHELL-DEV-TEST-SET
# exp tag
tag="exp1"
@@ -51,13 +48,13 @@
test_sets="dev_ios test_ios"
asr_config=conf/train_asr_transformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
inference_asr_model=valid.acc.ave_10best.pb
# you can set gpu num for decoding here
-gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, e.g., gpuid_list=2,3, the same as training stage by default
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
if ${gpu_inference}; then
@@ -73,61 +70,24 @@
# For training set
local/prepare_data.sh ${tr_dir} ${feats_dir}/data/local/train ${feats_dir}/data/train || exit 1;
# # For dev and test set
- for x in Android iOS Mic; do
+ for x in iOS; do
local/prepare_data.sh ${dev_tst_dir}/${x}/dev ${feats_dir}/data/local/dev_${x,,} ${feats_dir}/data/dev_${x,,} || exit 1;
local/prepare_data.sh ${dev_tst_dir}/${x}/test ${feats_dir}/data/local/test_${x,,} ${feats_dir}/data/test_${x,,} || exit 1;
- done
+ done
# Normalize text to capital letters
for x in train dev_ios test_ios; do
mv ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
paste -d " " <(cut -f 1 ${feats_dir}/data/${x}/text.org) <(cut -f 2- ${feats_dir}/data/${x}/text.org \
| tr 'A-Z' 'a-z' | tr -d " ") \
> ${feats_dir}/data/${x}/text
- tools/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/${train_set}; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/${valid_set}; mkdir -p ${feat_dev_dir}
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
- tools/fix_data_feat.sh ${fbankdir}/train
- for x in android ios mic; do
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/dev_${x} ${exp_dir}/exp/make_fbank/dev_${x} ${fbankdir}/dev_${x}
- tools/fix_data_feat.sh ${fbankdir}/dev_${x}
- steps/compute_fbank.sh --cmd "$train_cmd" --nj $nj \
- ${feats_dir}/data/test_${x} ${exp_dir}/exp/make_fbank/test_${x} ${fbankdir}/test_${x}
- tools/fix_data_feat.sh ${fbankdir}/test_${x}
- done
-
- # compute global cmvn
- steps/compute_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
-
- # apply cmvn
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${train_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${train_set} ${feat_train_dir}
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/${valid_set} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/${valid_set} ${feat_dev_dir}
- for x in android ios mic; do
- steps/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/test_${x} ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test_${x} ${feats_dir}/${dumpdir}/test_${x}
- done
-
- cp ${fbankdir}/${train_set}/text ${fbankdir}/${train_set}/speech_shape ${fbankdir}/${train_set}/text_shape ${feat_train_dir}
- tools/fix_data_feat.sh ${feat_train_dir}
- cp ${fbankdir}/${valid_set}/text ${fbankdir}/${valid_set}/speech_shape ${fbankdir}/${valid_set}/text_shape ${feat_dev_dir}
- tools/fix_data_feat.sh ${feat_dev_dir}
- for x in android ios mic; do
- cp ${fbankdir}/test_${x}/text ${fbankdir}/test_${x}/speech_shape ${fbankdir}/test_${x}/text_shape ${feats_dir}/${dumpdir}/test_${x}
- tools/fix_data_feat.sh ${feats_dir}/${dumpdir}/test_${x}
- done
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
fi
token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
@@ -135,23 +95,15 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "stage 2: Dictionary Preparation"
mkdir -p ${feats_dir}/data/${lang}_token_list/char/
-
+
echo "make a dictionary"
echo "<blank>" > ${token_list}
echo "<s>" >> ${token_list}
echo "</s>" >> ${token_list}
- tools/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/${train_set}/text | cut -f 2- -d" " | tr " " "\n" \
| sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
- num_token=$(cat ${token_list} | wc -l)
echo "<unk>" >> ${token_list}
- vocab_size=$(cat ${token_list} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${train_set}
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}
-fi
+ fi
# Training Stage
world_size=$gpu_num # run on one machine
@@ -170,21 +122,23 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
- --dataset_type $dataset_type \
--token_type char \
--token_list $token_list \
- --train_data_file $feats_dir/$dumpdir/${train_set}/data.list \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/data.list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --dataset_type $dataset_type \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
- --multiprocessing_distributed true \
--dist_init_method $init_method \
--dist_world_size $world_size \
--dist_rank $rank \
@@ -207,7 +161,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -228,6 +182,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/aishell2/transformer/utils/cmvn_converter.py b/egs/aishell2/transformer/utils/cmvn_converter.py
new file mode 100644
index 0000000..d405d12
--- /dev/null
+++ b/egs/aishell2/transformer/utils/cmvn_converter.py
@@ -0,0 +1,51 @@
+import argparse
+import json
+import numpy as np
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="cmvn converter",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--cmvn_json",
+ default=False,
+ required=True,
+ type=str,
+ help="cmvn json file",
+ )
+ parser.add_argument(
+ "--am_mvn",
+ default=False,
+ required=True,
+ type=str,
+ help="am mvn file",
+ )
+ return parser
+
+def main():
+ parser = get_parser()
+ args = parser.parse_args()
+
+ with open(args.cmvn_json, "r") as fin:
+ cmvn_dict = json.load(fin)
+
+ mean_stats = np.array(cmvn_dict["mean_stats"])
+ var_stats = np.array(cmvn_dict["var_stats"])
+ total_frame = np.array(cmvn_dict["total_frames"])
+
+ mean = -1.0 * mean_stats / total_frame
+ var = 1.0 / np.sqrt(var_stats / total_frame - mean * mean)
+ dims = mean.shape[0]
+ with open(args.am_mvn, 'w') as fout:
+ fout.write("<Nnet>" + "\n" + "<Splice> " + str(dims) + " " + str(dims) + '\n' + "[ 0 ]" + "\n" + "<AddShift> " + str(dims) + " " + str(dims) + "\n")
+ mean_str = str(list(mean)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
+ fout.write("<LearnRateCoef> 0 " + mean_str + '\n')
+ fout.write("<Rescale> " + str(dims) + " " + str(dims) + '\n')
+ var_str = str(list(var)).replace(',', '').replace('[', '[ ').replace(']', ' ]')
+ fout.write("<LearnRateCoef> 0 " + var_str + '\n')
+ fout.write("</Nnet>" + '\n')
+
+if __name__ == '__main__':
+ main()
diff --git a/egs/aishell2/transformer/utils/combine_cmvn_file.py b/egs/aishell2/transformer/utils/combine_cmvn_file.py
index b2974a4..c525973 100755
--- a/egs/aishell2/transformer/utils/combine_cmvn_file.py
+++ b/egs/aishell2/transformer/utils/combine_cmvn_file.py
@@ -1,6 +1,9 @@
import argparse
import json
+import os
+
import numpy as np
+
def get_parser():
parser = argparse.ArgumentParser(
@@ -8,15 +11,13 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--dims",
- "-d",
+ "--dim",
default=80,
type=int,
- help="feature dims",
+ help="feature dim",
)
parser.add_argument(
- "--cmvn-dir",
- "-c",
+ "--cmvn_dir",
default=False,
required=True,
type=str,
@@ -25,15 +26,13 @@
parser.add_argument(
"--nj",
- "-n",
default=1,
required=True,
type=int,
- help="num of cmvn file",
+ help="num of cmvn files",
)
parser.add_argument(
- "--output-dir",
- "-o",
+ "--output_dir",
default=False,
required=True,
type=str,
@@ -46,14 +45,14 @@
parser = get_parser()
args = parser.parse_args()
- total_means = np.zeros(args.dims)
- total_vars = np.zeros(args.dims)
+ total_means = np.zeros(args.dim)
+ total_vars = np.zeros(args.dim)
total_frames = 0
- cmvn_file = args.output_dir + "/cmvn.json"
+ cmvn_file = os.path.join(args.output_dir, "cmvn.json")
- for i in range(1, args.nj+1):
- with open(args.cmvn_dir + "/cmvn." + str(i) + ".json", "r") as fin:
+ for i in range(1, args.nj + 1):
+ with open(os.path.join(args.cmvn_dir, "cmvn.{}.json".format(str(i)))) as fin:
cmvn_stats = json.load(fin)
total_means += np.array(cmvn_stats["mean_stats"])
diff --git a/egs/aishell2/transformer/utils/compute_cmvn.py b/egs/aishell2/transformer/utils/compute_cmvn.py
index 2b96e26..949cc08 100755
--- a/egs/aishell2/transformer/utils/compute_cmvn.py
+++ b/egs/aishell2/transformer/utils/compute_cmvn.py
@@ -1,8 +1,10 @@
-from kaldiio import ReadHelper
-
import argparse
-import numpy as np
import json
+import os
+
+import numpy as np
+import torchaudio
+import torchaudio.compliance.kaldi as kaldi
def get_parser():
@@ -11,55 +13,83 @@
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--dims",
- "-d",
+ "--dim",
default=80,
type=int,
- help="feature dims",
+ help="feature dimension",
)
parser.add_argument(
- "--ark-file",
- "-a",
+ "--wav_path",
default=False,
required=True,
type=str,
- help="fbank ark file",
+ help="the path of wav scps",
)
parser.add_argument(
- "--ark-index",
- "-i",
+ "--idx",
default=1,
required=True,
type=int,
- help="ark index",
- )
- parser.add_argument(
- "--output-dir",
- "-o",
- default=False,
- required=True,
- type=str,
- help="output dir",
+ help="index",
)
return parser
+
+
+def compute_fbank(wav_file,
+ num_mel_bins=80,
+ frame_length=25,
+ frame_shift=10,
+ dither=0.0,
+ resample_rate=16000,
+ speed=1.0,
+ window_type="hamming"):
+ waveform, sample_rate = torchaudio.load(wav_file)
+ if resample_rate != sample_rate:
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate,
+ new_freq=resample_rate)(waveform)
+ if speed != 1.0:
+ waveform, _ = torchaudio.sox_effects.apply_effects_tensor(
+ waveform, resample_rate,
+ [['speed', str(speed)], ['rate', str(resample_rate)]]
+ )
+
+ waveform = waveform * (1 << 15)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=num_mel_bins,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ dither=dither,
+ energy_floor=0.0,
+ window_type=window_type,
+ sample_frequency=resample_rate)
+
+ return mat.numpy()
def main():
parser = get_parser()
args = parser.parse_args()
- ark_file = args.ark_file + "/feats." + str(args.ark_index) + ".ark"
- cmvn_file = args.output_dir + "/cmvn." + str(args.ark_index) + ".json"
+ wav_scp_file = os.path.join(args.wav_path, "wav.{}.scp".format(args.idx))
+ cmvn_file = os.path.join(args.wav_path, "cmvn.{}.json".format(args.idx))
- mean_stats = np.zeros(args.dims)
- var_stats = np.zeros(args.dims)
+ mean_stats = np.zeros(args.dim)
+ var_stats = np.zeros(args.dim)
total_frames = 0
- with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
- for key, mat in ark_reader:
- mean_stats += np.sum(mat, axis=0)
- var_stats += np.sum(np.square(mat), axis=0)
- total_frames += mat.shape[0]
+ # with ReadHelper('ark:{}'.format(ark_file)) as ark_reader:
+ # for key, mat in ark_reader:
+ # mean_stats += np.sum(mat, axis=0)
+ # var_stats += np.sum(np.square(mat), axis=0)
+ # total_frames += mat.shape[0]
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ _, wav_file = line.strip().split()
+ fbank = compute_fbank(wav_file, num_mel_bins=args.dim)
+ mean_stats += np.sum(fbank, axis=0)
+ var_stats += np.sum(np.square(fbank), axis=0)
+ total_frames += fbank.shape[0]
cmvn_info = {
'mean_stats': list(mean_stats.tolist()),
diff --git a/egs/aishell2/transformer/utils/compute_cmvn.sh b/egs/aishell2/transformer/utils/compute_cmvn.sh
index 12173ee..7663df9 100755
--- a/egs/aishell2/transformer/utils/compute_cmvn.sh
+++ b/egs/aishell2/transformer/utils/compute_cmvn.sh
@@ -11,15 +11,24 @@
. utils/parse_options.sh || exit 1;
fbankdir=$1
-logdir=$2
-output_dir=${fbankdir}/cmvn; mkdir -p ${output_dir}
-mkdir -p ${logdir}
+split_dir=${fbankdir}/cmvn/split_${nj};
+mkdir -p $split_dir
+split_scps=""
+for n in $(seq $nj); do
+ split_scps="$split_scps $split_dir/wav.$n.scp"
+done
+utils/split_scp.pl ${fbankdir}/wav.scp $split_scps || exit 1;
+logdir=${fbankdir}/cmvn/log
$cmd JOB=1:$nj $logdir/cmvn.JOB.log \
- python utils/compute_cmvn.py -d ${feats_dim} -a $fbankdir/ark -i JOB -o ${output_dir} \
- || exit 1;
+ python utils/compute_cmvn.py \
+ --dim ${feats_dim} \
+ --wav_path $split_dir \
+ --idx JOB
-python utils/combine_cmvn_file.py -d ${feats_dim} -c ${output_dir} -n $nj -o $fbankdir
+python utils/combine_cmvn_file.py --dim ${feats_dim} --cmvn_dir $split_dir --nj $nj --output_dir ${fbankdir}/cmvn
+
+python utils/cmvn_converter.py --cmvn_json ${fbankdir}/cmvn/cmvn.json --am_mvn ${fbankdir}/cmvn/cmvn.mvn
echo "$0: Succeeded compute global cmvn"
diff --git a/egs/aishell2/transformer/utils/compute_fbank.py b/egs/aishell2/transformer/utils/compute_fbank.py
index d03b5a8..9c3904f 100755
--- a/egs/aishell2/transformer/utils/compute_fbank.py
+++ b/egs/aishell2/transformer/utils/compute_fbank.py
@@ -14,7 +14,8 @@
frame_shift=10,
dither=0.0,
resample_rate=16000,
- speed=1.0):
+ speed=1.0,
+ window_type="hamming"):
waveform, sample_rate = torchaudio.load(wav_file)
if resample_rate != sample_rate:
@@ -33,7 +34,7 @@
frame_shift=frame_shift,
dither=dither,
energy_floor=0.0,
- window_type='hamming',
+ window_type=window_type,
sample_frequency=resample_rate)
return mat.numpy()
@@ -68,6 +69,13 @@
help="feature dims",
)
parser.add_argument(
+ "--max-lengths",
+ "-m",
+ default=1500,
+ type=int,
+ help="max frame numbers",
+ )
+ parser.add_argument(
"--sample-frequency",
"-s",
default=16000,
@@ -96,6 +104,13 @@
required=True,
type=str,
help="output dir",
+ )
+ parser.add_argument(
+ "--window-type",
+ default="hamming",
+ required=False,
+ type=str,
+ help="window type"
)
return parser
@@ -131,10 +146,13 @@
fbank = compute_fbank(wav_file,
num_mel_bins=args.dims,
resample_rate=args.sample_frequency,
- speed=float(speed)
+ speed=float(speed),
+ window_type=args.window_type
)
feats_dims = fbank.shape[1]
feats_lens = fbank.shape[0]
+ if feats_lens >= args.max_lengths:
+ continue
txt_lens = len(txt)
if speed == "1.0":
wav_id_sp = wav_id
diff --git a/egs/aishell2/transformer/utils/compute_fbank.sh b/egs/aishell2/transformer/utils/compute_fbank.sh
index 92a4fe6..8704b31 100755
--- a/egs/aishell2/transformer/utils/compute_fbank.sh
+++ b/egs/aishell2/transformer/utils/compute_fbank.sh
@@ -9,6 +9,8 @@
feats_dim=80
sample_frequency=16000
speed_perturb="1.0"
+window_type="hamming"
+max_lengths=1500
echo "$0 $@"
@@ -29,7 +31,8 @@
$cmd JOB=1:$nj $logdir/make_fbank.JOB.log \
python utils/compute_fbank.py -w $data/split${nj}/JOB/wav.scp -t $data/split${nj}/JOB/text \
- -d $feats_dim -s $sample_frequency -p ${speed_perturb} -a JOB -o ${fbankdir} \
+ -d $feats_dim -s $sample_frequency -m ${max_lengths} -p ${speed_perturb} -a JOB -o ${fbankdir} \
+ --window-type ${window_type} \
|| exit 1;
for n in $(seq $nj); do
diff --git a/egs/aishell2/transformer/utils/compute_wer.py b/egs/aishell2/transformer/utils/compute_wer.py
index 349a3f6..26a9f49 100755
--- a/egs/aishell2/transformer/utils/compute_wer.py
+++ b/egs/aishell2/transformer/utils/compute_wer.py
@@ -45,8 +45,8 @@
if out_item['wrong'] > 0:
rst['wrong_sentences'] += 1
cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + '\n')
- cer_detail_writer.write("ref:" + '\t' + "".join(ref_dict[hyp_key]) + '\n')
- cer_detail_writer.write("hyp:" + '\t' + "".join(hyp_dict[hyp_key]) + '\n')
+ cer_detail_writer.write("ref:" + '\t' + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + '\n')
+ cer_detail_writer.write("hyp:" + '\t' + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + '\n')
if rst['Wrd'] > 0:
rst['Err'] = round(rst['wrong_words'] * 100 / rst['Wrd'], 2)
diff --git a/egs/aishell2/transformer/utils/download_model.py b/egs/aishell2/transformer/utils/download_model.py
new file mode 100755
index 0000000..70ea179
--- /dev/null
+++ b/egs/aishell2/transformer/utils/download_model.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python3
+import argparse
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(
+ description="download model configs",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument("--model_name",
+ type=str,
+ default="damo/speech_data2vec_pretrain-zh-cn-aishell2-16k-pytorch",
+ help="model name in ModelScope")
+ args = parser.parse_args()
+
+ inference_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model=args.model_name)
diff --git a/egs/aishell2/transformer/utils/fix_data.sh b/egs/aishell2/transformer/utils/fix_data.sh
index 32cdde5..b1a2bb8 100755
--- a/egs/aishell2/transformer/utils/fix_data.sh
+++ b/egs/aishell2/transformer/utils/fix_data.sh
@@ -28,8 +28,8 @@
mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak > ${data_dir}/wav.scp
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak | sort -k1,1 -u > ${data_dir}/wav.scp
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
rm ${data_dir}/wav.scp.bak
rm ${data_dir}/text.bak
diff --git a/egs/aishell2/transformer/utils/fix_data_feat.sh b/egs/aishell2/transformer/utils/fix_data_feat.sh
index 2c92d7f..84eea36 100755
--- a/egs/aishell2/transformer/utils/fix_data_feat.sh
+++ b/egs/aishell2/transformer/utils/fix_data_feat.sh
@@ -40,10 +40,10 @@
mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak
mv ${data_dir}/text_shape ${data_dir}/text_shape.bak
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak > ${data_dir}/feats.scp
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak > ${data_dir}/text
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak > ${data_dir}/speech_shape
-utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak > ${data_dir}/text_shape
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak | sort -k1,1 -u > ${data_dir}/feats.scp
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak | sort -k1,1 -u > ${data_dir}/speech_shape
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak | sort -k1,1 -u > ${data_dir}/text_shape
rm ${data_dir}/feats.scp.bak
rm ${data_dir}/text.bak
diff --git a/egs/librispeech/conformer/conf/train_asr_conformer.yaml b/egs/librispeech/conformer/conf/train_asr_conformer.yaml
index 68b127f..2bd3db4 100644
--- a/egs/librispeech/conformer/conf/train_asr_conformer.yaml
+++ b/egs/librispeech/conformer/conf/train_asr_conformer.yaml
@@ -27,13 +27,25 @@
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1
length_normalized_loss: false
accum_grad: 2
-max_epoch: 50
+max_epoch: 150
patience: none
init: none
best_model_criterion:
diff --git a/egs/librispeech/conformer/local/data_prep_librispeech.sh b/egs/librispeech/conformer/local/data_prep.sh
similarity index 100%
rename from egs/librispeech/conformer/local/data_prep_librispeech.sh
rename to egs/librispeech/conformer/local/data_prep.sh
diff --git a/egs/librispeech/conformer/local/download_and_untar.sh b/egs/librispeech/conformer/local/download_and_untar.sh
new file mode 100755
index 0000000..fe322e4
--- /dev/null
+++ b/egs/librispeech/conformer/local/download_and_untar.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
+ echo " train-clean-100, train-clean-360, train-other-500."
+ exit 1
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1
+fi
+
+part_ok=false
+list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1
+fi
+
+if [ -f $data/LibriSpeech/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0
+fi
+
+
+# sizes of the archive files in bytes. This is some older versions.
+sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
+# sizes_new is the archive file sizes of the final release. Some of these sizes are of
+# things we probably won't download.
+sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
+
+if [ -f $data/$part.tar.gz ]; then
+ size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tar.gz
+ else
+ echo "$data/$part.tar.gz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tar.gz ]; then
+ if ! which wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1
+ fi
+ full_url=$url/$part.tar.gz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ if ! wget -P $data --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1
+ fi
+fi
+
+if ! tar -C $data -xvzf $data/$part.tar.gz; then
+ echo "$0: error un-tarring archive $data/$part.tar.gz"
+ exit 1
+fi
+
+touch $data/LibriSpeech/$part/.complete
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
+ rm $data/$part.tar.gz
+fi
diff --git a/egs/librispeech/conformer/local/spm_encode.py b/egs/librispeech/conformer/local/spm_encode.py
new file mode 100755
index 0000000..9e1c15f
--- /dev/null
+++ b/egs/librispeech/conformer/local/spm_encode.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+
+
+import argparse
+import contextlib
+import sys
+
+import sentencepiece as spm
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", required=True,
+ help="sentencepiece model to use for encoding")
+ parser.add_argument("--inputs", nargs="+", default=['-'],
+ help="input files to filter/encode")
+ parser.add_argument("--outputs", nargs="+", default=['-'],
+ help="path to save encoded outputs")
+ parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
+ parser.add_argument("--min-len", type=int, metavar="N",
+ help="filter sentence pairs with fewer than N tokens")
+ parser.add_argument("--max-len", type=int, metavar="N",
+ help="filter sentence pairs with more than N tokens")
+ args = parser.parse_args()
+
+ assert len(args.inputs) == len(args.outputs), \
+ "number of input and output paths should match"
+
+ sp = spm.SentencePieceProcessor()
+ sp.Load(args.model)
+
+ if args.output_format == "piece":
+ def encode(l):
+ return sp.EncodeAsPieces(l)
+ elif args.output_format == "id":
+ def encode(l):
+ return list(map(str, sp.EncodeAsIds(l)))
+ else:
+ raise NotImplementedError
+
+ if args.min_len is not None or args.max_len is not None:
+ def valid(line):
+ return (
+ (args.min_len is None or len(line) >= args.min_len) and
+ (args.max_len is None or len(line) <= args.max_len)
+ )
+ else:
+ def valid(lines):
+ return True
+
+ with contextlib.ExitStack() as stack:
+ inputs = [
+ stack.enter_context(open(input, "r", encoding="utf-8"))
+ if input != "-" else sys.stdin
+ for input in args.inputs
+ ]
+ outputs = [
+ stack.enter_context(open(output, "w", encoding="utf-8"))
+ if output != "-" else sys.stdout
+ for output in args.outputs
+ ]
+
+ stats = {
+ "num_empty": 0,
+ "num_filtered": 0,
+ }
+
+ def encode_line(line):
+ line = line.strip()
+ if len(line) > 0:
+ line = encode(line)
+ if valid(line):
+ return line
+ else:
+ stats["num_filtered"] += 1
+ else:
+ stats["num_empty"] += 1
+ return None
+
+ for i, lines in enumerate(zip(*inputs), start=1):
+ enc_lines = list(map(encode_line, lines))
+ if not any(enc_line is None for enc_line in enc_lines):
+ for enc_line, output_h in zip(enc_lines, outputs):
+ print(" ".join(enc_line), file=output_h)
+ if i % 10000 == 0:
+ print("processed {} lines".format(i), file=sys.stderr)
+
+ print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
+ print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech/conformer/local/spm_train.py b/egs/librispeech/conformer/local/spm_train.py
new file mode 100755
index 0000000..134a0b1
--- /dev/null
+++ b/egs/librispeech/conformer/local/spm_train.py
@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+import sys
+
+import sentencepiece as spm
+
+if __name__ == "__main__":
+ spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
diff --git a/egs/librispeech/conformer/run.sh b/egs/librispeech/conformer/run.sh
index 93d1b46..cc4029c 100755
--- a/egs/librispeech/conformer/run.sh
+++ b/egs/librispeech/conformer/run.sh
@@ -16,30 +16,27 @@
feats_dir="../DATA" #feature output dictionary
exp_dir="."
lang=en
-dumpdir=dump/fbank
-feats_type=fbank
token_type=bpe
-dataset_type=large
-scp=feats.scp
-type=kaldi_ark
-stage=3
-stop_stage=4
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=0
+stop_stage=2
# feature configuration
feats_dim=80
-sample_frequency=16000
-nj=100
-speed_perturb="0.9,1.0,1.1"
+nj=64
# data
-data_librispeech=
+raw_data=
+data_url=www.openslr.org/resources/12
# bpe model
nbpe=5000
bpemode=unigram
# exp tag
-tag=""
+tag="exp1"
. utils/parse_options.sh || exit 1;
@@ -54,8 +51,7 @@
test_sets="test_clean test_other dev_clean dev_other"
asr_config=conf/train_asr_conformer.yaml
-#asr_config=conf/train_asr_conformer_uttnorm.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
inference_config=conf/decode_asr_transformer.yaml
#inference_config=conf/decode_asr_transformer_beam60_ctc0.3.yaml
@@ -73,96 +69,52 @@
_ngpu=0
fi
-if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
- echo "stage 0: Data preparation"
- # Data preparation
- for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
- local/data_prep_librispeech.sh ${data_librispeech}/LibriSpeech/${x} ${feats_dir}/data/${x//-/_}
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
+ local/download_and_untar.sh ${raw_data} ${data_url} ${part}
done
fi
-feat_train_dir=${feats_dir}/${dumpdir}/$train_set; mkdir -p ${feat_train_dir}
-feat_dev_clean_dir=${feats_dir}/${dumpdir}/dev_clean; mkdir -p ${feat_dev_clean_dir}
-feat_dev_other_dir=${feats_dir}/${dumpdir}/dev_other; mkdir -p ${feat_dev_other_dir}
-feat_test_clean_dir=${feats_dir}/${dumpdir}/test_clean; mkdir -p ${feat_test_clean_dir}
-feat_test_other_dir=${feats_dir}/${dumpdir}/test_other; mkdir -p ${feat_test_other_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/$valid_set; mkdir -p ${feat_dev_dir}
-if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
- echo "stage 1: Feature Generation"
- # compute fbank features
- fbankdir=${feats_dir}/fbank
- for x in dev_clean dev_other test_clean test_other; do
- utils/compute_fbank.sh --cmd "$train_cmd" --nj 1 --max_lengths 3000 --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
- ${feats_dir}/data/${x} ${exp_dir}/exp/make_fbank/${x} ${fbankdir}/${x}
- utils/fix_data_feat.sh ${fbankdir}/${x}
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ for x in dev-clean dev-other test-clean test-other train-clean-100 train-clean-360 train-other-500; do
+ local/data_prep.sh ${raw_data}/LibriSpeech/${x} ${feats_dir}/data/${x//-/_}
done
-
- mkdir ${feats_dir}/data/$train_set
+ mkdir $feats_dir/data/$valid_set
+ dev_sets="dev_clean dev_other"
+ for file in wav.scp text; do
+ ( for f in $dev_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$valid_set/$file || exit 1;
+ done
+ mkdir $feats_dir/data/$train_set
train_sets="train_clean_100 train_clean_360 train_other_500"
for file in wav.scp text; do
( for f in $train_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$train_set/$file || exit 1;
done
- utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --max_lengths 3000 --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
- ${feats_dir}/data/$train_set ${exp_dir}/exp/make_fbank/$train_set ${fbankdir}/$train_set
- utils/fix_data_feat.sh ${fbankdir}/$train_set
-
- # compute global cmvn
- utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
- ${fbankdir}/$train_set ${exp_dir}/exp/make_fbank/$train_set
-
- # apply cmvn
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
- ${fbankdir}/$train_set ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/$train_set ${feat_train_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1 \
- ${fbankdir}/dev_clean ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/dev_clean ${feat_dev_clean_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1\
- ${fbankdir}/dev_other ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/dev_other ${feat_dev_other_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1 \
- ${fbankdir}/test_clean ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/test_clean ${feat_test_clean_dir}
- utils/apply_cmvn.sh --cmd "$train_cmd" --nj 1 \
- ${fbankdir}/test_other ${fbankdir}/$train_set/cmvn.json ${exp_dir}/exp/make_fbank/test_other ${feat_test_other_dir}
-
- cp ${fbankdir}/$train_set/text ${fbankdir}/$train_set/speech_shape ${fbankdir}/$train_set/text_shape ${feat_train_dir}
- cp ${fbankdir}/dev_clean/text ${fbankdir}/dev_clean/speech_shape ${fbankdir}/dev_clean/text_shape ${feat_dev_clean_dir}
- cp ${fbankdir}/dev_other/text ${fbankdir}/dev_other/speech_shape ${fbankdir}/dev_other/text_shape ${feat_dev_other_dir}
- cp ${fbankdir}/test_clean/text ${fbankdir}/test_clean/speech_shape ${fbankdir}/test_clean/text_shape ${feat_test_clean_dir}
- cp ${fbankdir}/test_other/text ${fbankdir}/test_other/speech_shape ${fbankdir}/test_other/text_shape ${feat_test_other_dir}
-
- dev_sets="dev_clean dev_other"
- for file in feats.scp text speech_shape text_shape; do
- ( for f in $dev_sets; do cat $feats_dir/${dumpdir}/$f/$file; done ) | sort -k1 > $feat_dev_dir/$file || exit 1;
- done
-
- #generate ark list
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/${train_set} ${feat_train_dir}
- utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/${valid_set} ${feat_dev_dir}
fi
-dict=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
+fi
+
+token_list=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
bpemodel=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}
-echo "dictionary: ${dict}"
+echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
### Task dependent. You have to check non-linguistic symbols used in the corpus.
echo "stage 2: Dictionary and Json Data Preparation"
mkdir -p ${feats_dir}/data/lang_char/
- echo "<blank>" > ${dict}
- echo "<s>" >> ${dict}
- echo "</s>" >> ${dict}
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
cut -f 2- -d" " ${feats_dir}/data/${train_set}/text > ${feats_dir}/data/lang_char/input.txt
- spm_train --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
- spm_encode --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${dict}
- echo "<unk>" >> ${dict}
- wc -l ${dict}
-
- vocab_size=$(cat ${dict} | wc -l)
- awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
- awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/$train_set
- mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/$valid_set
- cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/$train_set
- cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/$valid_set
+ local/spm_train.py --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
+ local/spm_encode.py --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${token_list}
+ echo "<unk>" >> ${token_list}
fi
-
# Training Stage
world_size=$gpu_num # run on one machine
@@ -181,20 +133,22 @@
rank=$i
local_rank=$i
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
+ train.py \
+ --task_name asr \
--gpu_id $gpu_id \
--use_preprocessor true \
--split_with_space false \
--bpemodel ${bpemodel}.model \
--token_type $token_type \
- --dataset_type $dataset_type \
- --token_list $dict \
- --train_data_file $feats_dir/$dumpdir/${train_set}/ark_txt.scp \
- --valid_data_file $feats_dir/$dumpdir/${valid_set}/ark_txt.scp \
+ --token_list $token_list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
--resume true \
--output_dir ${exp_dir}/exp/${model_dir} \
--config $asr_config \
- --input_size $feats_dim \
--ngpu $gpu_num \
--num_worker_count $count \
--multiprocessing_distributed true \
@@ -220,7 +174,7 @@
exit 0
fi
mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
+ _data="${feats_dir}/data/${dset}"
key_file=${_data}/${scp}
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
@@ -241,6 +195,7 @@
--njob ${njob} \
--gpuid_list ${gpuid_list} \
--data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
--key_file "${_logdir}"/keys.JOB.scp \
--asr_train_config "${asr_exp}"/config.yaml \
--asr_model_file "${asr_exp}"/"${inference_asr_model}" \
diff --git a/egs/librispeech_100h/conformer/conf/decode_asr_transformer.yaml b/egs/librispeech_100h/conformer/conf/decode_asr_transformer.yaml
new file mode 100644
index 0000000..a147fa7
--- /dev/null
+++ b/egs/librispeech_100h/conformer/conf/decode_asr_transformer.yaml
@@ -0,0 +1,6 @@
+beam_size: 10
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.5
+lm_weight: 0.7
diff --git a/egs/librispeech/conformer/conf/train_asr_conformer_uttnorm.yaml b/egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml
similarity index 77%
rename from egs/librispeech/conformer/conf/train_asr_conformer_uttnorm.yaml
rename to egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml
index 16b7cc0..bd92bb0 100644
--- a/egs/librispeech/conformer/conf/train_asr_conformer_uttnorm.yaml
+++ b/egs/librispeech_100h/conformer/conf/train_asr_conformer.yaml
@@ -1,8 +1,8 @@
encoder: conformer
encoder_conf:
- output_size: 512
- attention_heads: 8
- linear_units: 2048
+ output_size: 256
+ attention_heads: 4
+ linear_units: 1024
num_blocks: 12
dropout_rate: 0.1
positional_dropout_rate: 0.1
@@ -19,7 +19,7 @@
decoder: transformer
decoder_conf:
- attention_heads: 8
+ attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
@@ -27,13 +27,25 @@
self_attention_dropout_rate: 0.1
src_attention_dropout_rate: 0.1
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1
length_normalized_loss: false
-accum_grad: 2
-max_epoch: 50
+accum_grad: 1
+max_epoch: 210
patience: none
init: none
best_model_criterion:
@@ -44,11 +56,11 @@
optim: adam
optim_conf:
- lr: 0.0025
+ lr: 0.002
weight_decay: 0.000001
scheduler: warmuplr
scheduler_conf:
- warmup_steps: 40000
+ warmup_steps: 15000
specaug: specaug
specaug_conf:
@@ -64,7 +76,7 @@
time_mask_width_ratio_range:
- 0.
- 0.05
- num_time_mask: 10
+ num_time_mask: 5
dataset_conf:
shuffle: True
@@ -77,4 +89,4 @@
num_workers: 8
log_interval: 50
-normalize: utterance_mvn
\ No newline at end of file
+normalize: None
\ No newline at end of file
diff --git a/egs/librispeech/conformer/local/data_prep_librispeech.sh b/egs/librispeech_100h/conformer/local/data_prep.sh
similarity index 100%
copy from egs/librispeech/conformer/local/data_prep_librispeech.sh
copy to egs/librispeech_100h/conformer/local/data_prep.sh
diff --git a/egs/librispeech_100h/conformer/local/download_and_untar.sh b/egs/librispeech_100h/conformer/local/download_and_untar.sh
new file mode 100755
index 0000000..fe322e4
--- /dev/null
+++ b/egs/librispeech_100h/conformer/local/download_and_untar.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
+ echo " train-clean-100, train-clean-360, train-other-500."
+ exit 1
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1
+fi
+
+part_ok=false
+list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1
+fi
+
+if [ -f $data/LibriSpeech/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0
+fi
+
+
+# sizes of the archive files in bytes. This is some older versions.
+sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
+# sizes_new is the archive file sizes of the final release. Some of these sizes are of
+# things we probably won't download.
+sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
+
+if [ -f $data/$part.tar.gz ]; then
+ size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tar.gz
+ else
+ echo "$data/$part.tar.gz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tar.gz ]; then
+ if ! which wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1
+ fi
+ full_url=$url/$part.tar.gz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ if ! wget -P $data --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1
+ fi
+fi
+
+if ! tar -C $data -xvzf $data/$part.tar.gz; then
+ echo "$0: error un-tarring archive $data/$part.tar.gz"
+ exit 1
+fi
+
+touch $data/LibriSpeech/$part/.complete
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
+ rm $data/$part.tar.gz
+fi
diff --git a/egs/librispeech_100h/conformer/local/spm_encode.py b/egs/librispeech_100h/conformer/local/spm_encode.py
new file mode 100755
index 0000000..9e1c15f
--- /dev/null
+++ b/egs/librispeech_100h/conformer/local/spm_encode.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+
+
+import argparse
+import contextlib
+import sys
+
+import sentencepiece as spm
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--model", required=True,
+ help="sentencepiece model to use for encoding")
+ parser.add_argument("--inputs", nargs="+", default=['-'],
+ help="input files to filter/encode")
+ parser.add_argument("--outputs", nargs="+", default=['-'],
+ help="path to save encoded outputs")
+ parser.add_argument("--output_format", choices=["piece", "id"], default="piece")
+ parser.add_argument("--min-len", type=int, metavar="N",
+ help="filter sentence pairs with fewer than N tokens")
+ parser.add_argument("--max-len", type=int, metavar="N",
+ help="filter sentence pairs with more than N tokens")
+ args = parser.parse_args()
+
+ assert len(args.inputs) == len(args.outputs), \
+ "number of input and output paths should match"
+
+ sp = spm.SentencePieceProcessor()
+ sp.Load(args.model)
+
+ if args.output_format == "piece":
+ def encode(l):
+ return sp.EncodeAsPieces(l)
+ elif args.output_format == "id":
+ def encode(l):
+ return list(map(str, sp.EncodeAsIds(l)))
+ else:
+ raise NotImplementedError
+
+ if args.min_len is not None or args.max_len is not None:
+ def valid(line):
+ return (
+ (args.min_len is None or len(line) >= args.min_len) and
+ (args.max_len is None or len(line) <= args.max_len)
+ )
+ else:
+ def valid(lines):
+ return True
+
+ with contextlib.ExitStack() as stack:
+ inputs = [
+ stack.enter_context(open(input, "r", encoding="utf-8"))
+ if input != "-" else sys.stdin
+ for input in args.inputs
+ ]
+ outputs = [
+ stack.enter_context(open(output, "w", encoding="utf-8"))
+ if output != "-" else sys.stdout
+ for output in args.outputs
+ ]
+
+ stats = {
+ "num_empty": 0,
+ "num_filtered": 0,
+ }
+
+ def encode_line(line):
+ line = line.strip()
+ if len(line) > 0:
+ line = encode(line)
+ if valid(line):
+ return line
+ else:
+ stats["num_filtered"] += 1
+ else:
+ stats["num_empty"] += 1
+ return None
+
+ for i, lines in enumerate(zip(*inputs), start=1):
+ enc_lines = list(map(encode_line, lines))
+ if not any(enc_line is None for enc_line in enc_lines):
+ for enc_line, output_h in zip(enc_lines, outputs):
+ print(" ".join(enc_line), file=output_h)
+ if i % 10000 == 0:
+ print("processed {} lines".format(i), file=sys.stderr)
+
+ print("skipped {} empty lines".format(stats["num_empty"]), file=sys.stderr)
+ print("filtered {} lines".format(stats["num_filtered"]), file=sys.stderr)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/librispeech_100h/conformer/local/spm_train.py b/egs/librispeech_100h/conformer/local/spm_train.py
new file mode 100755
index 0000000..134a0b1
--- /dev/null
+++ b/egs/librispeech_100h/conformer/local/spm_train.py
@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+# Copyright (c) Facebook, Inc. and its affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# https://github.com/pytorch/fairseq/blob/master/LICENSE
+import sys
+
+import sentencepiece as spm
+
+if __name__ == "__main__":
+ spm.SentencePieceTrainer.Train(" ".join(sys.argv[1:]))
diff --git a/egs/mars/sd/path.sh b/egs/librispeech_100h/conformer/path.sh
similarity index 100%
rename from egs/mars/sd/path.sh
rename to egs/librispeech_100h/conformer/path.sh
diff --git a/egs/librispeech_100h/conformer/run.sh b/egs/librispeech_100h/conformer/run.sh
new file mode 100755
index 0000000..e980806
--- /dev/null
+++ b/egs/librispeech_100h/conformer/run.sh
@@ -0,0 +1,212 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=2
+count=1
+gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="../DATA" #feature output dictionary
+exp_dir="."
+lang=en
+token_type=bpe
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=3
+stop_stage=4
+
+# feature configuration
+feats_dim=80
+nj=64
+
+# data
+raw_data=
+data_url=www.openslr.org/resources/12
+
+# bpe model
+nbpe=5000
+bpemode=unigram
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train_clean_100
+valid_set=dev
+test_sets="test_clean test_other dev_clean dev_other"
+
+asr_config=conf/train_asr_conformer.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+
+inference_config=conf/decode_asr_transformer.yaml
+#inference_config=conf/decode_asr_transformer_beam60_ctc0.3.yaml
+inference_asr_model=valid.acc.ave_10best.pth
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+ inference_nj=$[${ngpu}*${njob}]
+ _ngpu=1
+else
+ inference_nj=$njob
+ _ngpu=0
+fi
+
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ for part in dev-clean test-clean dev-other test-other train-clean-100; do
+ local/download_and_untar.sh ${raw_data} ${data_url} ${part}
+ done
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ for x in dev-clean dev-other test-clean test-other train-clean-100; do
+ local/data_prep.sh ${raw_data}/LibriSpeech/${x} ${feats_dir}/data/${x//-/_}
+ done
+ mkdir $feats_dir/data/$valid_set
+ dev_sets="dev_clean dev_other"
+ for file in wav.scp text; do
+ ( for f in $dev_sets; do cat $feats_dir/data/$f/$file; done ) | sort -k1 > $feats_dir/data/$valid_set/$file || exit 1;
+ done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} ${feats_dir}/data/${train_set}
+fi
+
+token_list=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
+bpemodel=${feats_dir}/data/lang_char/${train_set}_${bpemode}${nbpe}
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ ### Task dependent. You have to check non-linguistic symbols used in the corpus.
+ echo "stage 2: Dictionary and Json Data Preparation"
+ mkdir -p ${feats_dir}/data/lang_char/
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ cut -f 2- -d" " ${feats_dir}/data/${train_set}/text > ${feats_dir}/data/lang_char/input.txt
+ local/spm_train.py --input=${feats_dir}/data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
+ local/spm_encode.py --model=${bpemodel}.model --output_format=piece < ${feats_dir}/data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0}' >> ${token_list}
+ echo "<unk>" >> ${token_list}
+fi
+
+# Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: Training"
+ mkdir -p ${exp_dir}/exp/${model_dir}
+ mkdir -p ${exp_dir}/exp/${model_dir}/log
+ INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $gpu_num; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ train.py \
+ --task_name asr \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --split_with_space false \
+ --bpemodel ${bpemodel}.model \
+ --token_type $token_type \
+ --token_list $token_list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --speed_perturb ${speed_perturb} \
+ --resume true \
+ --output_dir ${exp_dir}/exp/${model_dir} \
+ --config $asr_config \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: Inference"
+ for dset in ${test_sets}; do
+ asr_exp=${exp_dir}/exp/${model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/data/${dset}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --gpuid_list ${gpuid_list} \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/cmvn.mvn \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${asr_exp}"/config.yaml \
+ --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode asr \
+ ${_opts}
+
+ for f in token token_int score text; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+ python utils/compute_wer.py ${_data}/text ${_dir}/text ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+ done
+fi
\ No newline at end of file
diff --git a/egs/librispeech_100h/conformer/utils b/egs/librispeech_100h/conformer/utils
new file mode 120000
index 0000000..fe070dd
--- /dev/null
+++ b/egs/librispeech_100h/conformer/utils
@@ -0,0 +1 @@
+../../aishell/transformer/utils
\ No newline at end of file
diff --git a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml b/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml
deleted file mode 100644
index 459a741..0000000
--- a/egs/mars/sd/conf/SOND_ECAPATDNN_None_Dot_SAN_L4N512_FSMN_L6N512_n16k2.yaml
+++ /dev/null
@@ -1,121 +0,0 @@
-model: sond
-model_conf:
- lsm_weight: 0.0
- length_normalized_loss: true
- max_spk_num: 16
-
-# speech encoder
-encoder: ecapa_tdnn
-encoder_conf:
- # pass by model, equal to feature dim
- # input_size: 80
- pool_size: 20
- stride: 1
-speaker_encoder: conv
-speaker_encoder_conf:
- input_units: 256
- num_layers: 3
- num_units: 256
- kernel_size: 1
- dropout_rate: 0.0
- position_encoder: null
- out_units: 256
- out_norm: false
- auxiliary_states: false
- tf2torch_tensor_name_prefix_torch: speaker_encoder
- tf2torch_tensor_name_prefix_tf: EAND/speaker_encoder
-ci_scorer: dot
-ci_scorer_conf: {}
-cd_scorer: san
-cd_scorer_conf:
- input_size: 512
- output_size: 512
- out_units: 1
- attention_heads: 4
- linear_units: 1024
- num_blocks: 4
- dropout_rate: 0.0
- positional_dropout_rate: 0.0
- attention_dropout_rate: 0.0
- # use string "null" to remove input layer
- input_layer: "null"
- pos_enc_class: null
- normalize_before: true
- tf2torch_tensor_name_prefix_torch: cd_scorer
- tf2torch_tensor_name_prefix_tf: EAND/compute_distance_layer
-# post net
-decoder: fsmn
-decoder_conf:
- in_units: 32
- out_units: 2517
- filter_size: 31
- fsmn_num_layers: 6
- dnn_num_layers: 1
- num_memory_units: 512
- ffn_inner_dim: 512
- dropout_rate: 0.0
- tf2torch_tensor_name_prefix_torch: decoder
- tf2torch_tensor_name_prefix_tf: EAND/post_net
-frontend: wav_frontend
-frontend_conf:
- fs: 16000
- window: povey
- n_mels: 80
- frame_length: 25
- frame_shift: 10
- filter_length_min: -1
- filter_length_max: -1
- lfr_m: 1
- lfr_n: 1
- dither: 0.0
- snip_edges: false
-
-# minibatch related
-batch_type: length
-# 16s * 16k * 16 samples
-batch_bins: 4096000
-num_workers: 8
-
-# optimization related
-accum_grad: 1
-grad_clip: 5
-max_epoch: 50
-val_scheduler_criterion:
- - valid
- - acc
-best_model_criterion:
-- - valid
- - der
- - min
-- - valid
- - forward_steps
- - max
-keep_nbest_models: 10
-
-optim: adam
-optim_conf:
- lr: 0.001
-scheduler: warmuplr
-scheduler_conf:
- warmup_steps: 10000
-
-# without spec aug
-specaug: null
-specaug_conf:
- apply_time_warp: true
- time_warp_window: 5
- time_warp_mode: bicubic
- apply_freq_mask: true
- freq_mask_width_range:
- - 0
- - 30
- num_freq_mask: 2
- apply_time_mask: true
- time_mask_width_range:
- - 0
- - 40
- num_time_mask: 2
-
-log_interval: 50
-# without normalize
-normalize: None
diff --git a/egs/mars/sd/local_run.sh b/egs/mars/sd/local_run.sh
deleted file mode 100755
index 4516e9f..0000000
--- a/egs/mars/sd/local_run.sh
+++ /dev/null
@@ -1,171 +0,0 @@
-#!/usr/bin/env bash
-
-. ./path.sh || exit 1;
-
-# machines configuration
-CUDA_VISIBLE_DEVICES="6,7"
-gpu_num=2
-count=1
-gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
-# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
-njob=5
-train_cmd=utils/run.pl
-infer_cmd=utils/run.pl
-
-# general configuration
-feats_dir="." #feature output dictionary
-exp_dir="."
-lang=zh
-dumpdir=dump/raw
-feats_type=raw
-token_type=char
-scp=wav.scp
-type=kaldi_ark
-stage=3
-stop_stage=4
-
-# feature configuration
-feats_dim=
-sample_frequency=16000
-nj=32
-speed_perturb=
-
-# exp tag
-tag="exp1"
-
-. utils/parse_options.sh || exit 1;
-
-# Set bash to 'debug' mode, it will exit on :
-# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
-set -e
-set -u
-set -o pipefail
-
-train_set=train
-valid_set=dev
-test_sets="dev test"
-
-asr_config=conf/train_asr_conformer.yaml
-model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
-
-inference_config=conf/decode_asr_transformer.yaml
-inference_asr_model=valid.acc.ave_10best.pb
-
-# you can set gpu num for decoding here
-gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
-ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
-
-if ${gpu_inference}; then
- inference_nj=$[${ngpu}*${njob}]
- _ngpu=1
-else
- inference_nj=$njob
- _ngpu=0
-fi
-
-feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
-feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
-feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
-
-# Training Stage
-world_size=$gpu_num # run on one machine
-if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
- echo "stage 3: Training"
- mkdir -p ${exp_dir}/exp/${model_dir}
- mkdir -p ${exp_dir}/exp/${model_dir}/log
- INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
- if [ -f $INIT_FILE ];then
- rm -f $INIT_FILE
- fi
- init_method=file://$(readlink -f $INIT_FILE)
- echo "$0: init method is $init_method"
- for ((i = 0; i < $gpu_num; ++i)); do
- {
- rank=$i
- local_rank=$i
- gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
- asr_train.py \
- --gpu_id $gpu_id \
- --use_preprocessor true \
- --token_type char \
- --token_list $token_list \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
- --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
- --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
- --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
- --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
- --resume true \
- --output_dir ${exp_dir}/exp/${model_dir} \
- --config $asr_config \
- --input_size $feats_dim \
- --ngpu $gpu_num \
- --num_worker_count $count \
- --multiprocessing_distributed true \
- --dist_init_method $init_method \
- --dist_world_size $world_size \
- --dist_rank $rank \
- --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
- } &
- done
- wait
-fi
-
-# Testing Stage
-if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
- echo "stage 4: Inference"
- for dset in ${test_sets}; do
- asr_exp=${exp_dir}/exp/${model_dir}
- inference_tag="$(basename "${inference_config}" .yaml)"
- _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
- _logdir="${_dir}/logdir"
- if [ -d ${_dir} ]; then
- echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
- exit 0
- fi
- mkdir -p "${_logdir}"
- _data="${feats_dir}/${dumpdir}/${dset}"
- key_file=${_data}/${scp}
- num_scp_file="$(<${key_file} wc -l)"
- _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
- split_scps=
- for n in $(seq "${_nj}"); do
- split_scps+=" ${_logdir}/keys.${n}.scp"
- done
- # shellcheck disable=SC2086
- utils/split_scp.pl "${key_file}" ${split_scps}
- _opts=
- if [ -n "${inference_config}" ]; then
- _opts+="--config ${inference_config} "
- fi
- ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1: "${_nj}" "${_logdir}"/asr_inference.JOB.log \
- python -m funasr.bin.asr_inference_launch \
- --batch_size 1 \
- --ngpu "${_ngpu}" \
- --njob ${njob} \
- --gpuid_list ${gpuid_list} \
- --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
- --key_file "${_logdir}"/keys.JOB.scp \
- --asr_train_config "${asr_exp}"/config.yaml \
- --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
- --output_dir "${_logdir}"/output.JOB \
- --mode asr \
- ${_opts}
-
- for f in token token_int score text; do
- if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
- for i in $(seq "${_nj}"); do
- cat "${_logdir}/output.${i}/1best_recog/${f}"
- done | sort -k1 >"${_dir}/${f}"
- fi
- done
- python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
- python utils/proce_text.py ${_data}/text ${_data}/text.proc
- python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
- tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
- cat ${_dir}/text.cer.txt
- done
-fi
-
diff --git a/egs/mars/sd/scripts/calculate_shapes.py b/egs/mars/sd/scripts/calculate_shapes.py
deleted file mode 100644
index b207f2d..0000000
--- a/egs/mars/sd/scripts/calculate_shapes.py
+++ /dev/null
@@ -1,45 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser: argparse.ArgumentParser):
- parser.add_argument("--input_scp", type=str, required=True)
- parser.add_argument("--out_path")
- args = parser.parse_args()
-
- if not os.path.exists(os.path.dirname(args.out_path)):
- os.makedirs(os.path.dirname(args.out_path))
-
- task_list = load_scp_as_list(args.input_scp)
- return task_list, None, args
-
- def post(self, result_list, args):
- fd = open(args.out_path, "wt", encoding="utf-8")
- for results in result_list:
- for uttid, shape in results:
- fd.write("{} {}\n".format(uttid, ",".join(shape)))
- fd.close()
-
-
-def process(task_args):
- task_idx, task_list, _, args = task_args
- rst = []
- for uttid, file_path in task_list:
- data = kaldiio.load_mat(file_path)
- shape = [str(x) for x in data.shape]
- rst.append((uttid, shape))
- return rst
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/dump_rttm_to_labels.py b/egs/mars/sd/scripts/dump_rttm_to_labels.py
deleted file mode 100644
index ec1c765..0000000
--- a/egs/mars/sd/scripts/dump_rttm_to_labels.py
+++ /dev/null
@@ -1,140 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser: argparse.ArgumentParser):
- parser.add_argument("--rttm_list", type=str, required=True)
- parser.add_argument("--wav_scp_list", type=str, required=True)
- parser.add_argument("--out_dir", type=str, required=True)
- parser.add_argument("--n_spk", type=int, default=8)
- parser.add_argument("--remove_sil", default=False, action="store_true")
- parser.add_argument("--max_overlap", default=0, type=int)
- parser.add_argument("--frame_shift", type=float, default=0.01)
- args = parser.parse_args()
-
- rttm_list = [x.strip() for x in open(args.rttm_list, "rt", encoding="utf-8").readlines()]
- meeting2rttm = OrderedDict()
- for rttm_path in rttm_list:
- meeting2rttm.update(self.load_rttm(rttm_path))
-
- wav_scp_list = [x.strip() for x in open(args.wav_scp_list, "rt", encoding="utf-8").readlines()]
- meeting_scp = OrderedDict()
- for scp_path in wav_scp_list:
- meeting_scp.update(load_scp_as_dict(scp_path))
-
- if len(meeting_scp) != len(meeting2rttm):
- logging.warning("Number of wav and rttm mismatch {} != {}".format(
- len(meeting_scp), len(meeting2rttm)))
- common_keys = set(meeting_scp.keys()) & set(meeting2rttm.keys())
- logging.warning("Keep {} records.".format(len(common_keys)))
- new_meeting_scp = OrderedDict()
- rm_keys = []
- for key in meeting_scp:
- if key not in common_keys:
- rm_keys.append(key)
- else:
- new_meeting_scp[key] = meeting_scp[key]
- logging.warning("Keys are removed from wav scp: {}".format(" ".join(rm_keys)))
-
- new_meeting2rttm = OrderedDict()
- rm_keys = []
- for key in meeting2rttm:
- if key not in common_keys:
- rm_keys.append(key)
- else:
- new_meeting2rttm[key] = meeting2rttm[key]
- logging.warning("Keys are removed from rttm scp: {}".format(" ".join(rm_keys)))
- meeting_scp, meeting2rttm = new_meeting_scp, new_meeting2rttm
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- task_list = [(mid, meeting_scp[mid], meeting2rttm[mid]) for mid in meeting2rttm.keys()]
- return task_list, None, args
-
- @staticmethod
- def load_rttm(rttm_path):
- meeting2rttm = OrderedDict()
- for one_line in open(rttm_path, "rt", encoding="utf-8"):
- mid = one_line.strip().split(" ")[1]
- if mid not in meeting2rttm:
- meeting2rttm[mid] = []
- meeting2rttm[mid].append(one_line.strip())
-
- return meeting2rttm
-
- def post(self, results_list, args):
- pass
-
-
-def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, max_overlap=0,
- sr=None, frame_shift=0.01):
- frame_shift = int(frame_shift * sr)
- num_frame = int((float(length) + (float(frame_shift) / 2)) / frame_shift)
- multi_label = np.zeros([n_spk, num_frame], dtype=np.float32)
- for _, st, dur, spk in spk_turns:
- idx = spk_list.index(spk)
-
- st, dur = int(st * sr), int(dur * sr)
- frame_st = int((float(st) + (float(frame_shift) / 2)) / frame_shift)
- frame_ed = int((float(st+dur) + (float(frame_shift) / 2)) / frame_shift)
- multi_label[idx, frame_st:frame_ed] = 1
-
- if remove_sil:
- speech_count = np.sum(multi_label, axis=0)
- idx = np.nonzero(speech_count)[0]
- multi_label = multi_label[:, idx]
-
- if max_overlap > 0:
- speech_count = np.sum(multi_label, axis=0)
- idx = np.nonzero(speech_count <= max_overlap)[0]
- multi_label = multi_label[:, idx]
-
- label = multi_label.T
- return label # (T, N)
-
-
-def build_labels(wav_path, rttms, n_spk, remove_sil=False, max_overlap=0,
- sr=16000, frame_shift=0.01):
- wav, sr = soundfile.read(wav_path)
- wav_len = len(wav)
- spk_turns = []
- spk_list = []
- for one_line in rttms:
- parts = one_line.strip().split(" ")
- mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), parts[7]
- if spk not in spk_list:
- spk_list.append(spk)
- spk_turns.append((mid, st, dur, spk))
- labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil, max_overlap, sr, frame_shift)
- return labels, spk_list
-
-
-def process(task_args):
- task_idx, task_list, _, args = task_args
- spk_list_writer = open(os.path.join(args.out_dir, "spk_list.{}.txt".format(task_idx+1)),
- "wt", encoding="utf-8")
- out_path = os.path.join(args.out_dir, "labels.{}".format(task_idx + 1))
- label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
- for mid, wav_path, rttms in task_list:
- meeting_labels, spk_list = build_labels(wav_path, rttms, args.n_spk, args.remove_sil, args.max_overlap,
- args.sr, args.frame_shift)
- label_writer(mid, meeting_labels)
- spk_list_writer.write("{} {}\n".format(mid, " ".join(spk_list)))
-
- spk_list_writer.close()
- label_writer.close()
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py b/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
deleted file mode 100644
index cd1ec7b..0000000
--- a/egs/mars/sd/scripts/extract_nonoverlap_segments_v2.py
+++ /dev/null
@@ -1,115 +0,0 @@
-import numpy as np
-import os
-import argparse
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import soundfile as sf
-from tqdm import tqdm
-
-
-class MyRunner(MultiProcessRunnerV3):
- def prepare(self, parser):
- assert isinstance(parser, argparse.ArgumentParser)
- parser.add_argument("wav_scp", type=str)
- parser.add_argument("rttm", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--min_dur", type=float, default=2.0)
- parser.add_argument("--max_spk_num", type=int, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- wav_scp = load_scp_as_list(args.wav_scp)
- meeting2rttms = {}
- for one_line in open(args.rttm, "rt"):
- parts = [x for x in one_line.strip().split(" ") if x != ""]
- mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- if mid not in meeting2rttms:
- meeting2rttms[mid] = []
- meeting2rttms[mid].append(one_line)
-
- task_list = [(mid, wav_path, meeting2rttms[mid]) for (mid, wav_path) in wav_scp]
- return task_list, None, args
-
- def post(self, result_list, args):
- count = [0, 0]
- for result in result_list:
- count[0] += result[0]
- count[1] += result[1]
- print("Found {} speakers, extracted {}.".format(count[1], count[0]))
-
-
-# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
-def calc_multi_label(rttms, length, sr=8000, max_spk_num=4):
- labels = np.zeros([max_spk_num, length], int)
- spk_list = []
- for one_line in rttms:
- parts = [x for x in one_line.strip().split(" ") if x != ""]
- mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- spk_name = spk_name.replace("spk", "").replace(mid, "").replace("-", "")
- if spk_name.isdigit():
- spk_name = "{}_S{:03d}".format(mid, int(spk_name))
- else:
- spk_name = "{}_{}".format(mid, spk_name)
- if spk_name not in spk_list:
- spk_list.append(spk_name)
- st, dur = int(st*sr), int(dur*sr)
- idx = spk_list.index(spk_name)
- labels[idx, st:st+dur] = 1
- return labels, spk_list
-
-
-def get_nonoverlap_turns(multi_label, spk_list):
- turns = []
- label = np.sum(multi_label, axis=0) == 1
- spk, in_turn, st = None, False, 0
- for i in range(len(label)):
- if not in_turn and label[i]:
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- if not label[i]:
- in_turn = False
- turns.append([st, i, spk])
- elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
- turns.append([st, i, spk])
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- turns.append([st, len(label), spk])
- return turns
-
-
-def process(task_args):
- task_id, task_list, _, args = task_args
- spk_count = [0, 0]
- for mid, wav_path, rttms in task_list:
- wav, sr = sf.read(wav_path, dtype="int16")
- assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
- multi_label, spk_list = calc_multi_label(rttms, len(wav), args.sr, args.max_spk_num)
- turns = get_nonoverlap_turns(multi_label, spk_list)
- extracted_spk = []
- count = 1
- for st, ed, spk in tqdm(turns, total=len(turns), ascii=True, disable=args.no_pbar):
- if (ed - st) >= args.min_dur * args.sr:
- seg = wav[st: ed]
- save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
- if not os.path.exists(os.path.dirname(save_path)):
- os.makedirs(os.path.dirname(save_path))
- sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- count += 1
- if spk not in extracted_spk:
- extracted_spk.append(spk)
- if len(extracted_spk) != len(spk_list):
- print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
- mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
- ))
- spk_count[0] += len(extracted_spk)
- spk_count[1] += len(spk_list)
- return spk_count
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py b/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py
deleted file mode 100644
index e579f51..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/calc_real_meeting_labels.py
+++ /dev/null
@@ -1,73 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("dir", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--n_spk", type=int, default=4)
- parser.add_argument("--remove_sil", default=False, action="store_true")
- args = parser.parse_args()
-
- meeting_scp = load_scp_as_dict(os.path.join(args.dir, "meeting.scp"))
- rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- task_list = [(mid, meeting_scp[mid], rttm_path) for mid, rttm_path in rttm_scp]
- return task_list, None, args
-
- def post(self, results_list, args):
- pass
-
-
-def calc_labels(spk_turns, spk_list, length, n_spk, remove_sil=False, sr=16000):
- multi_label = np.zeros([n_spk, length], dtype=int)
- for _, st, dur, spk in spk_turns:
- st, dur = int(st * sr), int(dur * sr)
- idx = spk_list.index(spk)
- multi_label[idx, st:st+dur] = 1
- if not remove_sil:
- return multi_label.T
-
- speech_count = np.sum(multi_label, axis=0)
- idx = np.nonzero(speech_count)[0]
- label = multi_label[:, idx].T
- return label # (T, N)
-
-
-def build_labels(wav_path, rttm_path, n_spk, remove_sil=False, sr=16000):
- wav_len = int(librosa.get_duration(filename=wav_path, sr=sr) * sr)
- spk_turns = []
- spk_list = []
- for one_line in open(rttm_path, "rt"):
- parts = one_line.strip().split(" ")
- mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
- spk = "{}_S{:03d}".format(mid, spk)
- if spk not in spk_list:
- spk_list.append(spk)
- spk_turns.append((mid, st, dur, spk))
- labels = calc_labels(spk_turns, spk_list, wav_len, n_spk, remove_sil)
- return labels
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- for mid, wav_path, rttm_path in task_list:
- meeting_labels = build_labels(wav_path, rttm_path, args.n_spk, args.remove_sil)
- save_path = os.path.join(args.out_dir, "{}.lbl".format(mid))
- np.save(save_path, meeting_labels.astype(bool))
- print(mid)
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py b/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py
deleted file mode 100644
index 11bc395..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/clip_meeting_without_silence.py
+++ /dev/null
@@ -1,53 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import soundfile as sf
-from tqdm import tqdm
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("wav_scp", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--chunk_dur", type=float, default=16)
- parser.add_argument("--shift_dur", type=float, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- wav_scp = load_scp_as_list(args.wav_scp)
- return wav_scp, None, args
-
- def post(self, results_list, args):
- pass
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- chunk_len, shift_len = int(args.chunk_dur * args.sr), int(args.shift_dur * args.sr)
- for mid, wav_path in tqdm(task_list, total=len(task_list), ascii=True, disable=args.no_pbar):
- if not os.path.exists(os.path.join(args.out_dir, mid)):
- os.makedirs(os.path.join(args.out_dir, mid))
-
- wav = librosa.load(wav_path, args.sr, True)[0] * 32767
- n_chunk = (len(wav) - chunk_len) // shift_len + 1
- if (len(wav) - chunk_len) % shift_len > 0:
- n_chunk += 1
- for i in range(n_chunk):
- seg = wav[i*shift_len: i*shift_len + chunk_len]
- st = int(float(i*shift_len)/args.sr * 100)
- dur = int(float(len(seg))/args.sr * 100)
- file_name = "{}_S{:04d}_{:07d}_{:07d}.wav".format(mid, i, st, st+dur)
- save_path = os.path.join(args.out_dir, mid, file_name)
- sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py b/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py
deleted file mode 100644
index 011bd7c..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/convert_rttm_to_seg_file.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("--rttm_scp", type=str)
- parser.add_argument("--seg_file", type=str)
- args = parser.parse_args()
-
- if not os.path.exists(os.path.dirname(args.seg_file)):
- os.makedirs(os.path.dirname(args.seg_file))
-
- task_list = load_scp_as_list(args.rttm_scp)
- return task_list, None, args
-
- def post(self, results_list, args):
- with open(args.seg_file, "wt", encoding="utf-8") as fd:
- for results in results_list:
- fd.writelines(results)
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- outputs = []
- for mid, rttm_path in task_list:
- spk_turns = []
- length = 0
- for one_line in open(rttm_path, 'rt', encoding="utf-8"):
- parts = one_line.strip().split(" ")
- _, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- st, ed = int(st*100), int((st + dur)*100)
- length = ed if ed > length else length
- spk_turns.append([mid, st, ed, spk_name])
- is_sph = np.zeros((length+1, ), dtype=bool)
- for _, st, ed, _ in spk_turns:
- is_sph[st:ed] = True
-
- st, in_speech = 0, False
- for i in range(length+1):
- if not in_speech and is_sph[i]:
- st, in_speech = i, True
- if in_speech and not is_sph[i]:
- in_speech = False
- outputs.append("{}-{:07d}-{:07d} {} {:.2f} {:.2f}\n".format(
- mid, st, i, mid, float(st)/100, float(i)/100
- ))
- return outputs
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py b/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py
deleted file mode 100644
index a2bcd39..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/dump_real_meeting_chunks.py
+++ /dev/null
@@ -1,138 +0,0 @@
-import soundfile
-import kaldiio
-from tqdm import tqdm
-import json
-import os
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import numpy as np
-import argparse
-import random
-
-short_spk_list = []
-def calc_rand_ivc(spk, spk2utt, utt2ivc, utt2frames, total_len=3000):
- all_utts = spk2utt[spk]
- idx_list = list(range(len(all_utts)))
- random.shuffle(idx_list)
- count = 0
- utt_list = []
- for i in idx_list:
- utt_id = all_utts[i]
- utt_list.append(utt_id)
- count += int(utt2frames[utt_id])
- if count >= total_len:
- break
- if count < 300 and spk not in short_spk_list:
- print("Speaker {} has only {} frames, but expect {} frames at least, use them all.".format(spk, count, 300))
- short_spk_list.append(spk)
-
- ivc_list = [kaldiio.load_mat(utt2ivc[utt]) for utt in utt_list]
- ivc_list = [x/np.linalg.norm(x, axis=-1) for x in ivc_list]
- ivc = np.concatenate(ivc_list, axis=0)
- ivc = np.mean(ivc, axis=0, keepdims=False)
- return ivc
-
-
-def process(meeting_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args):
- out_prefix = args.out
-
- ivc_dim = 192
- win_len, win_shift = 400, 160
- label_weights = 2 ** np.array(list(range(args.n_spk)))
- wav_writer = kaldiio.WriteHelper("ark,scp:{}_wav.ark,{}_wav.scp".format(out_prefix, out_prefix))
- ivc_writer = kaldiio.WriteHelper("ark,scp:{}_profile.ark,{}_profile.scp".format(out_prefix, out_prefix))
- label_writer = kaldiio.WriteHelper("ark,scp:{}_label.ark,{}_label.scp".format(out_prefix, out_prefix))
-
-
- frames_list = []
- chunk_size = int(args.chunk_size * args.sr)
- chunk_shift = int(args.chunk_shift * args.sr)
- for mid, meeting_wav_path in tqdm(meeting_scp, total=len(meeting_scp), ascii=True, disable=args.no_pbar):
- meeting_wav, sr = soundfile.read(meeting_wav_path, dtype='float32')
- num_chunk = (len(meeting_wav) - chunk_size) // chunk_shift + 1
- meeting_labels = np.load(labels_scp[mid])
- for i in range(num_chunk):
- st, ed = i*chunk_shift, i*chunk_shift+chunk_size
- seg_id = "{}-{:03d}-{:06d}-{:06d}".format(mid, i, int(st/args.sr*100), int(ed/args.sr*100))
- wav_writer(seg_id, meeting_wav[st: ed])
-
- xvec_list = []
- for spk in meeting2spk_list[mid]:
- spk_xvec = calc_rand_ivc(spk, spk2utt, utt2xvec, utt2frames, 1000)
- xvec_list.append(spk_xvec)
- for _ in range(args.n_spk - len(xvec_list)):
- xvec_list.append(np.zeros((ivc_dim,), dtype=np.float32))
- xvec = np.row_stack(xvec_list)
- ivc_writer(seg_id, xvec)
-
- wav_label = meeting_labels[st:ed, :]
- frame_num = (ed-st) // win_shift
- # wav_label = np.pad(wav_label, ((win_len/2, win_len/2), (0, 0)), "constant")
- feat_label = np.zeros((frame_num, wav_label.shape[1]), dtype=np.float32)
- for i in range(frame_num):
- frame_label = wav_label[i*win_shift: (i+1)*win_shift, :]
- feat_label[i, :] = (np.sum(frame_label, axis=0) > 0).astype(np.float32)
- label_writer(seg_id, feat_label)
-
- frames_list.append((mid, feat_label.shape[0]))
- return frames_list
-
-
-def calc_spk_list(rttm_path):
- spk_list = []
- for one_line in open(rttm_path, "rt"):
- parts = one_line.strip().split(" ")
- mid, st, dur, spk = parts[1], float(parts[3]), float(parts[4]), int(parts[7])
- spk = "{}_S{:03d}".format(mid, spk)
- if spk not in spk_list:
- spk_list.append(spk)
-
- return spk_list
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--dir", required=True, type=str, default=None,
- help="feats.scp")
- parser.add_argument("--out", required=True, type=str, default=None,
- help="The prefix of dumpped files.")
- parser.add_argument("--n_spk", type=int, default=4)
- parser.add_argument("--use_lfr", default=False, action="store_true")
- parser.add_argument("--no_pbar", default=False, action="store_true")
- parser.add_argument("--sr", type=int, default=16000)
- parser.add_argument("--chunk_size", type=int, default=16)
- parser.add_argument("--chunk_shift", type=int, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(os.path.dirname(args.out)):
- os.makedirs(os.path.dirname(args.out))
-
- meetings_scp = load_scp_as_list(os.path.join(args.dir, "meetings_rmsil.scp"))
- labels_scp = load_scp_as_dict(os.path.join(args.dir, "labels.scp"))
- rttm_scp = load_scp_as_list(os.path.join(args.dir, "rttm.scp"))
- utt2spk = load_scp_as_dict(os.path.join(args.dir, "utt2spk"))
- utt2xvec = load_scp_as_dict(os.path.join(args.dir, "utt2xvec"))
- utt2wav = load_scp_as_dict(os.path.join(args.dir, "wav.scp"))
- utt2frames = {}
- for uttid, wav_path in utt2wav.items():
- wav, sr = soundfile.read(wav_path, dtype="int16")
- utt2frames[uttid] = int(len(wav) / sr * 100)
-
- meeting2spk_list = {}
- for mid, rttm_path in rttm_scp:
- meeting2spk_list[mid] = calc_spk_list(rttm_path)
-
- spk2utt = {}
- for utt, spk in utt2spk.items():
- if utt in utt2xvec and utt in utt2frames and int(utt2frames[utt]) > 25:
- if spk not in spk2utt:
- spk2utt[spk] = []
- spk2utt[spk].append(utt)
-
- # random.shuffle(feat_scp)
- meeting_lens = process(meetings_scp, labels_scp, spk2utt, utt2xvec, utt2frames, meeting2spk_list, args)
- total_frames = sum([x[1] for x in meeting_lens])
- print("Total chunks: {:6d}, total frames: {:10d}".format(len(meeting_lens), total_frames))
-
-
-if __name__ == '__main__':
- main()
diff --git a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py b/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
deleted file mode 100644
index 1d6f53e..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/extract_nonoverlap_segments.py
+++ /dev/null
@@ -1,110 +0,0 @@
-from __future__ import print_function
-import numpy as np
-import os
-import sys
-import argparse
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import librosa
-import soundfile as sf
-from copy import deepcopy
-import json
-from tqdm import tqdm
-
-
-class MyRunner(MultiProcessRunnerV3):
- def prepare(self, parser):
- assert isinstance(parser, argparse.ArgumentParser)
- parser.add_argument("wav_scp", type=str)
- parser.add_argument("rttm_scp", type=str)
- parser.add_argument("out_dir", type=str)
- parser.add_argument("--min_dur", type=float, default=2.0)
- parser.add_argument("--max_spk_num", type=int, default=4)
- args = parser.parse_args()
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- wav_scp = load_scp_as_list(args.wav_scp)
- rttm_scp = load_scp_as_dict(args.rttm_scp)
- task_list = [(mid, wav_path, rttm_scp[mid]) for (mid, wav_path) in wav_scp]
- return task_list, None, args
-
- def post(self, result_list, args):
- count = [0, 0]
- for result in result_list:
- count[0] += result[0]
- count[1] += result[1]
- print("Found {} speakers, extracted {}.".format(count[1], count[0]))
-
-
-# SPEAKER R8001_M8004_MS801 1 6.90 11.39 <NA> <NA> 1 <NA> <NA>
-def calc_multi_label(rttm_path, length, sr=16000, max_spk_num=4):
- labels = np.zeros([max_spk_num, length], int)
- spk_list = []
- for one_line in open(rttm_path, 'rt'):
- parts = one_line.strip().split(" ")
- mid, st, dur, spk_name = parts[1], float(parts[3]), float(parts[4]), parts[7]
- if spk_name.isdigit():
- spk_name = "{}_S{:03d}".format(mid, int(spk_name))
- if spk_name not in spk_list:
- spk_list.append(spk_name)
- st, dur = int(st*sr), int(dur*sr)
- idx = spk_list.index(spk_name)
- labels[idx, st:st+dur] = 1
- return labels, spk_list
-
-
-def get_nonoverlap_turns(multi_label, spk_list):
- turns = []
- label = np.sum(multi_label, axis=0) == 1
- spk, in_turn, st = None, False, 0
- for i in range(len(label)):
- if not in_turn and label[i]:
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- if not label[i]:
- in_turn = False
- turns.append([st, i, spk])
- elif label[i] and spk != spk_list[np.argmax(multi_label[:, i], axis=0)]:
- turns.append([st, i, spk])
- st, in_turn = i, True
- spk = spk_list[np.argmax(multi_label[:, i], axis=0)]
- if in_turn:
- turns.append([st, len(label), spk])
- return turns
-
-
-def process(task_args):
- task_id, task_list, _, args = task_args
- spk_count = [0, 0]
- for mid, wav_path, rttm_path in task_list:
- wav, sr = sf.read(wav_path, dtype="int16")
- assert sr == args.sr, "args.sr {}, file sr {}".format(args.sr, sr)
- multi_label, spk_list = calc_multi_label(rttm_path, len(wav), args.sr, args.max_spk_num)
- turns = get_nonoverlap_turns(multi_label, spk_list)
- extracted_spk = []
- count = 1
- for st, ed, spk in tqdm(turns, total=len(turns), ascii=True):
- if (ed - st) >= args.min_dur * args.sr:
- seg = wav[st: ed]
- save_path = os.path.join(args.out_dir, mid, spk, "{}_U{:04d}.wav".format(spk, count))
- if not os.path.exists(os.path.dirname(save_path)):
- os.makedirs(os.path.dirname(save_path))
- sf.write(save_path, seg.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- count += 1
- if spk not in extracted_spk:
- extracted_spk.append(spk)
- if len(extracted_spk) != len(spk_list):
- print("{}: Found {} speakers, but only extracted {}. {} are filtered due to min_dur".format(
- mid, len(spk_list), len(extracted_spk), " ".join([x for x in spk_list if x not in extracted_spk])
- ))
- spk_count[0] += len(extracted_spk)
- spk_count[1] += len(spk_list)
- return spk_count
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py b/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py
deleted file mode 100644
index 8b3195f..0000000
--- a/egs/mars/sd/scripts/real_meeting_process/remove_silence_from_wav.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import numpy as np
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import librosa
-import soundfile as sf
-import argparse
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser):
- parser.add_argument("dir", type=str)
- parser.add_argument("out_dir", type=str)
- args = parser.parse_args()
-
- meeting_scp = load_scp_as_list(os.path.join(args.dir, "meeting.scp"))
- vad_file = open(os.path.join(args.dir, "segments"), encoding="utf-8")
- meeting2vad = {}
- for one_line in vad_file:
- uid, mid, st, ed = one_line.strip().split(" ")
- st, ed = int(float(st) * args.sr), int(float(ed) * args.sr)
- if mid not in meeting2vad:
- meeting2vad[mid] = []
- meeting2vad[mid].append((uid, st, ed))
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- task_list = [(mid, wav_path, meeting2vad[mid]) for mid, wav_path in meeting_scp]
- return task_list, None, args
-
- def post(self, results_list, args):
- pass
-
-
-def process(task_args):
- _, task_list, _, args = task_args
- for mid, wav_path, vad_list in task_list:
- wav = librosa.load(wav_path, args.sr, True)[0] * 32767
- seg_list = []
- pos_map = []
- offset = 0
- for uid, st, ed in vad_list:
- seg_list.append(wav[st: ed])
- pos_map.append("{} {} {} {} {}\n".format(uid, st, ed, offset, offset+ed-st))
- offset = offset + ed - st
- out = np.concatenate(seg_list, axis=0)
- save_path = os.path.join(args.out_dir, "{}.wav".format(mid))
- sf.write(save_path, out.astype(np.int16), args.sr, "PCM_16", "LITTLE", "WAV", True)
- map_path = os.path.join(args.out_dir, "{}.pos".format(mid))
- with open(map_path, "wt", encoding="utf-8") as fd:
- fd.writelines(pos_map)
- print(mid)
- return None
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/egs/mars/sd/scripts/simu_chunk_with_labels.py b/egs/mars/sd/scripts/simu_chunk_with_labels.py
deleted file mode 100644
index f61b808..0000000
--- a/egs/mars/sd/scripts/simu_chunk_with_labels.py
+++ /dev/null
@@ -1,261 +0,0 @@
-import logging
-import numpy as np
-import soundfile
-import kaldiio
-from funasr.utils.job_runner import MultiProcessRunnerV3
-from funasr.utils.misc import load_scp_as_list, load_scp_as_dict
-import os
-import argparse
-from collections import OrderedDict
-import random
-from typing import List, Dict
-from copy import deepcopy
-import json
-logging.basicConfig(
- level="INFO",
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-)
-
-
-class MyRunner(MultiProcessRunnerV3):
-
- def prepare(self, parser: argparse.ArgumentParser):
- parser.add_argument("--label_scp", type=str, required=True)
- parser.add_argument("--wav_scp", type=str, required=True)
- parser.add_argument("--utt2spk", type=str, required=True)
- parser.add_argument("--spk2meeting", type=str, required=True)
- parser.add_argument("--utt2xvec", type=str, required=True)
- parser.add_argument("--out_dir", type=str, required=True)
- parser.add_argument("--chunk_size", type=float, default=16)
- parser.add_argument("--chunk_shift", type=float, default=4)
- parser.add_argument("--frame_shift", type=float, default=0.01)
- parser.add_argument("--embedding_dim", type=int, default=None)
- parser.add_argument("--average_emb_num", type=int, default=0)
- parser.add_argument("--subset", type=int, default=0)
- parser.add_argument("--data_json", type=str, default=None)
- parser.add_argument("--seed", type=int, default=1234)
- parser.add_argument("--log_interval", type=int, default=100)
- args = parser.parse_args()
- random.seed(args.seed)
- np.random.seed(args.seed)
-
- logging.info("Loading data...")
- if not os.path.exists(args.data_json):
- label_list = load_scp_as_list(args.label_scp)
- wav_scp = load_scp_as_dict(args.wav_scp)
- utt2spk = load_scp_as_dict(args.utt2spk)
- utt2xvec = load_scp_as_dict(args.utt2xvec)
- spk2meeting = load_scp_as_dict(args.spk2meeting)
-
- meeting2spks = OrderedDict()
- for spk, meeting in spk2meeting.items():
- if meeting not in meeting2spks:
- meeting2spks[meeting] = []
- meeting2spks[meeting].append(spk)
-
- spk2utts = OrderedDict()
- for utt, spk in utt2spk.items():
- if spk not in spk2utts:
- spk2utts[spk] = []
- spk2utts[spk].append(utt)
-
- os.makedirs(os.path.dirname(args.data_json), exist_ok=True)
- logging.info("Dump data...")
- json.dump({
- "label_list": label_list, "wav_scp": wav_scp, "utt2xvec": utt2xvec,
- "spk2utts": spk2utts, "meeting2spks": meeting2spks
- }, open(args.data_json, "wt", encoding="utf-8"), ensure_ascii=False, indent=4)
- else:
- data_dict = json.load(open(args.data_json, "rt", encoding="utf-8"))
- label_list = data_dict["label_list"]
- wav_scp = data_dict["wav_scp"]
- utt2xvec = data_dict["utt2xvec"]
- spk2utts = data_dict["spk2utts"]
- meeting2spks = data_dict["meeting2spks"]
-
- if not os.path.exists(args.out_dir):
- os.makedirs(args.out_dir)
-
- args.chunk_size = int(args.chunk_size / args.frame_shift)
- args.chunk_shift = int(args.chunk_shift / args.frame_shift)
-
- if args.embedding_dim is None:
- args.embedding_dim = kaldiio.load_mat(next(iter(utt2xvec.values()))).shape[1]
- logging.info("Embedding dim is detected as {}.".format(args.embedding_dim))
-
- logging.info("Number utt: {}, Number speaker: {}, Number meetings: {}".format(
- len(wav_scp), len(spk2utts), len(meeting2spks)
- ))
- return label_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args
-
- def post(self, results_list, args):
- logging.info("[main]: Got {} chunks.".format(sum(results_list)))
-
-
-def simu_wav_chunk(spk, spk2utts, wav_scp, sample_length):
- utt_list = spk2utts[spk]
- wav_list = []
- cur_length = 0
- while cur_length < sample_length:
- uttid = random.choice(utt_list)
- wav, fs = soundfile.read(wav_scp[uttid], dtype='float32')
- wav_list.append(wav)
- cur_length += len(wav)
- concat_wav = np.concatenate(wav_list, axis=0)
- start = random.randint(0, len(concat_wav) - sample_length)
- return concat_wav[start: start+sample_length]
-
-
-def calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num):
- # process for dummy speaker
- if spk == "None":
- return np.zeros((1, embedding_dim), dtype=np.float32)
-
- # calculate averaged speaker embeddings
- utt_list = spk2utts[spk]
- if average_emb_num == 0 or average_emb_num > len(utt_list):
- xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in utt_list]
- else:
- xvec_list = [kaldiio.load_mat(utt2xvec[utt]) for utt in random.sample(utt_list, average_emb_num)]
- xvec = np.concatenate(xvec_list, axis=0)
- xvec = xvec / np.linalg.norm(xvec, axis=-1, keepdims=True)
- xvec = np.mean(xvec, axis=0)
-
- return xvec
-
-
-def simu_chunk(
- frame_label: np.ndarray,
- sample_label: np.ndarray,
- wav_scp: Dict[str, str],
- utt2xvec: Dict[str, str],
- spk2utts: Dict[str, List[str]],
- meeting2spks: Dict[str, List[str]],
- all_speaker_list: List[str],
- meeting_list: List[str],
- embedding_dim: int,
- average_emb_num: int,
-):
- frame_length, max_spk_num = frame_label.shape
- sample_length = sample_label.shape[0]
- positive_speaker_num = int(np.sum(frame_label.sum(axis=0) > 0))
- pos_speaker_list = deepcopy(meeting2spks[random.choice(meeting_list)])
-
- # get positive speakers
- if len(pos_speaker_list) >= positive_speaker_num:
- pos_speaker_list = random.sample(pos_speaker_list, positive_speaker_num)
- else:
- while len(pos_speaker_list) < positive_speaker_num:
- _spk = random.choice(all_speaker_list)
- if _spk not in pos_speaker_list:
- pos_speaker_list.append(_spk)
-
- # get negative speakers
- negative_speaker_num = random.randint(0, max_spk_num - positive_speaker_num)
- neg_speaker_list = []
- while len(neg_speaker_list) < negative_speaker_num:
- _spk = random.choice(all_speaker_list)
- if _spk not in pos_speaker_list and _spk not in neg_speaker_list:
- neg_speaker_list.append(_spk)
- neg_speaker_list.extend(["None"] * (max_spk_num - positive_speaker_num - negative_speaker_num))
-
- random.shuffle(pos_speaker_list)
- random.shuffle(neg_speaker_list)
- seperated_wav = np.zeros(sample_label.shape, dtype=np.float32)
- this_spk_list = []
- for idx, frame_num in enumerate(frame_label.sum(axis=0)):
- if frame_num > 0:
- spk = pos_speaker_list.pop(0)
- this_spk_list.append(spk)
- simu_spk_wav = simu_wav_chunk(spk, spk2utts, wav_scp, sample_length)
- seperated_wav[:, idx] = simu_spk_wav
- else:
- spk = neg_speaker_list.pop(0)
- this_spk_list.append(spk)
-
- # calculate mixed wav
- mixed_wav = np.sum(seperated_wav * sample_label, axis=1)
-
- # shuffle the order of speakers
- shuffle_idx = list(range(max_spk_num))
- random.shuffle(shuffle_idx)
- this_spk_list = [this_spk_list[x] for x in shuffle_idx]
- seperated_wav = seperated_wav.transpose()[shuffle_idx].transpose()
- frame_label = frame_label.transpose()[shuffle_idx].transpose()
-
- # calculate profile
- profile = [calculate_embedding(spk, spk2utts, utt2xvec, embedding_dim, average_emb_num)
- for spk in this_spk_list]
- profile = np.vstack(profile)
- # pse_weights = 2 ** np.arange(max_spk_num)
- # pse_label = np.sum(frame_label * pse_weights[np.newaxis, :], axis=1)
- # pse_label = pse_label.astype(str).tolist()
-
- return mixed_wav, seperated_wav, profile, frame_label
-
-
-def process(task_args):
- task_idx, task_list, (wav_scp, utt2xvec, spk2utts, meeting2spks), args = task_args
- logging.info("{:02d}/{:02d}: Start simulation...".format(task_idx+1, args.nj))
-
- out_path = os.path.join(args.out_dir, "wav_mix.{}".format(task_idx+1))
- wav_mix_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- # out_path = os.path.join(args.out_dir, "wav_sep.{}".format(task_idx + 1))
- # wav_sep_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- out_path = os.path.join(args.out_dir, "profile.{}".format(task_idx + 1))
- profile_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- out_path = os.path.join(args.out_dir, "frame_label.{}".format(task_idx + 1))
- label_writer = kaldiio.WriteHelper('ark,scp:{}.ark,{}.scp'.format(out_path, out_path))
-
- speaker_list, meeting_list = list(spk2utts.keys()), list(meeting2spks.keys())
-
- labels_list = []
- total_chunks = 0
- for org_mid, label_path in task_list:
- whole_label = kaldiio.load_mat(label_path)
- # random offset to keep diversity
- rand_shift = random.randint(0, args.chunk_shift)
- num_chunk = (whole_label.shape[0] - rand_shift - args.chunk_size) // args.chunk_shift + 1
- labels_list.append((org_mid, whole_label, rand_shift, num_chunk))
- total_chunks += num_chunk
-
- idx = 0
- simu_chunk_count = 0
- for org_mid, whole_label, rand_shift, num_chunk in labels_list:
- for i in range(num_chunk):
- idx = idx + 1
- st = i * args.chunk_shift + rand_shift
- ed = i * args.chunk_shift + args.chunk_size + rand_shift
- utt_id = "subset{}_part{}_{}_{:06d}_{:06d}".format(
- args.subset + 1, task_idx + 1, org_mid, st, ed
- )
- frame_label = whole_label[st: ed, :]
- sample_label = frame_label.repeat(int(args.sr * args.frame_shift), axis=0)
- mix_wav, seg_wav, profile, frame_label = simu_chunk(
- frame_label, sample_label, wav_scp, utt2xvec, spk2utts, meeting2spks,
- speaker_list, meeting_list, args.embedding_dim, args.average_emb_num
- )
- wav_mix_writer(utt_id, mix_wav)
- # wav_sep_writer(utt_id, seg_wav)
- profile_writer(utt_id, profile)
- label_writer(utt_id, frame_label)
-
- simu_chunk_count += 1
- if simu_chunk_count % args.log_interval == 0:
- logging.info("{:02d}/{:02d}: Complete {}/{} simulation, {}.".format(
- task_idx + 1, args.nj, simu_chunk_count, total_chunks, utt_id))
- wav_mix_writer.close()
- # wav_sep_writer.close()
- profile_writer.close()
- label_writer.close()
- logging.info("[{}/{}]: Simulate {} chunks.".format(task_idx+1, args.nj, simu_chunk_count))
- return simu_chunk_count
-
-
-if __name__ == '__main__':
- my_runner = MyRunner(process)
- my_runner.run()
diff --git a/funasr/bin/asr_test.py b/funasr/bin/asr_test.py
new file mode 100644
index 0000000..389bd22
--- /dev/null
+++ b/funasr/bin/asr_test.py
@@ -0,0 +1,226 @@
+import argparse
+import logging
+import os
+import sys
+
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="ASR Decoding",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=True)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument(
+ "--njob",
+ type=int,
+ default=1,
+ help="The number of jobs for each gpu",
+ )
+ parser.add_argument(
+ "--gpuid_list",
+ type=str,
+ default="",
+ help="The visible gpus",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=True,
+ action="append",
+ )
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--vad_infer_config",
+ type=str,
+ help="VAD infer configuration",
+ )
+ group.add_argument(
+ "--vad_model_file",
+ type=str,
+ help="VAD model parameter file",
+ )
+ group.add_argument(
+ "--cmvn_file",
+ type=str,
+ help="Global CMVN file",
+ )
+ group.add_argument(
+ "--asr_train_config",
+ type=str,
+ help="ASR training configuration",
+ )
+ group.add_argument(
+ "--asr_model_file",
+ type=str,
+ help="ASR model parameter file",
+ )
+ group.add_argument(
+ "--lm_train_config",
+ type=str,
+ help="LM training configuration",
+ )
+ group.add_argument(
+ "--lm_file",
+ type=str,
+ help="LM parameter file",
+ )
+ group.add_argument(
+ "--word_lm_train_config",
+ type=str,
+ help="Word LM training configuration",
+ )
+ group.add_argument(
+ "--word_lm_file",
+ type=str,
+ help="Word LM parameter file",
+ )
+ group.add_argument(
+ "--ngram_file",
+ type=str,
+ help="N-gram parameter file",
+ )
+ group.add_argument(
+ "--model_tag",
+ type=str,
+ help="Pretrained model tag. If specify this option, *_train_config and "
+ "*_file will be overwritten",
+ )
+
+ group = parser.add_argument_group("Beam-search related")
+ group.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+ group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
+ group.add_argument("--beam_size", type=int, default=20, help="Beam size")
+ group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
+ group.add_argument(
+ "--maxlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain max output length. "
+ "If maxlenratio=0.0 (default), it uses a end-detect "
+ "function "
+ "to automatically find maximum hypothesis lengths."
+ "If maxlenratio<0.0, its absolute value is interpreted"
+ "as a constant max output length",
+ )
+ group.add_argument(
+ "--minlenratio",
+ type=float,
+ default=0.0,
+ help="Input length ratio to obtain min output length",
+ )
+ group.add_argument(
+ "--ctc_weight",
+ type=float,
+ default=0.0,
+ help="CTC weight in joint decoding",
+ )
+ group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
+ group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
+ group.add_argument("--streaming", type=str2bool, default=False)
+
+ group = parser.add_argument_group("Text converter related")
+ group.add_argument(
+ "--token_type",
+ type=str_or_none,
+ default=None,
+ choices=["char", "bpe", None],
+ help="The token type for ASR model. "
+ "If not given, refers from the training args",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model path of sentencepiece. "
+ "If not given, refers from the training args",
+ )
+ group.add_argument("--token_num_relax", type=int, default=1, help="")
+ group.add_argument("--decoding_ind", type=int, default=0, help="")
+ group.add_argument("--decoding_mode", type=str, default="model1", help="")
+ group.add_argument(
+ "--ctc_weight2",
+ type=float,
+ default=0.0,
+ help="CTC weight in joint decoding",
+ )
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ parser.add_argument(
+ "--mode",
+ type=str,
+ default="asr",
+ help="The decoding mode",
+ )
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+
+ # set logging messages
+ logging.basicConfig(
+ level=args.log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("Decoding args: {}".format(kwargs))
+
+ # gpu setting
+ if args.ngpu > 0:
+ jobid = int(args.output_dir.split(".")[-1])
+ gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
+
+ # inference_launch_funasr(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
deleted file mode 100644
index df3434f..0000000
--- a/funasr/bin/build_trainer.py
+++ /dev/null
@@ -1,151 +0,0 @@
-# -*- encoding: utf-8 -*-
-#!/usr/bin/env python3
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import os
-
-import yaml
-
-
-def update_dct(fin_configs, root):
- if root == {}:
- return {}
- for root_key, root_value in root.items():
- if not isinstance(root[root_key], dict):
- fin_configs[root_key] = root[root_key]
- else:
- if root_key in fin_configs.keys():
- result = update_dct(fin_configs[root_key], root[root_key])
- fin_configs[root_key] = result
- else:
- fin_configs[root_key] = root[root_key]
- return fin_configs
-
-
-def parse_args(mode):
- if mode == "asr":
- from funasr.tasks.asr import ASRTask as ASRTask
- elif mode == "paraformer":
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- elif mode == "paraformer_vad_punc":
- from funasr.tasks.asr import ASRTaskParaformer as ASRTask
- elif mode == "uniasr":
- from funasr.tasks.asr import ASRTaskUniASR as ASRTask
- elif mode == "mfcca":
- from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
- elif mode == "tp":
- from funasr.tasks.asr import ASRTaskAligner as ASRTask
- else:
- raise ValueError("Unknown mode: {}".format(mode))
- parser = ASRTask.get_parser()
- args = parser.parse_args()
- return args, ASRTask
-
-
-def build_trainer(modelscope_dict,
- data_dir,
- output_dir,
- train_set="train",
- dev_set="validation",
- distributed=False,
- dataset_type="small",
- batch_bins=None,
- max_epoch=None,
- optim=None,
- lr=None,
- scheduler=None,
- scheduler_conf=None,
- specaug=None,
- specaug_conf=None,
- param_dict=None,
- **kwargs):
- mode = modelscope_dict['mode']
- args, ASRTask = parse_args(mode=mode)
- # ddp related
- if args.local_rank is not None:
- distributed = True
- else:
- distributed = False
- args.local_rank = args.local_rank if args.local_rank is not None else 0
- local_rank = args.local_rank
- if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
- gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
- os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[args.local_rank])
- else:
- os.environ['CUDA_VISIBLE_DEVICES'] = str(args.local_rank)
-
- config = modelscope_dict['am_model_config']
- finetune_config = modelscope_dict['finetune_config']
- init_param = modelscope_dict['init_model']
- cmvn_file = modelscope_dict['cmvn_file']
- seg_dict_file = modelscope_dict['seg_dict']
-
- # overwrite parameters
- with open(config) as f:
- configs = yaml.safe_load(f)
- with open(finetune_config) as f:
- finetune_configs = yaml.safe_load(f)
- # set data_types
- if dataset_type == "large":
- if 'data_types' not in finetune_configs['dataset_conf']:
- finetune_configs["dataset_conf"]["data_types"] = "sound,text"
- finetune_configs = update_dct(configs, finetune_configs)
- for key, value in finetune_configs.items():
- if hasattr(args, key):
- setattr(args, key, value)
-
- # prepare data
- args.dataset_type = dataset_type
- if args.dataset_type == "small":
- args.train_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, train_set), "speech", "sound"],
- ["{}/{}/text".format(data_dir, train_set), "text", "text"]]
- args.valid_data_path_and_name_and_type = [["{}/{}/wav.scp".format(data_dir, dev_set), "speech", "sound"],
- ["{}/{}/text".format(data_dir, dev_set), "text", "text"]]
- elif args.dataset_type == "large":
- args.train_data_file = None
- args.valid_data_file = None
- else:
- raise ValueError(f"Not supported dataset_type={args.dataset_type}")
- args.init_param = [init_param]
- args.cmvn_file = cmvn_file
- if os.path.exists(seg_dict_file):
- args.seg_dict_file = seg_dict_file
- else:
- args.seg_dict_file = None
- args.data_dir = data_dir
- args.train_set = train_set
- args.dev_set = dev_set
- args.output_dir = output_dir
- args.gpu_id = args.local_rank
- args.config = finetune_config
- if optim is not None:
- args.optim = optim
- if lr is not None:
- args.optim_conf["lr"] = lr
- if scheduler is not None:
- args.scheduler = scheduler
- if scheduler_conf is not None:
- args.scheduler_conf = scheduler_conf
- if specaug is not None:
- args.specaug = specaug
- if specaug_conf is not None:
- args.specaug_conf = specaug_conf
- if max_epoch is not None:
- args.max_epoch = max_epoch
- if batch_bins is not None:
- if args.dataset_type == "small":
- args.batch_bins = batch_bins
- elif args.dataset_type == "large":
- args.dataset_conf["batch_conf"]["batch_size"] = batch_bins
- else:
- raise ValueError(f"Not supported dataset_type={args.dataset_type}")
- if args.normalize in ["null", "none", "None"]:
- args.normalize = None
- if args.patience in ["null", "none", "None"]:
- args.patience = None
- args.local_rank = local_rank
- args.distributed = distributed
- ASRTask.finetune_args = args
-
- return ASRTask
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
new file mode 100755
index 0000000..53e5bde
--- /dev/null
+++ b/funasr/bin/train.py
@@ -0,0 +1,572 @@
+#!/usr/bin/env python3
+
+import argparse
+import logging
+import os
+import sys
+from io import BytesIO
+
+import torch
+
+from funasr.build_utils.build_args import build_args
+from funasr.build_utils.build_dataloader import build_dataloader
+from funasr.build_utils.build_distributed import build_distributed
+from funasr.build_utils.build_model import build_model
+from funasr.build_utils.build_optimizer import build_optimizer
+from funasr.build_utils.build_scheduler import build_scheduler
+from funasr.build_utils.build_trainer import build_trainer
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.load_pretrained_model import load_pretrained_model
+from funasr.torch_utils.model_summary import model_summary
+from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.prepare_data import prepare_data
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ description="FunASR Common Training Parser",
+ )
+
+ # common configuration
+ parser.add_argument("--output_dir", help="model save path")
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument("--task_name", type=str, default="asr", help="Name for different tasks")
+
+ # ddp related
+ parser.add_argument(
+ "--dist_backend",
+ default="nccl",
+ type=str,
+ help="distributed backend",
+ )
+ parser.add_argument(
+ "--dist_init_method",
+ type=str,
+ default="env://",
+ help='if init_method="env://", env values of "MASTER_PORT", "MASTER_ADDR", '
+ '"WORLD_SIZE", and "RANK" are referred.',
+ )
+ parser.add_argument(
+ "--dist_world_size",
+ type=int,
+ default=1,
+ help="number of nodes for distributed training",
+ )
+ parser.add_argument(
+ "--dist_rank",
+ type=int,
+ default=None,
+ help="node rank for distributed training",
+ )
+ parser.add_argument(
+ "--local_rank",
+ type=int,
+ default=None,
+ help="local rank for distributed training",
+ )
+ parser.add_argument(
+ "--dist_master_addr",
+ default=None,
+ type=str_or_none,
+ help="The master address for distributed training. "
+ "This value is used when dist_init_method == 'env://'",
+ )
+ parser.add_argument(
+ "--dist_master_port",
+ default=None,
+ type=int_or_none,
+ help="The master port for distributed training"
+ "This value is used when dist_init_method == 'env://'",
+ )
+ parser.add_argument(
+ "--dist_launcher",
+ default=None,
+ type=str_or_none,
+ choices=["slurm", "mpi", None],
+ help="The launcher type for distributed training",
+ )
+ parser.add_argument(
+ "--multiprocessing_distributed",
+ default=True,
+ type=str2bool,
+ help="Use multi-processing distributed training to launch "
+ "N processes per node, which has N GPUs. This is the "
+ "fastest way to use PyTorch for either single node or "
+ "multi node data parallel training",
+ )
+ parser.add_argument(
+ "--unused_parameters",
+ type=str2bool,
+ default=False,
+ help="Whether to use the find_unused_parameters in "
+ "torch.nn.parallel.DistributedDataParallel ",
+ )
+ parser.add_argument(
+ "--gpu_id",
+ type=int,
+ default=0,
+ help="local gpu id.",
+ )
+
+ # cudnn related
+ parser.add_argument(
+ "--cudnn_enabled",
+ type=str2bool,
+ default=torch.backends.cudnn.enabled,
+ help="Enable CUDNN",
+ )
+ parser.add_argument(
+ "--cudnn_benchmark",
+ type=str2bool,
+ default=torch.backends.cudnn.benchmark,
+ help="Enable cudnn-benchmark mode",
+ )
+ parser.add_argument(
+ "--cudnn_deterministic",
+ type=str2bool,
+ default=True,
+ help="Enable cudnn-deterministic mode",
+ )
+
+ # trainer related
+ parser.add_argument(
+ "--max_epoch",
+ type=int,
+ default=40,
+ help="The maximum number epoch to train",
+ )
+ parser.add_argument(
+ "--max_update",
+ type=int,
+ default=sys.maxsize,
+ help="The maximum number update step to train",
+ )
+ parser.add_argument(
+ "--batch_interval",
+ type=int,
+ default=10000,
+ help="The batch interval for saving model.",
+ )
+ parser.add_argument(
+ "--patience",
+ type=int_or_none,
+ default=None,
+ help="Number of epochs to wait without improvement "
+ "before stopping the training",
+ )
+ parser.add_argument(
+ "--val_scheduler_criterion",
+ type=str,
+ nargs=2,
+ default=("valid", "loss"),
+ help="The criterion used for the value given to the lr scheduler. "
+ 'Give a pair referring the phase, "train" or "valid",'
+ 'and the criterion name. The mode specifying "min" or "max" can '
+ "be changed by --scheduler_conf",
+ )
+ parser.add_argument(
+ "--early_stopping_criterion",
+ type=str,
+ nargs=3,
+ default=("valid", "loss", "min"),
+ help="The criterion used for judging of early stopping. "
+ 'Give a pair referring the phase, "train" or "valid",'
+ 'the criterion name and the mode, "min" or "max", e.g. "acc,max".',
+ )
+ parser.add_argument(
+ "--best_model_criterion",
+ nargs="+",
+ default=[
+ ("train", "loss", "min"),
+ ("valid", "loss", "min"),
+ ("train", "acc", "max"),
+ ("valid", "acc", "max"),
+ ],
+ help="The criterion used for judging of the best model. "
+ 'Give a pair referring the phase, "train" or "valid",'
+ 'the criterion name, and the mode, "min" or "max", e.g. "acc,max".',
+ )
+ parser.add_argument(
+ "--keep_nbest_models",
+ type=int,
+ nargs="+",
+ default=[10],
+ help="Remove previous snapshots excluding the n-best scored epochs",
+ )
+ parser.add_argument(
+ "--nbest_averaging_interval",
+ type=int,
+ default=0,
+ help="The epoch interval to apply model averaging and save nbest models",
+ )
+ parser.add_argument(
+ "--grad_clip",
+ type=float,
+ default=5.0,
+ help="Gradient norm threshold to clip",
+ )
+ parser.add_argument(
+ "--grad_clip_type",
+ type=float,
+ default=2.0,
+ help="The type of the used p-norm for gradient clip. Can be inf",
+ )
+ parser.add_argument(
+ "--grad_noise",
+ type=str2bool,
+ default=False,
+ help="The flag to switch to use noise injection to "
+ "gradients during training",
+ )
+ parser.add_argument(
+ "--accum_grad",
+ type=int,
+ default=1,
+ help="The number of gradient accumulation",
+ )
+ parser.add_argument(
+ "--resume",
+ type=str2bool,
+ default=False,
+ help="Enable resuming if checkpoint is existing",
+ )
+ parser.add_argument(
+ "--train_dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type for training.",
+ )
+ parser.add_argument(
+ "--use_amp",
+ type=str2bool,
+ default=False,
+ help="Enable Automatic Mixed Precision. This feature requires pytorch>=1.6",
+ )
+ parser.add_argument(
+ "--log_interval",
+ default=None,
+ help="Show the logs every the number iterations in each epochs at the "
+ "training phase. If None is given, it is decided according the number "
+ "of training samples automatically .",
+ )
+ parser.add_argument(
+ "--use_tensorboard",
+ type=str2bool,
+ default=True,
+ help="Enable tensorboard logging",
+ )
+
+ # pretrained model related
+ parser.add_argument(
+ "--init_param",
+ type=str,
+ default=[],
+ nargs="*",
+ help="Specify the file path used for initialization of parameters. "
+ "The format is '<file_path>:<src_key>:<dst_key>:<exclude_keys>', "
+ "where file_path is the model file path, "
+ "src_key specifies the key of model states to be used in the model file, "
+ "dst_key specifies the attribute of the model to be initialized, "
+ "and exclude_keys excludes keys of model states for the initialization."
+ "e.g.\n"
+ " # Load all parameters"
+ " --init_param some/where/model.pb\n"
+ " # Load only decoder parameters"
+ " --init_param some/where/model.pb:decoder:decoder\n"
+ " # Load only decoder parameters excluding decoder.embed"
+ " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n"
+ " --init_param some/where/model.pb:decoder:decoder:decoder.embed\n",
+ )
+ parser.add_argument(
+ "--ignore_init_mismatch",
+ type=str2bool,
+ default=False,
+ help="Ignore size mismatch when loading pre-trained model",
+ )
+ parser.add_argument(
+ "--freeze_param",
+ type=str,
+ default=[],
+ nargs="*",
+ help="Freeze parameters",
+ )
+
+ # dataset related
+ parser.add_argument(
+ "--dataset_type",
+ type=str,
+ default="small",
+ help="whether to use dataloader for large dataset",
+ )
+ parser.add_argument(
+ "--dataset_conf",
+ action=NestedDictAction,
+ default=dict(),
+ help=f"The keyword arguments for dataset",
+ )
+ parser.add_argument(
+ "--data_dir",
+ type=str,
+ default=None,
+ help="root path of data",
+ )
+ parser.add_argument(
+ "--train_set",
+ type=str,
+ default="train",
+ help="train dataset",
+ )
+ parser.add_argument(
+ "--valid_set",
+ type=str,
+ default="validation",
+ help="dev dataset",
+ )
+ parser.add_argument(
+ "--speed_perturb",
+ type=float,
+ nargs="+",
+ default=None,
+ help="speed perturb",
+ )
+ parser.add_argument(
+ "--use_preprocessor",
+ type=str2bool,
+ default=True,
+ help="Apply preprocessing to data or not",
+ )
+ parser.add_argument(
+ "--embed_path",
+ type=str,
+ default=None,
+ help="for model which requires embeds",
+ )
+
+ # optimization related
+ parser.add_argument(
+ "--optim",
+ type=lambda x: x.lower(),
+ default="adam",
+ help="The optimizer type",
+ )
+ parser.add_argument(
+ "--optim_conf",
+ action=NestedDictAction,
+ default=dict(),
+ help="The keyword arguments for optimizer",
+ )
+ parser.add_argument(
+ "--scheduler",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The lr scheduler type",
+ )
+ parser.add_argument(
+ "--scheduler_conf",
+ action=NestedDictAction,
+ default=dict(),
+ help="The keyword arguments for lr scheduler",
+ )
+
+ # most task related
+ parser.add_argument(
+ "--init",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The initialization method",
+ choices=[
+ "chainer",
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ None,
+ ],
+ )
+ parser.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ parser.add_argument(
+ "--token_type",
+ type=str,
+ default="bpe",
+ choices=["bpe", "char", "word"],
+ help="",
+ )
+ parser.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model file fo sentencepiece",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese"],
+ default=None,
+ help="Apply text cleaning",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="Specify g2p method if --token_type=phn",
+ )
+
+ # pai related
+ parser.add_argument(
+ "--use_pai",
+ type=str2bool,
+ default=False,
+ help="flag to indicate whether training on PAI",
+ )
+ parser.add_argument(
+ "--simple_ddp",
+ type=str2bool,
+ default=False,
+ )
+ parser.add_argument(
+ "--num_worker_count",
+ type=int,
+ default=1,
+ help="The number of machines on PAI.",
+ )
+ parser.add_argument(
+ "--access_key_id",
+ type=str,
+ default=None,
+ help="The username for oss.",
+ )
+ parser.add_argument(
+ "--access_key_secret",
+ type=str,
+ default=None,
+ help="The password for oss.",
+ )
+ parser.add_argument(
+ "--endpoint",
+ type=str,
+ default=None,
+ help="The endpoint for oss.",
+ )
+ parser.add_argument(
+ "--bucket_name",
+ type=str,
+ default=None,
+ help="The bucket name for oss.",
+ )
+ parser.add_argument(
+ "--oss_bucket",
+ default=None,
+ help="oss bucket.",
+ )
+
+ return parser
+
+
+if __name__ == '__main__':
+ parser = get_parser()
+ args, extra_task_params = parser.parse_known_args()
+ if extra_task_params:
+ args = build_args(args, parser, extra_task_params)
+
+ # set random seed
+ set_all_random_seed(args.seed)
+ torch.backends.cudnn.enabled = args.cudnn_enabled
+ torch.backends.cudnn.benchmark = args.cudnn_benchmark
+ torch.backends.cudnn.deterministic = args.cudnn_deterministic
+
+ # ddp init
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+ args.distributed = args.ngpu > 1 or args.dist_world_size > 1
+ distributed_option = build_distributed(args)
+
+ # for logging
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ logging.basicConfig(
+ level="INFO",
+ format=f"[{os.uname()[1].split('.')[0]}]"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ else:
+ logging.basicConfig(
+ level="ERROR",
+ format=f"[{os.uname()[1].split('.')[0]}]"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ # prepare files for dataloader
+ prepare_data(args, distributed_option)
+
+ model = build_model(args)
+ model = model.to(
+ dtype=getattr(torch, args.train_dtype),
+ device="cuda" if args.ngpu > 0 else "cpu",
+ )
+ optimizers = build_optimizer(args, model=model)
+ schedulers = build_scheduler(args, optimizers)
+
+ logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
+ distributed_option.dist_rank,
+ distributed_option.local_rank))
+ logging.info(pytorch_cudnn_version())
+ logging.info("Args: {}".format(args))
+ logging.info(model_summary(model))
+ logging.info("Optimizer: {}".format(optimizers))
+ logging.info("Scheduler: {}".format(schedulers))
+
+ # dump args to config.yaml
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ os.makedirs(args.output_dir, exist_ok=True)
+ with open(os.path.join(args.output_dir, "config.yaml"), "w") as f:
+ logging.info("Saving the configuration in {}/{}".format(args.output_dir, "config.yaml"))
+ if args.use_pai:
+ buffer = BytesIO()
+ torch.save({"config": vars(args)}, buffer)
+ args.oss_bucket.put_object(os.path.join(args.output_dir, "config.dict"), buffer.getvalue())
+ else:
+ yaml_no_alias_safe_dump(vars(args), f, indent=4, sort_keys=False)
+
+ for p in args.init_param:
+ logging.info(f"Loading pretrained params from {p}")
+ load_pretrained_model(
+ model=model,
+ init_param=p,
+ ignore_init_mismatch=args.ignore_init_mismatch,
+ map_location=f"cuda:{torch.cuda.current_device()}"
+ if args.ngpu > 0
+ else "cpu",
+ oss_bucket=args.oss_bucket,
+ )
+
+ # dataloader for training/validation
+ train_dataloader, valid_dataloader = build_dataloader(args)
+
+ # Trainer, including model, optimizers, etc.
+ trainer = build_trainer(
+ args=args,
+ model=model,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ train_dataloader=train_dataloader,
+ valid_dataloader=valid_dataloader,
+ distributed_option=distributed_option
+ )
+
+ trainer.run()
diff --git a/funasr/build_utils/__init__.py b/funasr/build_utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/build_utils/__init__.py
diff --git a/funasr/build_utils/build_args.py b/funasr/build_utils/build_args.py
new file mode 100644
index 0000000..517c85b
--- /dev/null
+++ b/funasr/build_utils/build_args.py
@@ -0,0 +1,93 @@
+from funasr.models.ctc import CTC
+from funasr.utils import config_argparse
+from funasr.utils.get_default_kwargs import get_default_kwargs
+from funasr.utils.nested_dict_action import NestedDictAction
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+
+def build_args(args, parser, extra_task_params):
+ task_parser = config_argparse.ArgumentParser("Task related config")
+ if args.task_name == "asr":
+ from funasr.build_utils.build_asr_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+ task_parser.add_argument(
+ "--split_with_space",
+ type=str2bool,
+ default=True,
+ help="whether to split text using <space>",
+ )
+ task_parser.add_argument(
+ "--seg_dict_file",
+ type=str,
+ default=None,
+ help="seg_dict_file for text processing",
+ )
+ task_parser.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+ task_parser.add_argument(
+ "--ctc_conf",
+ action=NestedDictAction,
+ default=get_default_kwargs(CTC),
+ help="The keyword arguments for CTC class.",
+ )
+ task_parser.add_argument(
+ "--cmvn_file",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+
+ elif args.task_name == "pretrain":
+ from funasr.build_utils.build_pretrain_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+ task_parser.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
+ elif args.task_name == "lm":
+ from funasr.build_utils.build_lm_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+
+ elif args.task_name == "punc":
+ from funasr.build_utils.build_punc_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+
+ elif args.task_name == "vad":
+ from funasr.build_utils.build_vad_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+ task_parser.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
+ elif args.task_name == "diar":
+ from funasr.build_utils.build_diar_model import class_choices_list
+ for class_choices in class_choices_list:
+ class_choices.add_arguments(task_parser)
+
+ else:
+ raise NotImplementedError("Not supported task: {}".format(args.task_name))
+
+ for action in parser._actions:
+ if not any(action.dest == a.dest for a in task_parser._actions):
+ task_parser._add_action(action)
+
+ task_parser.set_defaults(**vars(args))
+ task_args = task_parser.parse_args(extra_task_params)
+ return task_args
diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
new file mode 100644
index 0000000..d8cbba5
--- /dev/null
+++ b/funasr/build_utils/build_asr_model.py
@@ -0,0 +1,352 @@
+import logging
+
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.ctc import CTC
+from funasr.models.decoder.abs_decoder import AbsDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.rnn_decoder import RNNDecoder
+from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder, FsmnDecoderSCAMAOpt
+from funasr.models.decoder.transformer_decoder import (
+ DynamicConvolution2DTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
+from funasr.models.decoder.transformer_decoder import (
+ LightweightConvolution2DTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import (
+ LightweightConvolutionTransformerDecoder, # noqa: H301
+)
+from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
+from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.e2e_asr import ASRModel
+from funasr.models.e2e_asr_mfcca import MFCCA
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_tp import TimestampPredictor
+from funasr.models.e2e_uni_asr import UniASR
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.mfcca_encoder import MFCCAEncoder
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.default import MultiChannelFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.predictor.cif import CifPredictor, CifPredictorV2, CifPredictorV3
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.subsampling import Conv1dSubsampling
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ multichannelfrontend=MultiChannelFrontend,
+ ),
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(
+ specaug=SpecAug,
+ specaug_lfr=SpecAugLFR,
+ ),
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ default=None,
+ optional=True,
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ asr=ASRModel,
+ uniasr=UniASR,
+ paraformer=Paraformer,
+ paraformer_bert=ParaformerBert,
+ bicif_paraformer=BiCifParaformer,
+ contextual_paraformer=ContextualParaformer,
+ mfcca=MFCCA,
+ timestamp_prediction=TimestampPredictor,
+ ),
+ default="asr",
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ mfcca_enc=MFCCAEncoder,
+ ),
+ default="rnn",
+)
+encoder_choices2 = ClassChoices(
+ "encoder2",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ ),
+ default="rnn",
+)
+decoder_choices = ClassChoices(
+ "decoder",
+ classes=dict(
+ transformer=TransformerDecoder,
+ lightweight_conv=LightweightConvolutionTransformerDecoder,
+ lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+ dynamic_conv=DynamicConvolutionTransformerDecoder,
+ dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+ rnn=RNNDecoder,
+ fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+ paraformer_decoder_sanm=ParaformerSANMDecoder,
+ paraformer_decoder_san=ParaformerDecoderSAN,
+ contextual_paraformer_decoder=ContextualParaformerDecoder,
+ ),
+ default="rnn",
+)
+decoder_choices2 = ClassChoices(
+ "decoder2",
+ classes=dict(
+ transformer=TransformerDecoder,
+ lightweight_conv=LightweightConvolutionTransformerDecoder,
+ lightweight_conv2d=LightweightConvolution2DTransformerDecoder,
+ dynamic_conv=DynamicConvolutionTransformerDecoder,
+ dynamic_conv2d=DynamicConvolution2DTransformerDecoder,
+ rnn=RNNDecoder,
+ fsmn_scama_opt=FsmnDecoderSCAMAOpt,
+ paraformer_decoder_sanm=ParaformerSANMDecoder,
+ ),
+ type_check=AbsDecoder,
+ default="rnn",
+)
+predictor_choices = ClassChoices(
+ name="predictor",
+ classes=dict(
+ cif_predictor=CifPredictor,
+ ctc_predictor=None,
+ cif_predictor_v2=CifPredictorV2,
+ cif_predictor_v3=CifPredictorV3,
+ ),
+ default="cif_predictor",
+ optional=True,
+)
+predictor_choices2 = ClassChoices(
+ name="predictor2",
+ classes=dict(
+ cif_predictor=CifPredictor,
+ ctc_predictor=None,
+ cif_predictor_v2=CifPredictorV2,
+ ),
+ default="cif_predictor",
+ optional=True,
+)
+stride_conv_choices = ClassChoices(
+ name="stride_conv",
+ classes=dict(
+ stride_conv1d=Conv1dSubsampling
+ ),
+ default="stride_conv1d",
+ optional=True,
+)
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --model and --model_conf
+ model_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ # --predictor and --predictor_conf
+ predictor_choices,
+ # --encoder2 and --encoder2_conf
+ encoder_choices2,
+ # --decoder2 and --decoder2_conf
+ decoder_choices2,
+ # --predictor2 and --predictor2_conf
+ predictor_choices2,
+ # --stride_conv and --stride_conv_conf
+ stride_conv_choices,
+]
+
+
+def build_asr_model(args):
+ # token_list
+ if args.token_list is not None:
+ with open(args.token_list) as f:
+ token_list = [line.rstrip() for line in f]
+ args.token_list = list(token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+ else:
+ vocab_size = None
+
+ # frontend
+ if args.input_size is None:
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+ # decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder.output_size(),
+ **args.decoder_conf,
+ )
+
+ # ctc
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder.output_size(), **args.ctc_conf
+ )
+
+ if args.model in ["asr", "mfcca"]:
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ **args.model_conf,
+ )
+ elif args.model in ["paraformer", "paraformer_bert", "bicif_paraformer", "contextual_paraformer"]:
+ # predictor
+ predictor_class = predictor_choices.get_class(args.predictor)
+ predictor = predictor_class(**args.predictor_conf)
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ predictor=predictor,
+ **args.model_conf,
+ )
+ elif args.model == "uniasr":
+ # stride_conv
+ stride_conv_class = stride_conv_choices.get_class(args.stride_conv)
+ stride_conv = stride_conv_class(**args.stride_conv_conf, idim=input_size + encoder.output_size(),
+ odim=input_size + encoder.output_size())
+ stride_conv_output_size = stride_conv.output_size()
+
+ # encoder2
+ encoder_class2 = encoder_choices2.get_class(args.encoder2)
+ encoder2 = encoder_class2(input_size=stride_conv_output_size, **args.encoder2_conf)
+
+ # decoder2
+ decoder_class2 = decoder_choices2.get_class(args.decoder2)
+ decoder2 = decoder_class2(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder2.output_size(),
+ **args.decoder2_conf,
+ )
+
+ # ctc2
+ ctc2 = CTC(
+ odim=vocab_size, encoder_output_size=encoder2.output_size(), **args.ctc_conf
+ )
+
+ # predictor
+ predictor_class = predictor_choices.get_class(args.predictor)
+ predictor = predictor_class(**args.predictor_conf)
+
+ # predictor2
+ predictor_class = predictor_choices2.get_class(args.predictor2)
+ predictor2 = predictor_class(**args.predictor2_conf)
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ ctc=ctc,
+ token_list=token_list,
+ predictor=predictor,
+ ctc2=ctc2,
+ encoder2=encoder2,
+ decoder2=decoder2,
+ predictor2=predictor2,
+ stride_conv=stride_conv,
+ **args.model_conf,
+ )
+ elif args.model == "timestamp_prediction":
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ frontend=frontend,
+ encoder=encoder,
+ token_list=token_list,
+ **args.model_conf,
+ )
+ else:
+ raise NotImplementedError("Not supported model: {}".format(args.model))
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
\ No newline at end of file
diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
new file mode 100644
index 0000000..c95c40d
--- /dev/null
+++ b/funasr/build_utils/build_dataloader.py
@@ -0,0 +1,15 @@
+from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
+from funasr.datasets.small_datasets.sequence_iter_factory import SequenceIterFactory
+
+
+def build_dataloader(args):
+ if args.dataset_type == "small":
+ train_iter_factory = SequenceIterFactory(args, mode="train")
+ valid_iter_factory = SequenceIterFactory(args, mode="valid")
+ elif args.dataset_type == "large":
+ train_iter_factory = LargeDataLoader(args, mode="train")
+ valid_iter_factory = LargeDataLoader(args, mode="valid")
+ else:
+ raise ValueError(f"Not supported dataset_type={args.dataset_type}")
+
+ return train_iter_factory, valid_iter_factory
diff --git a/funasr/build_utils/build_diar_model.py b/funasr/build_utils/build_diar_model.py
new file mode 100644
index 0000000..6406404
--- /dev/null
+++ b/funasr/build_utils/build_diar_model.py
@@ -0,0 +1,296 @@
+import logging
+
+import torch
+
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.label_aggregation import LabelAggregate
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
+from funasr.models.e2e_diar_sond import DiarSondModel
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.encoder.ecapa_tdnn_encoder import ECAPA_TDNN
+from funasr.models.encoder.opennmt_encoders.ci_scorers import DotScorer, CosScorer
+from funasr.models.encoder.opennmt_encoders.conv_encoder import ConvEncoder
+from funasr.models.encoder.opennmt_encoders.fsmn_encoder import FsmnEncoder
+from funasr.models.encoder.opennmt_encoders.self_attention_encoder import SelfAttentionEncoder
+from funasr.models.encoder.resnet34_encoder import ResNet34Diar, ResNet34SpL2RegDiar
+from funasr.models.encoder.rnn_encoder import RNNEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
+from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.modules.eend_ola.encoder import EENDOLATransformerEncoder
+from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ wav_frontend_mel23=WavFrontendMel23,
+ ),
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(
+ specaug=SpecAug,
+ specaug_lfr=SpecAugLFR,
+ ),
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ default=None,
+ optional=True,
+)
+label_aggregator_choices = ClassChoices(
+ "label_aggregator",
+ classes=dict(
+ label_aggregator=LabelAggregate
+ ),
+ default=None,
+ optional=True,
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ sond=DiarSondModel,
+ eend_ola=DiarEENDOLAModel,
+ ),
+ default="sond",
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ san=SelfAttentionEncoder,
+ fsmn=FsmnEncoder,
+ conv=ConvEncoder,
+ resnet34=ResNet34Diar,
+ resnet34_sp_l2reg=ResNet34SpL2RegDiar,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ ecapa_tdnn=ECAPA_TDNN,
+ eend_ola_transformer=EENDOLATransformerEncoder,
+ ),
+ default="resnet34",
+)
+speaker_encoder_choices = ClassChoices(
+ "speaker_encoder",
+ classes=dict(
+ conformer=ConformerEncoder,
+ transformer=TransformerEncoder,
+ rnn=RNNEncoder,
+ sanm=SANMEncoder,
+ san=SelfAttentionEncoder,
+ fsmn=FsmnEncoder,
+ conv=ConvEncoder,
+ sanm_chunk_opt=SANMEncoderChunkOpt,
+ data2vec_encoder=Data2VecEncoder,
+ ),
+ default=None,
+ optional=True
+)
+cd_scorer_choices = ClassChoices(
+ "cd_scorer",
+ classes=dict(
+ san=SelfAttentionEncoder,
+ ),
+ default=None,
+ optional=True,
+)
+ci_scorer_choices = ClassChoices(
+ "ci_scorer",
+ classes=dict(
+ dot=DotScorer,
+ cosine=CosScorer,
+ conv=ConvEncoder,
+ ),
+ type_check=torch.nn.Module,
+ default=None,
+ optional=True,
+)
+# decoder is used for output (e.g. post_net in SOND)
+decoder_choices = ClassChoices(
+ "decoder",
+ classes=dict(
+ rnn=RNNEncoder,
+ fsmn=FsmnEncoder,
+ ),
+ type_check=torch.nn.Module,
+ default="fsmn",
+)
+# encoder_decoder_attractor is used for EEND-OLA
+encoder_decoder_attractor_choices = ClassChoices(
+ "encoder_decoder_attractor",
+ classes=dict(
+ eda=EncoderDecoderAttractor,
+ ),
+ type_check=torch.nn.Module,
+ default="eda",
+)
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --label_aggregator and --label_aggregator_conf
+ label_aggregator_choices,
+ # --model and --model_conf
+ model_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --speaker_encoder and --speaker_encoder_conf
+ speaker_encoder_choices,
+ # --cd_scorer and cd_scorer_conf
+ cd_scorer_choices,
+ # --ci_scorer and ci_scorer_conf
+ ci_scorer_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ # --eda and --eda_conf
+ encoder_decoder_attractor_choices,
+]
+
+
+def build_diar_model(args):
+ # token_list
+ if args.token_list is not None:
+ with open(args.token_list) as f:
+ token_list = [line.rstrip() for line in f]
+ args.token_list = list(token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+ else:
+ vocab_size = None
+
+ # frontend
+ if args.input_size is None:
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+ if args.model_name == "sond":
+ # data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # speaker encoder
+ if getattr(args, "speaker_encoder", None) is not None:
+ speaker_encoder_class = speaker_encoder_choices.get_class(args.speaker_encoder)
+ speaker_encoder = speaker_encoder_class(**args.speaker_encoder_conf)
+ else:
+ speaker_encoder = None
+
+ # ci scorer
+ if getattr(args, "ci_scorer", None) is not None:
+ ci_scorer_class = ci_scorer_choices.get_class(args.ci_scorer)
+ ci_scorer = ci_scorer_class(**args.ci_scorer_conf)
+ else:
+ ci_scorer = None
+
+ # cd scorer
+ if getattr(args, "cd_scorer", None) is not None:
+ cd_scorer_class = cd_scorer_choices.get_class(args.cd_scorer)
+ cd_scorer = cd_scorer_class(**args.cd_scorer_conf)
+ else:
+ cd_scorer = None
+
+ # decoder
+ decoder_class = decoder_choices.get_class(args.decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder.output_size(),
+ **args.decoder_conf,
+ )
+
+ # logger aggregator
+ if getattr(args, "label_aggregator", None) is not None:
+ label_aggregator_class = label_aggregator_choices.get_class(args.label_aggregator)
+ label_aggregator = label_aggregator_class(**args.label_aggregator_conf)
+ else:
+ label_aggregator = None
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ vocab_size=vocab_size,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ label_aggregator=label_aggregator,
+ encoder=encoder,
+ speaker_encoder=speaker_encoder,
+ ci_scorer=ci_scorer,
+ cd_scorer=cd_scorer,
+ decoder=decoder,
+ token_list=token_list,
+ **args.model_conf,
+ )
+
+ elif args.model_name == "eend_ola":
+ # encoder-decoder attractor
+ encoder_decoder_attractor_class = encoder_decoder_attractor_choices.get_class(args.encoder_decoder_attractor)
+ encoder_decoder_attractor = encoder_decoder_attractor_class(**args.encoder_decoder_attractor_conf)
+
+ # 9. Build model
+ model_class = model_choices.get_class(args.model)
+ model = model_class(
+ frontend=frontend,
+ encoder=encoder,
+ encoder_decoder_attractor=encoder_decoder_attractor,
+ **args.model_conf,
+ )
+
+ else:
+ raise NotImplementedError("Not supported model: {}".format(args.model))
+
+ # 10. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/build_utils/build_distributed.py b/funasr/build_utils/build_distributed.py
new file mode 100644
index 0000000..b64b4c0
--- /dev/null
+++ b/funasr/build_utils/build_distributed.py
@@ -0,0 +1,38 @@
+import logging
+import os
+
+import torch
+
+from funasr.train.distributed_utils import DistributedOption
+from funasr.utils.build_dataclass import build_dataclass
+
+
+def build_distributed(args):
+ distributed_option = build_dataclass(DistributedOption, args)
+ if args.use_pai:
+ distributed_option.init_options_pai()
+ distributed_option.init_torch_distributed_pai(args)
+ elif not args.simple_ddp:
+ distributed_option.init_torch_distributed(args)
+ elif args.distributed and args.simple_ddp:
+ distributed_option.init_torch_distributed_pai(args)
+ args.ngpu = torch.distributed.get_world_size()
+
+ for handler in logging.root.handlers[:]:
+ logging.root.removeHandler(handler)
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ logging.basicConfig(
+ level="INFO",
+ format=f"[{os.uname()[1].split('.')[0]}]"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ else:
+ logging.basicConfig(
+ level="ERROR",
+ format=f"[{os.uname()[1].split('.')[0]}]"
+ f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+ logging.info("world size: {}, rank: {}, local_rank: {}".format(distributed_option.dist_world_size,
+ distributed_option.dist_rank,
+ distributed_option.local_rank))
+ return distributed_option
diff --git a/funasr/build_utils/build_lm_model.py b/funasr/build_utils/build_lm_model.py
new file mode 100644
index 0000000..aaa4fb7
--- /dev/null
+++ b/funasr/build_utils/build_lm_model.py
@@ -0,0 +1,57 @@
+import logging
+
+from funasr.lm.abs_model import AbsLM
+from funasr.lm.abs_model import LanguageModel
+from funasr.lm.seq_rnn_lm import SequentialRNNLM
+from funasr.lm.transformer_lm import TransformerLM
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+lm_choices = ClassChoices(
+ "lm",
+ classes=dict(
+ seq_rnn=SequentialRNNLM,
+ transformer=TransformerLM,
+ ),
+ type_check=AbsLM,
+ default="seq_rnn",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ lm=LanguageModel,
+ ),
+ default="lm",
+)
+
+class_choices_list = [
+ # --lm and --lm_conf
+ lm_choices,
+ # --model and --model_conf
+ model_choices
+]
+
+
+def build_lm_model(args):
+ # token_list
+ if args.token_list is not None:
+ with open(args.token_list) as f:
+ token_list = [line.rstrip() for line in f]
+ args.token_list = list(token_list)
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+ else:
+ vocab_size = None
+
+ # lm
+ lm_class = lm_choices.get_class(args.lm)
+ lm = lm_class(vocab_size=vocab_size, **args.lm_conf)
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(lm=lm, vocab_size=vocab_size, **args.model_conf)
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/build_utils/build_model.py b/funasr/build_utils/build_model.py
new file mode 100644
index 0000000..13a6faa
--- /dev/null
+++ b/funasr/build_utils/build_model.py
@@ -0,0 +1,25 @@
+from funasr.build_utils.build_asr_model import build_asr_model
+from funasr.build_utils.build_lm_model import build_lm_model
+from funasr.build_utils.build_pretrain_model import build_pretrain_model
+from funasr.build_utils.build_punc_model import build_punc_model
+from funasr.build_utils.build_vad_model import build_vad_model
+from funasr.build_utils.build_diar_model import build_diar_model
+
+
+def build_model(args):
+ if args.task_name == "asr":
+ model = build_asr_model(args)
+ elif args.task_name == "pretrain":
+ model = build_pretrain_model(args)
+ elif args.task_name == "lm":
+ model = build_lm_model(args)
+ elif args.task_name == "punc":
+ model = build_punc_model(args)
+ elif args.task_name == "vad":
+ model = build_vad_model(args)
+ elif args.task_name == "diar":
+ model = build_diar_model(args)
+ else:
+ raise NotImplementedError("Not supported task: {}".format(args.task_name))
+
+ return model
diff --git a/funasr/build_utils/build_optimizer.py b/funasr/build_utils/build_optimizer.py
new file mode 100644
index 0000000..bd0b73d
--- /dev/null
+++ b/funasr/build_utils/build_optimizer.py
@@ -0,0 +1,28 @@
+import torch
+
+from funasr.optimizers.fairseq_adam import FairseqAdam
+from funasr.optimizers.sgd import SGD
+
+
+def build_optimizer(args, model):
+ optim_classes = dict(
+ adam=torch.optim.Adam,
+ fairseq_adam=FairseqAdam,
+ adamw=torch.optim.AdamW,
+ sgd=SGD,
+ adadelta=torch.optim.Adadelta,
+ adagrad=torch.optim.Adagrad,
+ adamax=torch.optim.Adamax,
+ asgd=torch.optim.ASGD,
+ lbfgs=torch.optim.LBFGS,
+ rmsprop=torch.optim.RMSprop,
+ rprop=torch.optim.Rprop,
+ )
+
+ optim_class = optim_classes.get(args.optim)
+ if optim_class is None:
+ raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
+ optimizer = optim_class(model.parameters(), **args.optim_conf)
+
+ optimizers = [optimizer]
+ return optimizers
\ No newline at end of file
diff --git a/funasr/build_utils/build_pretrain_model.py b/funasr/build_utils/build_pretrain_model.py
new file mode 100644
index 0000000..629937f
--- /dev/null
+++ b/funasr/build_utils/build_pretrain_model.py
@@ -0,0 +1,107 @@
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.data2vec import Data2VecPretrainModel
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.specaug import SpecAug
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(specaug=SpecAug),
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ default=None,
+ optional=True,
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ data2vec_encoder=Data2VecEncoder,
+ ),
+ default="data2vec_encoder",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ data2vec=Data2VecPretrainModel,
+ ),
+ default="data2vec",
+)
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --model and --model_conf
+ model_choices,
+]
+
+
+def build_pretrain_model(args):
+ # frontend
+ if args.input_size is None:
+ frontend_class = frontend_choices.get_class(args.frontend)
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(
+ input_size=input_size,
+ **args.encoder_conf,
+ )
+
+ if args.model == "data2vec":
+ model_class = model_choices.get_class("data2vec")
+ model = model_class(
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ )
+ else:
+ raise NotImplementedError("Not supported model: {}".format(args.model))
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/build_utils/build_punc_model.py b/funasr/build_utils/build_punc_model.py
new file mode 100644
index 0000000..62ccaf2
--- /dev/null
+++ b/funasr/build_utils/build_punc_model.py
@@ -0,0 +1,68 @@
+import logging
+
+from funasr.models.target_delay_transformer import TargetDelayTransformer
+from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.torch_utils.initialize import initialize
+from funasr.train.abs_model import PunctuationModel
+from funasr.train.class_choices import ClassChoices
+
+punc_choices = ClassChoices(
+ "punctuation",
+ classes=dict(
+ target_delay=TargetDelayTransformer,
+ vad_realtime=VadRealtimeTransformer
+ ),
+ default="target_delay",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ punc=PunctuationModel,
+ ),
+ default="punc",
+)
+class_choices_list = [
+ # --punc and --punc_conf
+ punc_choices,
+ # --model and --model_conf
+ model_choices
+]
+
+
+def build_punc_model(args):
+ # token_list and punc list
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+ args.token_list = token_list.copy()
+ if isinstance(args.punc_list, str):
+ with open(args.punc_list, encoding="utf-8") as f2:
+ pairs = [line.rstrip().split(":") for line in f2]
+ punc_list = [pair[0] for pair in pairs]
+ punc_weight_list = [float(pair[1]) for pair in pairs]
+ args.punc_list = punc_list.copy()
+ elif isinstance(args.punc_list, list):
+ punc_list = args.punc_list.copy()
+ punc_weight_list = [1] * len(punc_list)
+ if isinstance(args.token_list, (tuple, list)):
+ token_list = args.token_list.copy()
+ else:
+ raise RuntimeError("token_list must be str or dict")
+
+ vocab_size = len(token_list)
+ punc_size = len(punc_list)
+ logging.info(f"Vocabulary size: {vocab_size}")
+
+ # punc
+ punc_class = punc_choices.get_class(args.punctuation)
+ punc = punc_class(vocab_size=vocab_size, punc_size=punc_size, **args.punctuation_conf)
+
+ if "punc_weight" in args.model_conf:
+ args.model_conf.pop("punc_weight")
+ model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/build_utils/build_scheduler.py b/funasr/build_utils/build_scheduler.py
new file mode 100644
index 0000000..4b9990e
--- /dev/null
+++ b/funasr/build_utils/build_scheduler.py
@@ -0,0 +1,44 @@
+import torch
+import torch.multiprocessing
+import torch.nn
+import torch.optim
+
+from funasr.schedulers.noam_lr import NoamLR
+from funasr.schedulers.tri_stage_scheduler import TriStageLR
+from funasr.schedulers.warmup_lr import WarmupLR
+
+
+def build_scheduler(args, optimizers):
+ scheduler_classes = dict(
+ ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
+ lambdalr=torch.optim.lr_scheduler.LambdaLR,
+ steplr=torch.optim.lr_scheduler.StepLR,
+ multisteplr=torch.optim.lr_scheduler.MultiStepLR,
+ exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
+ CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
+ noamlr=NoamLR,
+ warmuplr=WarmupLR,
+ tri_stage=TriStageLR,
+ cycliclr=torch.optim.lr_scheduler.CyclicLR,
+ onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
+ CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
+ )
+
+ schedulers = []
+ for i, optim in enumerate(optimizers, 1):
+ suf = "" if i == 1 else str(i)
+ name = getattr(args, f"scheduler{suf}")
+ conf = getattr(args, f"scheduler{suf}_conf")
+ if name is not None:
+ cls_ = scheduler_classes.get(name)
+ if cls_ is None:
+ raise ValueError(
+ f"must be one of {list(scheduler_classes)}: {name}"
+ )
+ scheduler = cls_(optim, **conf)
+ else:
+ scheduler = None
+
+ schedulers.append(scheduler)
+
+ return schedulers
\ No newline at end of file
diff --git a/funasr/build_utils/build_trainer.py b/funasr/build_utils/build_trainer.py
new file mode 100644
index 0000000..aff99b5
--- /dev/null
+++ b/funasr/build_utils/build_trainer.py
@@ -0,0 +1,820 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Trainer module."""
+import argparse
+import dataclasses
+import logging
+import os
+import time
+from contextlib import contextmanager
+from dataclasses import is_dataclass
+from distutils.version import LooseVersion
+from io import BytesIO
+from pathlib import Path
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+import humanfriendly
+import oss2
+import torch
+import torch.nn
+import torch.optim
+from typeguard import check_argument_types
+
+from funasr.iterators.abs_iter_factory import AbsIterFactory
+from funasr.main_funcs.average_nbest_models import average_nbest_models
+from funasr.models.base_model import FunASRModel
+from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
+from funasr.schedulers.abs_scheduler import AbsEpochStepScheduler
+from funasr.schedulers.abs_scheduler import AbsScheduler
+from funasr.schedulers.abs_scheduler import AbsValEpochStepScheduler
+from funasr.torch_utils.add_gradient_noise import add_gradient_noise
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.recursive_op import recursive_average
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.train.distributed_utils import DistributedOption
+from funasr.train.reporter import Reporter
+from funasr.train.reporter import SubReporter
+from funasr.utils.build_dataclass import build_dataclass
+
+if torch.distributed.is_available():
+ from torch.distributed import ReduceOp
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+ from torch.cuda.amp import GradScaler
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+ GradScaler = None
+
+try:
+ import fairscale
+except ImportError:
+ fairscale = None
+
+
+@dataclasses.dataclass
+class TrainerOptions:
+ ngpu: int
+ resume: bool
+ use_amp: bool
+ train_dtype: str
+ grad_noise: bool
+ accum_grad: int
+ grad_clip: float
+ grad_clip_type: float
+ log_interval: Optional[int]
+ # no_forward_run: bool
+ use_tensorboard: bool
+ # use_wandb: bool
+ output_dir: Union[Path, str]
+ max_epoch: int
+ max_update: int
+ seed: int
+ # sharded_ddp: bool
+ patience: Optional[int]
+ keep_nbest_models: Union[int, List[int]]
+ nbest_averaging_interval: int
+ early_stopping_criterion: Sequence[str]
+ best_model_criterion: Sequence[Sequence[str]]
+ val_scheduler_criterion: Sequence[str]
+ unused_parameters: bool
+ # wandb_model_log_interval: int
+ use_pai: bool
+ oss_bucket: Union[oss2.Bucket, None]
+
+
+class Trainer:
+ """Trainer
+
+ """
+
+ def __init__(self,
+ args,
+ model: FunASRModel,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ train_dataloader: AbsIterFactory,
+ valid_dataloader: AbsIterFactory,
+ distributed_option: DistributedOption):
+ self.trainer_options = self.build_options(args)
+ self.model = model
+ self.optimizers = optimizers
+ self.schedulers = schedulers
+ self.train_dataloader = train_dataloader
+ self.valid_dataloader = valid_dataloader
+ self.distributed_option = distributed_option
+
+ def build_options(self, args: argparse.Namespace) -> TrainerOptions:
+ """Build options consumed by train(), eval()"""
+ assert check_argument_types()
+ return build_dataclass(TrainerOptions, args)
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ """Reserved for future development of another Trainer"""
+ pass
+
+ def resume(self,
+ checkpoint: Union[str, Path],
+ model: torch.nn.Module,
+ reporter: Reporter,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ scaler: Optional[GradScaler],
+ ngpu: int = 0,
+ ):
+ states = torch.load(
+ checkpoint,
+ map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu",
+ )
+ model.load_state_dict(states["model"])
+ reporter.load_state_dict(states["reporter"])
+ for optimizer, state in zip(optimizers, states["optimizers"]):
+ optimizer.load_state_dict(state)
+ for scheduler, state in zip(schedulers, states["schedulers"]):
+ if scheduler is not None:
+ scheduler.load_state_dict(state)
+ if scaler is not None:
+ if states["scaler"] is None:
+ logging.warning("scaler state is not found")
+ else:
+ scaler.load_state_dict(states["scaler"])
+
+ logging.info(f"The training was resumed using {checkpoint}")
+
+ def run(self) -> None:
+ """Perform training. This method performs the main process of training."""
+ assert check_argument_types()
+ # NOTE(kamo): Don't check the type more strictly as far trainer_options
+ model = self.model
+ optimizers = self.optimizers
+ schedulers = self.schedulers
+ train_dataloader = self.train_dataloader
+ valid_dataloader = self.valid_dataloader
+ trainer_options = self.trainer_options
+ distributed_option = self.distributed_option
+ assert is_dataclass(trainer_options), type(trainer_options)
+ assert len(optimizers) == len(schedulers), (len(optimizers), len(schedulers))
+
+ if isinstance(trainer_options.keep_nbest_models, int):
+ keep_nbest_models = [trainer_options.keep_nbest_models]
+ else:
+ if len(trainer_options.keep_nbest_models) == 0:
+ logging.warning("No keep_nbest_models is given. Change to [1]")
+ trainer_options.keep_nbest_models = [1]
+ keep_nbest_models = trainer_options.keep_nbest_models
+
+ output_dir = Path(trainer_options.output_dir)
+ reporter = Reporter()
+ if trainer_options.use_amp:
+ if LooseVersion(torch.__version__) < LooseVersion("1.6.0"):
+ raise RuntimeError(
+ "Require torch>=1.6.0 for Automatic Mixed Precision"
+ )
+ # if trainer_options.sharded_ddp:
+ # if fairscale is None:
+ # raise RuntimeError(
+ # "Requiring fairscale. Do 'pip install fairscale'"
+ # )
+ # scaler = fairscale.optim.grad_scaler.ShardedGradScaler()
+ # else:
+ scaler = GradScaler()
+ else:
+ scaler = None
+
+ if trainer_options.resume and (output_dir / "checkpoint.pb").exists():
+ self.resume(
+ checkpoint=output_dir / "checkpoint.pb",
+ model=model,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ reporter=reporter,
+ scaler=scaler,
+ ngpu=trainer_options.ngpu,
+ )
+
+ start_epoch = reporter.get_epoch() + 1
+ if start_epoch == trainer_options.max_epoch + 1:
+ logging.warning(
+ f"The training has already reached at max_epoch: {start_epoch}"
+ )
+
+ if distributed_option.distributed:
+ dp_model = torch.nn.parallel.DistributedDataParallel(
+ model, find_unused_parameters=trainer_options.unused_parameters)
+ elif distributed_option.ngpu > 1:
+ dp_model = torch.nn.parallel.DataParallel(
+ model,
+ device_ids=list(range(distributed_option.ngpu)),
+ )
+ else:
+ # NOTE(kamo): DataParallel also should work with ngpu=1,
+ # but for debuggability it's better to keep this block.
+ dp_model = model
+
+ if trainer_options.use_tensorboard and (
+ not distributed_option.distributed or distributed_option.dist_rank == 0
+ ):
+ from torch.utils.tensorboard import SummaryWriter
+ if trainer_options.use_pai:
+ train_summary_writer = SummaryWriter(
+ os.path.join(trainer_options.output_dir, "tensorboard/train")
+ )
+ valid_summary_writer = SummaryWriter(
+ os.path.join(trainer_options.output_dir, "tensorboard/valid")
+ )
+ else:
+ train_summary_writer = SummaryWriter(
+ str(output_dir / "tensorboard" / "train")
+ )
+ valid_summary_writer = SummaryWriter(
+ str(output_dir / "tensorboard" / "valid")
+ )
+ else:
+ train_summary_writer = None
+
+ start_time = time.perf_counter()
+ for iepoch in range(start_epoch, trainer_options.max_epoch + 1):
+ if iepoch != start_epoch:
+ logging.info(
+ "{}/{}epoch started. Estimated time to finish: {}".format(
+ iepoch,
+ trainer_options.max_epoch,
+ humanfriendly.format_timespan(
+ (time.perf_counter() - start_time)
+ / (iepoch - start_epoch)
+ * (trainer_options.max_epoch - iepoch + 1)
+ ),
+ )
+ )
+ else:
+ logging.info(f"{iepoch}/{trainer_options.max_epoch}epoch started")
+ set_all_random_seed(trainer_options.seed + iepoch)
+
+ reporter.set_epoch(iepoch)
+ # 1. Train and validation for one-epoch
+ with reporter.observe("train") as sub_reporter:
+ all_steps_are_invalid, max_update_stop = self.train_one_epoch(
+ model=dp_model,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ iterator=train_dataloader.build_iter(iepoch),
+ reporter=sub_reporter,
+ scaler=scaler,
+ summary_writer=train_summary_writer,
+ options=trainer_options,
+ distributed_option=distributed_option,
+ )
+
+ with reporter.observe("valid") as sub_reporter:
+ self.validate_one_epoch(
+ model=dp_model,
+ iterator=valid_dataloader.build_iter(iepoch),
+ reporter=sub_reporter,
+ options=trainer_options,
+ distributed_option=distributed_option,
+ )
+
+ # 2. LR Scheduler step
+ for scheduler in schedulers:
+ if isinstance(scheduler, AbsValEpochStepScheduler):
+ scheduler.step(
+ reporter.get_value(*trainer_options.val_scheduler_criterion)
+ )
+ elif isinstance(scheduler, AbsEpochStepScheduler):
+ scheduler.step()
+ # if trainer_options.sharded_ddp:
+ # for optimizer in optimizers:
+ # if isinstance(optimizer, fairscale.optim.oss.OSS):
+ # optimizer.consolidate_state_dict()
+
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ # 3. Report the results
+ logging.info(reporter.log_message())
+ if train_summary_writer is not None:
+ reporter.tensorboard_add_scalar(train_summary_writer, key1="train")
+ reporter.tensorboard_add_scalar(valid_summary_writer, key1="valid")
+ # if trainer_options.use_wandb:
+ # reporter.wandb_log()
+
+ # save tensorboard on oss
+ if trainer_options.use_pai and train_summary_writer is not None:
+ def write_tensorboard_summary(summary_writer_path, oss_bucket):
+ file_list = []
+ for root, dirs, files in os.walk(summary_writer_path, topdown=False):
+ for name in files:
+ file_full_path = os.path.join(root, name)
+ file_list.append(file_full_path)
+
+ for file_full_path in file_list:
+ with open(file_full_path, "rb") as f:
+ oss_bucket.put_object(file_full_path, f)
+
+ write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/train"),
+ trainer_options.oss_bucket)
+ write_tensorboard_summary(os.path.join(trainer_options.output_dir, "tensorboard/valid"),
+ trainer_options.oss_bucket)
+
+ # 4. Save/Update the checkpoint
+ if trainer_options.use_pai:
+ buffer = BytesIO()
+ torch.save(
+ {
+ "model": model.state_dict(),
+ "reporter": reporter.state_dict(),
+ "optimizers": [o.state_dict() for o in optimizers],
+ "schedulers": [
+ s.state_dict() if s is not None else None
+ for s in schedulers
+ ],
+ "scaler": scaler.state_dict() if scaler is not None else None,
+ "ema_model": model.encoder.ema.model.state_dict()
+ if hasattr(model.encoder, "ema") and model.encoder.ema is not None else None,
+ },
+ buffer,
+ )
+ trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir, "checkpoint.pb"),
+ buffer.getvalue())
+ else:
+ torch.save(
+ {
+ "model": model.state_dict(),
+ "reporter": reporter.state_dict(),
+ "optimizers": [o.state_dict() for o in optimizers],
+ "schedulers": [
+ s.state_dict() if s is not None else None
+ for s in schedulers
+ ],
+ "scaler": scaler.state_dict() if scaler is not None else None,
+ },
+ output_dir / "checkpoint.pb",
+ )
+
+ # 5. Save and log the model and update the link to the best model
+ if trainer_options.use_pai:
+ buffer = BytesIO()
+ torch.save(model.state_dict(), buffer)
+ trainer_options.oss_bucket.put_object(os.path.join(trainer_options.output_dir,
+ f"{iepoch}epoch.pb"), buffer.getvalue())
+ else:
+ torch.save(model.state_dict(), output_dir / f"{iepoch}epoch.pb")
+
+ # Creates a sym link latest.pb -> {iepoch}epoch.pb
+ if trainer_options.use_pai:
+ p = os.path.join(trainer_options.output_dir, "latest.pb")
+ if trainer_options.oss_bucket.object_exists(p):
+ trainer_options.oss_bucket.delete_object(p)
+ trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
+ os.path.join(trainer_options.output_dir,
+ f"{iepoch}epoch.pb"), p)
+ else:
+ p = output_dir / "latest.pb"
+ if p.is_symlink() or p.exists():
+ p.unlink()
+ p.symlink_to(f"{iepoch}epoch.pb")
+
+ _improved = []
+ for _phase, k, _mode in trainer_options.best_model_criterion:
+ # e.g. _phase, k, _mode = "train", "loss", "min"
+ if reporter.has(_phase, k):
+ best_epoch = reporter.get_best_epoch(_phase, k, _mode)
+ # Creates sym links if it's the best result
+ if best_epoch == iepoch:
+ if trainer_options.use_pai:
+ p = os.path.join(trainer_options.output_dir, f"{_phase}.{k}.best.pb")
+ if trainer_options.oss_bucket.object_exists(p):
+ trainer_options.oss_bucket.delete_object(p)
+ trainer_options.oss_bucket.copy_object(trainer_options.oss_bucket.bucket_name,
+ os.path.join(trainer_options.output_dir,
+ f"{iepoch}epoch.pb"), p)
+ else:
+ p = output_dir / f"{_phase}.{k}.best.pb"
+ if p.is_symlink() or p.exists():
+ p.unlink()
+ p.symlink_to(f"{iepoch}epoch.pb")
+ _improved.append(f"{_phase}.{k}")
+ if len(_improved) == 0:
+ logging.info("There are no improvements in this epoch")
+ else:
+ logging.info(
+ "The best model has been updated: " + ", ".join(_improved)
+ )
+
+ # log_model = (
+ # trainer_options.wandb_model_log_interval > 0
+ # and iepoch % trainer_options.wandb_model_log_interval == 0
+ # )
+ # if log_model and trainer_options.use_wandb:
+ # import wandb
+ #
+ # logging.info("Logging Model on this epoch :::::")
+ # artifact = wandb.Artifact(
+ # name=f"model_{wandb.run.id}",
+ # type="model",
+ # metadata={"improved": _improved},
+ # )
+ # artifact.add_file(str(output_dir / f"{iepoch}epoch.pb"))
+ # aliases = [
+ # f"epoch-{iepoch}",
+ # "best" if best_epoch == iepoch else "",
+ # ]
+ # wandb.log_artifact(artifact, aliases=aliases)
+
+ # 6. Remove the model files excluding n-best epoch and latest epoch
+ _removed = []
+ # Get the union set of the n-best among multiple criterion
+ nbests = set().union(
+ *[
+ set(reporter.sort_epochs(ph, k, m)[: max(keep_nbest_models)])
+ for ph, k, m in trainer_options.best_model_criterion
+ if reporter.has(ph, k)
+ ]
+ )
+
+ # Generated n-best averaged model
+ if (
+ trainer_options.nbest_averaging_interval > 0
+ and iepoch % trainer_options.nbest_averaging_interval == 0
+ ):
+ average_nbest_models(
+ reporter=reporter,
+ output_dir=output_dir,
+ best_model_criterion=trainer_options.best_model_criterion,
+ nbest=keep_nbest_models,
+ suffix=f"till{iepoch}epoch",
+ oss_bucket=trainer_options.oss_bucket,
+ pai_output_dir=trainer_options.output_dir,
+ )
+
+ for e in range(1, iepoch):
+ if trainer_options.use_pai:
+ p = os.path.join(trainer_options.output_dir, f"{e}epoch.pb")
+ if trainer_options.oss_bucket.object_exists(p) and e not in nbests:
+ trainer_options.oss_bucket.delete_object(p)
+ _removed.append(str(p))
+ else:
+ p = output_dir / f"{e}epoch.pb"
+ if p.exists() and e not in nbests:
+ p.unlink()
+ _removed.append(str(p))
+ if len(_removed) != 0:
+ logging.info("The model files were removed: " + ", ".join(_removed))
+
+ # 7. If any updating haven't happened, stops the training
+ if all_steps_are_invalid:
+ logging.warning(
+ f"The gradients at all steps are invalid in this epoch. "
+ f"Something seems wrong. This training was stopped at {iepoch}epoch"
+ )
+ break
+
+ if max_update_stop:
+ logging.info(
+ f"Stopping training due to "
+ f"num_updates: {trainer_options.num_updates} >= max_update: {trainer_options.max_update}"
+ )
+ break
+
+ # 8. Check early stopping
+ if trainer_options.patience is not None:
+ if reporter.check_early_stopping(
+ trainer_options.patience, *trainer_options.early_stopping_criterion
+ ):
+ break
+
+ else:
+ logging.info(
+ f"The training was finished at {trainer_options.max_epoch} epochs "
+ )
+
+ # Generated n-best averaged model
+ if not distributed_option.distributed or distributed_option.dist_rank == 0:
+ average_nbest_models(
+ reporter=reporter,
+ output_dir=output_dir,
+ best_model_criterion=trainer_options.best_model_criterion,
+ nbest=keep_nbest_models,
+ oss_bucket=trainer_options.oss_bucket,
+ pai_output_dir=trainer_options.output_dir,
+ )
+
+ def train_one_epoch(
+ self,
+ model: torch.nn.Module,
+ iterator: Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ scaler: Optional[GradScaler],
+ reporter: SubReporter,
+ summary_writer,
+ options: TrainerOptions,
+ distributed_option: DistributedOption,
+ ) -> Tuple[bool, bool]:
+ assert check_argument_types()
+
+ grad_noise = options.grad_noise
+ accum_grad = options.accum_grad
+ grad_clip = options.grad_clip
+ grad_clip_type = options.grad_clip_type
+ log_interval = options.log_interval
+ # no_forward_run = options.no_forward_run
+ ngpu = options.ngpu
+ # use_wandb = options.use_wandb
+ distributed = distributed_option.distributed
+
+ if log_interval is None:
+ try:
+ log_interval = max(len(iterator) // 20, 10)
+ except TypeError:
+ log_interval = 100
+
+ model.train()
+ all_steps_are_invalid = True
+ max_update_stop = False
+ # [For distributed] Because iteration counts are not always equals between
+ # processes, send stop-flag to the other processes if iterator is finished
+ iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
+
+ start_time = time.perf_counter()
+ for iiter, (_, batch) in enumerate(
+ reporter.measure_iter_time(iterator, "iter_time"), 1
+ ):
+ assert isinstance(batch, dict), type(batch)
+
+ if distributed:
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+ if iterator_stop > 0:
+ break
+
+ batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
+ # if no_forward_run:
+ # all_steps_are_invalid = False
+ # continue
+
+ with autocast(scaler is not None):
+ with reporter.measure_time("forward_time"):
+ retval = model(**batch)
+
+ # Note(kamo):
+ # Supporting two patterns for the returned value from the model
+ # a. dict type
+ if isinstance(retval, dict):
+ loss = retval["loss"]
+ stats = retval["stats"]
+ weight = retval["weight"]
+ optim_idx = retval.get("optim_idx")
+ if optim_idx is not None and not isinstance(optim_idx, int):
+ if not isinstance(optim_idx, torch.Tensor):
+ raise RuntimeError(
+ "optim_idx must be int or 1dim torch.Tensor, "
+ f"but got {type(optim_idx)}"
+ )
+ if optim_idx.dim() >= 2:
+ raise RuntimeError(
+ "optim_idx must be int or 1dim torch.Tensor, "
+ f"but got {optim_idx.dim()}dim tensor"
+ )
+ if optim_idx.dim() == 1:
+ for v in optim_idx:
+ if v != optim_idx[0]:
+ raise RuntimeError(
+ "optim_idx must be 1dim tensor "
+ "having same values for all entries"
+ )
+ optim_idx = optim_idx[0].item()
+ else:
+ optim_idx = optim_idx.item()
+
+ # b. tuple or list type
+ else:
+ loss, stats, weight = retval
+ optim_idx = None
+
+ stats = {k: v for k, v in stats.items() if v is not None}
+ if ngpu > 1 or distributed:
+ # Apply weighted averaging for loss and stats
+ loss = (loss * weight.type(loss.dtype)).sum()
+
+ # if distributed, this method can also apply all_reduce()
+ stats, weight = recursive_average(stats, weight, distributed)
+
+ # Now weight is summation over all workers
+ loss /= weight
+ if distributed:
+ # NOTE(kamo): Multiply world_size because DistributedDataParallel
+ # automatically normalizes the gradient by world_size.
+ loss *= torch.distributed.get_world_size()
+
+ loss /= accum_grad
+
+ reporter.register(stats, weight)
+
+ with reporter.measure_time("backward_time"):
+ if scaler is not None:
+ # Scales loss. Calls backward() on scaled loss
+ # to create scaled gradients.
+ # Backward passes under autocast are not recommended.
+ # Backward ops run in the same dtype autocast chose
+ # for corresponding forward ops.
+ scaler.scale(loss).backward()
+ else:
+ loss.backward()
+
+ if iiter % accum_grad == 0:
+ if scaler is not None:
+ # Unscales the gradients of optimizer's assigned params in-place
+ for iopt, optimizer in enumerate(optimizers):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ scaler.unscale_(optimizer)
+
+ # gradient noise injection
+ if grad_noise:
+ add_gradient_noise(
+ model,
+ reporter.get_total_count(),
+ duration=100,
+ eta=1.0,
+ scale_factor=0.55,
+ )
+
+ # compute the gradient norm to check if it is normal or not
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ model.parameters(),
+ max_norm=grad_clip,
+ norm_type=grad_clip_type,
+ )
+ # PyTorch<=1.4, clip_grad_norm_ returns float value
+ if not isinstance(grad_norm, torch.Tensor):
+ grad_norm = torch.tensor(grad_norm)
+
+ if not torch.isfinite(grad_norm):
+ logging.warning(
+ f"The grad norm is {grad_norm}. Skipping updating the model."
+ )
+
+ # Must invoke scaler.update() if unscale_() is used in the iteration
+ # to avoid the following error:
+ # RuntimeError: unscale_() has already been called
+ # on this optimizer since the last update().
+ # Note that if the gradient has inf/nan values,
+ # scaler.step skips optimizer.step().
+ if scaler is not None:
+ for iopt, optimizer in enumerate(optimizers):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ scaler.step(optimizer)
+ scaler.update()
+
+ else:
+ all_steps_are_invalid = False
+ with reporter.measure_time("optim_step_time"):
+ for iopt, (optimizer, scheduler) in enumerate(
+ zip(optimizers, schedulers)
+ ):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ if scaler is not None:
+ # scaler.step() first unscales the gradients of
+ # the optimizer's assigned params.
+ scaler.step(optimizer)
+ # Updates the scale for next iteration.
+ scaler.update()
+ else:
+ optimizer.step()
+ if isinstance(scheduler, AbsBatchStepScheduler):
+ scheduler.step()
+ for iopt, optimizer in enumerate(optimizers):
+ if optim_idx is not None and iopt != optim_idx:
+ continue
+ optimizer.zero_grad()
+
+ # Register lr and train/load time[sec/step],
+ # where step refers to accum_grad * mini-batch
+ reporter.register(
+ dict(
+ {
+ f"optim{i}_lr{j}": pg["lr"]
+ for i, optimizer in enumerate(optimizers)
+ for j, pg in enumerate(optimizer.param_groups)
+ if "lr" in pg
+ },
+ train_time=time.perf_counter() - start_time,
+ ),
+ )
+ start_time = time.perf_counter()
+
+ # update num_updates
+ if distributed:
+ if hasattr(model.module, "num_updates"):
+ model.module.set_num_updates(model.module.get_num_updates() + 1)
+ options.num_updates = model.module.get_num_updates()
+ if model.module.get_num_updates() >= options.max_update:
+ max_update_stop = True
+ else:
+ if hasattr(model, "num_updates"):
+ model.set_num_updates(model.get_num_updates() + 1)
+ options.num_updates = model.get_num_updates()
+ if model.get_num_updates() >= options.max_update:
+ max_update_stop = True
+
+ # NOTE(kamo): Call log_message() after next()
+ reporter.next()
+ if iiter % log_interval == 0:
+ num_updates = options.num_updates if hasattr(options, "num_updates") else None
+ logging.info(reporter.log_message(-log_interval, num_updates=num_updates))
+ if summary_writer is not None:
+ reporter.tensorboard_add_scalar(summary_writer, -log_interval)
+ # if use_wandb:
+ # reporter.wandb_log()
+
+ if max_update_stop:
+ break
+
+ else:
+ if distributed:
+ iterator_stop.fill_(1)
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+ return all_steps_are_invalid, max_update_stop
+
+ @torch.no_grad()
+ def validate_one_epoch(
+ self,
+ model: torch.nn.Module,
+ iterator: Iterable[Dict[str, torch.Tensor]],
+ reporter: SubReporter,
+ options: TrainerOptions,
+ distributed_option: DistributedOption,
+ ) -> None:
+ assert check_argument_types()
+ ngpu = options.ngpu
+ # no_forward_run = options.no_forward_run
+ distributed = distributed_option.distributed
+
+ model.eval()
+
+ # [For distributed] Because iteration counts are not always equals between
+ # processes, send stop-flag to the other processes if iterator is finished
+ iterator_stop = torch.tensor(0).to("cuda" if ngpu > 0 else "cpu")
+ for (_, batch) in iterator:
+ assert isinstance(batch, dict), type(batch)
+ if distributed:
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+ if iterator_stop > 0:
+ break
+
+ batch = to_device(batch, "cuda" if ngpu > 0 else "cpu")
+ # if no_forward_run:
+ # continue
+
+ retval = model(**batch)
+ if isinstance(retval, dict):
+ stats = retval["stats"]
+ weight = retval["weight"]
+ else:
+ _, stats, weight = retval
+ if ngpu > 1 or distributed:
+ # Apply weighted averaging for stats.
+ # if distributed, this method can also apply all_reduce()
+ stats, weight = recursive_average(stats, weight, distributed)
+
+ reporter.register(stats, weight)
+ reporter.next()
+
+ else:
+ if distributed:
+ iterator_stop.fill_(1)
+ torch.distributed.all_reduce(iterator_stop, ReduceOp.SUM)
+
+
+def build_trainer(
+ args,
+ model: FunASRModel,
+ optimizers: Sequence[torch.optim.Optimizer],
+ schedulers: Sequence[Optional[AbsScheduler]],
+ train_dataloader: AbsIterFactory,
+ valid_dataloader: AbsIterFactory,
+ distributed_option: DistributedOption
+):
+ trainer = Trainer(
+ args=args,
+ model=model,
+ optimizers=optimizers,
+ schedulers=schedulers,
+ train_dataloader=train_dataloader,
+ valid_dataloader=valid_dataloader,
+ distributed_option=distributed_option
+ )
+ return trainer
diff --git a/funasr/build_utils/build_vad_model.py b/funasr/build_utils/build_vad_model.py
new file mode 100644
index 0000000..76eb09b
--- /dev/null
+++ b/funasr/build_utils/build_vad_model.py
@@ -0,0 +1,77 @@
+import torch
+
+from funasr.models.e2e_vad import E2EVadModel
+from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(
+ default=DefaultFrontend,
+ sliding_window=SlidingWindow,
+ s3prl=S3prlFrontend,
+ fused=FusedFrontends,
+ wav_frontend=WavFrontend,
+ wav_frontend_online=WavFrontendOnline,
+ ),
+ default="default",
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ fsmn=FSMN,
+ ),
+ type_check=torch.nn.Module,
+ default="fsmn",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ e2evad=E2EVadModel,
+ ),
+ default="e2evad",
+)
+
+class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --model and --model_conf
+ model_choices,
+]
+
+
+def build_vad_model(args):
+ # frontend
+ if args.input_size is None:
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(**args.encoder_conf)
+
+ model_class = model_choices.get_class(args.model)
+ model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
+
+ # initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ return model
diff --git a/funasr/datasets/iterable_dataset_modelscope.py b/funasr/datasets/iterable_dataset_modelscope.py
deleted file mode 100644
index 860492c..0000000
--- a/funasr/datasets/iterable_dataset_modelscope.py
+++ /dev/null
@@ -1,349 +0,0 @@
-# Copyright (c) Alibaba, Inc. and its affiliates.
-# Part of the implementation is borrowed from espnet/espnet.
-"""Iterable dataset module."""
-import copy
-from io import StringIO
-from pathlib import Path
-from typing import Callable, Collection, Dict, Iterator, Tuple, Union
-
-import kaldiio
-import numpy as np
-import soundfile
-import torch
-from funasr.datasets.dataset import ESPnetDataset
-from torch.utils.data.dataset import IterableDataset
-from typeguard import check_argument_types
-
-from funasr.utils import wav_utils
-
-
-def load_kaldi(input):
- retval = kaldiio.load_mat(input)
- if isinstance(retval, tuple):
- assert len(retval) == 2, len(retval)
- if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
- # sound scp case
- rate, array = retval
- elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
- # Extended ark format case
- array, rate = retval
- else:
- raise RuntimeError(
- f'Unexpected type: {type(retval[0])}, {type(retval[1])}')
-
- # Multichannel wave fie
- # array: (NSample, Channel) or (Nsample)
-
- else:
- # Normal ark case
- assert isinstance(retval, np.ndarray), type(retval)
- array = retval
- return array
-
-
-DATA_TYPES = {
- 'sound':
- lambda x: soundfile.read(x)[0],
- 'kaldi_ark':
- load_kaldi,
- 'npy':
- np.load,
- 'text_int':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=' '),
- 'csv_int':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.long, delimiter=','),
- 'text_float':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=' '
- ),
- 'csv_float':
- lambda x: np.loadtxt(StringIO(x), ndmin=1, dtype=np.float32, delimiter=','
- ),
- 'text':
- lambda x: x,
-}
-
-
-class IterableESPnetDatasetModelScope(IterableDataset):
- """Pytorch Dataset class for ESPNet.
-
- Examples:
- >>> dataset = IterableESPnetDataset([('wav.scp', 'input', 'sound'),
- ... ('token_int', 'output', 'text_int')],
- ... )
- >>> for uid, data in dataset:
- ... data
- {'input': per_utt_array, 'output': per_utt_array}
- """
- def __init__(self,
- path_name_type_list: Collection[Tuple[any, str, str]],
- preprocess: Callable[[str, Dict[str, np.ndarray]],
- Dict[str, np.ndarray]] = None,
- float_dtype: str = 'float32',
- int_dtype: str = 'long',
- key_file: str = None,
- sample_rate: Union[dict, int] = 16000):
- assert check_argument_types()
- if len(path_name_type_list) == 0:
- raise ValueError(
- '1 or more elements are required for "path_name_type_list"')
-
- self.preprocess = preprocess
-
- self.float_dtype = float_dtype
- self.int_dtype = int_dtype
- self.key_file = key_file
- self.sample_rate = sample_rate
-
- self.debug_info = {}
- non_iterable_list = []
- self.path_name_type_list = []
-
- path_list = path_name_type_list[0]
- name = path_name_type_list[1]
- _type = path_name_type_list[2]
- if name in self.debug_info:
- raise RuntimeError(f'"{name}" is duplicated for data-key')
- self.debug_info[name] = path_list, _type
- # for path, name, _type in path_name_type_list:
- for path in path_list:
- self.path_name_type_list.append((path, name, _type))
-
- if len(non_iterable_list) != 0:
- # Some types doesn't support iterable mode
- self.non_iterable_dataset = ESPnetDataset(
- path_name_type_list=non_iterable_list,
- preprocess=preprocess,
- float_dtype=float_dtype,
- int_dtype=int_dtype,
- )
- else:
- self.non_iterable_dataset = None
-
- self.apply_utt2category = False
-
- def has_name(self, name) -> bool:
- return name in self.debug_info
-
- def names(self) -> Tuple[str, ...]:
- return tuple(self.debug_info)
-
- def __repr__(self):
- _mes = self.__class__.__name__
- _mes += '('
- for name, (path, _type) in self.debug_info.items():
- _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
- _mes += f'\n preprocess: {self.preprocess})'
- return _mes
-
- def __iter__(
- self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
- torch.set_printoptions(profile='default')
- count = len(self.path_name_type_list)
- for idx in range(count):
- # 2. Load the entry from each line and create a dict
- data = {}
- # 2.a. Load data streamingly
-
- # value: /home/fsc/code/MaaS/MaaS-lib-nls-asr/data/test/audios/asr_example.wav
- value = self.path_name_type_list[idx][0]['file']
- uid = self.path_name_type_list[idx][0]['key']
- # name: speech
- name = self.path_name_type_list[idx][1]
- _type = self.path_name_type_list[idx][2]
- func = DATA_TYPES[_type]
- array = func(value)
-
- # 2.b. audio resample
- if _type == 'sound':
- audio_sr: int = 16000
- model_sr: int = 16000
- if isinstance(self.sample_rate, int):
- model_sr = self.sample_rate
- else:
- if 'audio_sr' in self.sample_rate:
- audio_sr = self.sample_rate['audio_sr']
- if 'model_sr' in self.sample_rate:
- model_sr = self.sample_rate['model_sr']
- array = wav_utils.torch_resample(array, audio_sr, model_sr)
-
- # array: [ 1.25122070e-03 ... ]
- data[name] = array
-
- # 3. [Option] Apply preprocessing
- # e.g. espnet2.train.preprocessor:CommonPreprocessor
- if self.preprocess is not None:
- data = self.preprocess(uid, data)
- # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
-
- # 4. Force data-precision
- for name in data:
- # value is np.ndarray data
- value = data[name]
- if not isinstance(value, np.ndarray):
- raise RuntimeError(
- f'All values must be converted to np.ndarray object '
- f'by preprocessing, but "{name}" is still {type(value)}.'
- )
-
- # Cast to desired type
- if value.dtype.kind == 'f':
- value = value.astype(self.float_dtype)
- elif value.dtype.kind == 'i':
- value = value.astype(self.int_dtype)
- else:
- raise NotImplementedError(
- f'Not supported dtype: {value.dtype}')
- data[name] = value
-
- yield uid, data
-
- if count == 0:
- raise RuntimeError('No iteration')
-
-
-class IterableESPnetBytesModelScope(IterableDataset):
- """Pytorch audio bytes class for ESPNet.
-
- Examples:
- >>> dataset = IterableESPnetBytes([('audio bytes', 'input', 'sound'),
- ... ('token_int', 'output', 'text_int')],
- ... )
- >>> for uid, data in dataset:
- ... data
- {'input': per_utt_array, 'output': per_utt_array}
- """
- def __init__(self,
- path_name_type_list: Collection[Tuple[any, str, str]],
- preprocess: Callable[[str, Dict[str, np.ndarray]],
- Dict[str, np.ndarray]] = None,
- float_dtype: str = 'float32',
- int_dtype: str = 'long',
- key_file: str = None,
- sample_rate: Union[dict, int] = 16000):
- assert check_argument_types()
- if len(path_name_type_list) == 0:
- raise ValueError(
- '1 or more elements are required for "path_name_type_list"')
-
- self.preprocess = preprocess
-
- self.float_dtype = float_dtype
- self.int_dtype = int_dtype
- self.key_file = key_file
- self.sample_rate = sample_rate
-
- self.debug_info = {}
- non_iterable_list = []
- self.path_name_type_list = []
-
- audio_data = path_name_type_list[0]
- name = path_name_type_list[1]
- _type = path_name_type_list[2]
- if name in self.debug_info:
- raise RuntimeError(f'"{name}" is duplicated for data-key')
- self.debug_info[name] = audio_data, _type
- self.path_name_type_list.append((audio_data, name, _type))
-
- if len(non_iterable_list) != 0:
- # Some types doesn't support iterable mode
- self.non_iterable_dataset = ESPnetDataset(
- path_name_type_list=non_iterable_list,
- preprocess=preprocess,
- float_dtype=float_dtype,
- int_dtype=int_dtype,
- )
- else:
- self.non_iterable_dataset = None
-
- self.apply_utt2category = False
-
- if float_dtype == 'float32':
- self.np_dtype = np.float32
-
- def has_name(self, name) -> bool:
- return name in self.debug_info
-
- def names(self) -> Tuple[str, ...]:
- return tuple(self.debug_info)
-
- def __repr__(self):
- _mes = self.__class__.__name__
- _mes += '('
- for name, (path, _type) in self.debug_info.items():
- _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
- _mes += f'\n preprocess: {self.preprocess})'
- return _mes
-
- def __iter__(
- self) -> Iterator[Tuple[Union[str, int], Dict[str, np.ndarray]]]:
-
- torch.set_printoptions(profile='default')
- # 2. Load the entry from each line and create a dict
- data = {}
- # 2.a. Load data streamingly
-
- value = self.path_name_type_list[0][0]
- uid = 'pcm_data'
- # name: speech
- name = self.path_name_type_list[0][1]
- _type = self.path_name_type_list[0][2]
- func = DATA_TYPES[_type]
- # array: [ 1.25122070e-03 ... ]
- # data[name] = np.frombuffer(value, dtype=self.np_dtype)
-
- # 2.b. byte(PCM16) to float32
- middle_data = np.frombuffer(value, dtype=np.int16)
- middle_data = np.asarray(middle_data)
- if middle_data.dtype.kind not in 'iu':
- raise TypeError("'middle_data' must be an array of integers")
- dtype = np.dtype('float32')
- if dtype.kind != 'f':
- raise TypeError("'dtype' must be a floating point type")
-
- i = np.iinfo(middle_data.dtype)
- abs_max = 2**(i.bits - 1)
- offset = i.min + abs_max
- array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max,
- dtype=self.np_dtype)
-
- # 2.c. audio resample
- if _type == 'sound':
- audio_sr: int = 16000
- model_sr: int = 16000
- if isinstance(self.sample_rate, int):
- model_sr = self.sample_rate
- else:
- if 'audio_sr' in self.sample_rate:
- audio_sr = self.sample_rate['audio_sr']
- if 'model_sr' in self.sample_rate:
- model_sr = self.sample_rate['model_sr']
- array = wav_utils.torch_resample(array, audio_sr, model_sr)
-
- data[name] = array
-
- # 3. [Option] Apply preprocessing
- # e.g. espnet2.train.preprocessor:CommonPreprocessor
- if self.preprocess is not None:
- data = self.preprocess(uid, data)
- # data: {'speech': array([ 1.25122070e-03 ... 6.10351562e-03])}
-
- # 4. Force data-precision
- for name in data:
- # value is np.ndarray data
- value = data[name]
- if not isinstance(value, np.ndarray):
- raise RuntimeError(
- f'All values must be converted to np.ndarray object '
- f'by preprocessing, but "{name}" is still {type(value)}.')
-
- # Cast to desired type
- if value.dtype.kind == 'f':
- value = value.astype(self.float_dtype)
- elif value.dtype.kind == 'i':
- value = value.astype(self.int_dtype)
- else:
- raise NotImplementedError(
- f'Not supported dtype: {value.dtype}')
- data[name] = value
-
- yield uid, data
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 156f608..318ae0b 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -64,28 +64,26 @@
return self.sp.DecodePieces(list(tokens))
-class ArkDataLoader(AbsIterFactory):
- def __init__(self, data_list, dict_file, dataset_conf, frontend_conf=None, seg_dict_file=None, punc_dict_file=None,
- bpemodel_file=None, mode="train"):
- symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
- if seg_dict_file is not None:
- seg_dict = load_seg_dict(seg_dict_file)
- else:
- seg_dict = None
- if punc_dict_file is not None:
- punc_dict = read_symbol_table(punc_dict_file)
- else:
- punc_dict = None
- self.dataset_conf = dataset_conf
- self.frontend_conf = frontend_conf
+class LargeDataLoader(AbsIterFactory):
+ def __init__(self, args, mode="train"):
+ symbol_table, seg_dict, punc_dict, bpe_tokenizer = None, None, None, None
+ if hasattr(args, "token_list") and args.token_list is not None:
+ symbol_table = read_symbol_table(args.token_list)
+ if hasattr(args, "seg_dict_file") and args.seg_dict_file is not None:
+ seg_dict = load_seg_dict(args.seg_dict_file)
+ if hasattr(args, "punc_dict_file") and args.punc_dict_file is not None:
+ punc_dict = read_symbol_table(args.punc_dict_file)
+ if hasattr(args, "bpemodel_file") and args.bpemodel_file is not None:
+ bpe_tokenizer = SentencepiecesTokenizer(args.bpemodel_file)
+ self.dataset_conf = args.dataset_conf
+ self.frontend_conf = args.frontend_conf
logging.info("dataloader config: {}".format(self.dataset_conf))
batch_mode = self.dataset_conf.get("batch_mode", "padding")
- if bpemodel_file is not None:
- bpe_tokenizer = SentencepiecesTokenizer(bpemodel_file)
- else:
- bpe_tokenizer = None
+ data_list = args.train_data_file if mode == "train" else args.valid_data_file
self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict, bpe_tokenizer,
- self.dataset_conf, self.frontend_conf, mode=mode, batch_mode=batch_mode)
+ self.dataset_conf, self.frontend_conf,
+ speed_perturb=args.speed_perturb if mode == "train" else None,
+ mode=mode, batch_mode=batch_mode)
def build_iter(self, epoch, shuffle=True):
self.dataset.set_epoch(epoch)
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 8c224d8..5df61fd 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -1,20 +1,20 @@
+import logging
import os
import random
-import numpy
from functools import partial
import torch
-import torchaudio
import torch.distributed as dist
+import torchaudio
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
from funasr.datasets.large_datasets.datapipes.batch import MaxTokenBucketizerIterDataPipe
from funasr.datasets.large_datasets.datapipes.filter import FilterIterDataPipe
from funasr.datasets.large_datasets.datapipes.map import MapperIterDataPipe
+from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.filter import filter
from funasr.datasets.large_datasets.utils.padding import padding
-from funasr.datasets.large_datasets.utils.clipping import clipping
from funasr.datasets.large_datasets.utils.tokenize import tokenize
@@ -28,7 +28,8 @@
class AudioDataset(IterableDataset):
- def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, mode="train"):
+ def __init__(self, scp_lists, data_names, data_types, frontend_conf=None, shuffle=True, speed_perturb=None,
+ mode="train"):
self.scp_lists = scp_lists
self.data_names = data_names
self.data_types = data_types
@@ -40,6 +41,9 @@
self.world_size = 1
self.worker_id = 0
self.num_workers = 1
+ self.speed_perturb = speed_perturb
+ if self.speed_perturb is not None:
+ logging.info("Using speed_perturb: {}".format(speed_perturb))
def set_epoch(self, epoch):
self.epoch = epoch
@@ -124,9 +128,15 @@
if sampling_rate != self.frontend_conf["fs"]:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
new_freq=self.frontend_conf["fs"])(waveform)
- sampling_rate = self.frontend_conf["fs"]
+ sampling_rate = self.frontend_conf["fs"]
waveform = waveform.numpy()
mat = waveform[0]
+ if self.speed_perturb is not None:
+ speed = random.choice(self.speed_perturb)
+ if speed != 1.0:
+ mat, _ = torchaudio.sox_effects.apply_effects_tensor(
+ torch.tensor(mat).view(1, -1), sampling_rate, [['speed', str(speed)], ['rate', str(sampling_rate)]])
+ mat = mat.view(-1).numpy()
sample_dict[data_name] = mat
sample_dict["sampling_rate"] = sampling_rate
if data_name == "speech":
@@ -168,6 +178,7 @@
bpe_tokenizer,
conf,
frontend_conf,
+ speed_perturb=None,
mode="train",
batch_mode="padding"):
scp_lists = read_lists(data_list_file)
@@ -196,7 +207,8 @@
data_names,
data_types,
frontend_conf=frontend_conf,
- shuffle=shuffle,
+ shuffle=shuffle,
+ speed_perturb=speed_perturb,
mode=mode,
)
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
index f0f0c66..cf7d255 100644
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -48,7 +48,7 @@
vad = -2
if bpe_tokenizer is not None:
- text = bpe_tokenizer.text2tokens("".join(text))
+ text = bpe_tokenizer.text2tokens(text)
if seg_dict is not None:
assert isinstance(seg_dict, dict)
diff --git a/funasr/datasets/small_datasets/collate_fn.py b/funasr/datasets/small_datasets/collate_fn.py
new file mode 100644
index 0000000..573f581
--- /dev/null
+++ b/funasr/datasets/small_datasets/collate_fn.py
@@ -0,0 +1,93 @@
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Tuple
+from typing import Union
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.modules.nets_utils import pad_list
+
+
+class CommonCollateFn:
+ """Functor class of common_collate_fn()"""
+
+ def __init__(
+ self,
+ float_pad_value: Union[float, int] = 0.0,
+ int_pad_value: int = -32768,
+ not_sequence: Collection[str] = (),
+ max_sample_size=None
+ ):
+ assert check_argument_types()
+ self.float_pad_value = float_pad_value
+ self.int_pad_value = int_pad_value
+ self.not_sequence = set(not_sequence)
+ self.max_sample_size = max_sample_size
+
+ def __repr__(self):
+ return (
+ f"{self.__class__}(float_pad_value={self.float_pad_value}, "
+ f"int_pad_value={self.float_pad_value})"
+ )
+
+ def __call__(
+ self, data: Collection[Tuple[str, Dict[str, np.ndarray]]]
+ ) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+ return common_collate_fn(
+ data,
+ float_pad_value=self.float_pad_value,
+ int_pad_value=self.int_pad_value,
+ not_sequence=self.not_sequence,
+ )
+
+
+def common_collate_fn(
+ data: Collection[Tuple[str, Dict[str, np.ndarray]]],
+ float_pad_value: Union[float, int] = 0.0,
+ int_pad_value: int = -32768,
+ not_sequence: Collection[str] = (),
+) -> Tuple[List[str], Dict[str, torch.Tensor]]:
+ """Concatenate ndarray-list to an array and convert to torch.Tensor.
+ """
+ assert check_argument_types()
+ uttids = [u for u, _ in data]
+ data = [d for _, d in data]
+
+ assert all(set(data[0]) == set(d) for d in data), "dict-keys mismatching"
+ assert all(
+ not k.endswith("_lengths") for k in data[0]
+ ), f"*_lengths is reserved: {list(data[0])}"
+
+ output = {}
+ for key in data[0]:
+ if data[0][key].dtype.kind == "i":
+ pad_value = int_pad_value
+ else:
+ pad_value = float_pad_value
+
+ array_list = [d[key] for d in data]
+ tensor_list = [torch.from_numpy(a) for a in array_list]
+ tensor = pad_list(tensor_list, pad_value)
+ output[key] = tensor
+
+ if key not in not_sequence:
+ lens = torch.tensor([d[key].shape[0] for d in data], dtype=torch.long)
+ output[key + "_lengths"] = lens
+
+ output = (uttids, output)
+ assert check_return_type(output)
+ return output
+
+def crop_to_max_size(feature, target_size):
+ size = len(feature)
+ diff = size - target_size
+ if diff <= 0:
+ return feature
+
+ start = np.random.randint(0, diff + 1)
+ end = size - diff + start
+ return feature[start:end]
\ No newline at end of file
diff --git a/funasr/datasets/small_datasets/dataset.py b/funasr/datasets/small_datasets/dataset.py
new file mode 100644
index 0000000..e14e4f1
--- /dev/null
+++ b/funasr/datasets/small_datasets/dataset.py
@@ -0,0 +1,258 @@
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+import collections
+import copy
+import logging
+import numbers
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import Mapping
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import kaldiio
+import numpy as np
+import torch
+from torch.utils.data.dataset import Dataset
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.npy_scp import NpyScpReader
+from funasr.fileio.sound_scp import SoundScpReader
+
+
+class AdapterForSoundScpReader(collections.abc.Mapping):
+ def __init__(self, loader, dtype=None):
+ assert check_argument_types()
+ self.loader = loader
+ self.dtype = dtype
+ self.rate = None
+
+ def keys(self):
+ return self.loader.keys()
+
+ def __len__(self):
+ return len(self.loader)
+
+ def __iter__(self):
+ return iter(self.loader)
+
+ def __getitem__(self, key: str) -> np.ndarray:
+ retval = self.loader[key]
+
+ if isinstance(retval, tuple):
+ assert len(retval) == 2, len(retval)
+ if isinstance(retval[0], int) and isinstance(retval[1], np.ndarray):
+ # sound scp case
+ rate, array = retval
+ elif isinstance(retval[1], int) and isinstance(retval[0], np.ndarray):
+ # Extended ark format case
+ array, rate = retval
+ else:
+ raise RuntimeError(
+ f"Unexpected type: {type(retval[0])}, {type(retval[1])}"
+ )
+
+ if self.rate is not None and self.rate != rate:
+ raise RuntimeError(
+ f"Sampling rates are mismatched: {self.rate} != {rate}"
+ )
+ self.rate = rate
+ # Multichannel wave fie
+ # array: (NSample, Channel) or (Nsample)
+ if self.dtype is not None:
+ array = array.astype(self.dtype)
+
+ else:
+ # Normal ark case
+ assert isinstance(retval, np.ndarray), type(retval)
+ array = retval
+ if self.dtype is not None:
+ array = array.astype(self.dtype)
+
+ assert isinstance(array, np.ndarray), type(array)
+ return array
+
+
+def sound_loader(path, dest_sample_rate=16000, float_dtype=None):
+ # The file is as follows:
+ # utterance_id_A /some/where/a.wav
+ # utterance_id_B /some/where/a.flac
+
+ # NOTE(kamo): SoundScpReader doesn't support pipe-fashion
+ # like Kaldi e.g. "cat a.wav |".
+ # NOTE(kamo): The audio signal is normalized to [-1,1] range.
+ loader = SoundScpReader(path, dest_sample_rate, normalize=True, always_2d=False)
+
+ # SoundScpReader.__getitem__() returns Tuple[int, ndarray],
+ # but ndarray is desired, so Adapter class is inserted here
+ return AdapterForSoundScpReader(loader, float_dtype)
+
+
+def kaldi_loader(path, float_dtype=None, max_cache_fd: int = 0):
+ loader = kaldiio.load_scp(path, max_cache_fd=max_cache_fd)
+ return AdapterForSoundScpReader(loader, float_dtype)
+
+
+class ESPnetDataset(Dataset):
+ """
+ Pytorch Dataset class for FunASR, modified from ESPnet
+ """
+
+ def __init__(
+ self,
+ path_name_type_list: Collection[Tuple[str, str, str]],
+ preprocess: Callable[
+ [str, Dict[str, np.ndarray]], Dict[str, np.ndarray]
+ ] = None,
+ float_dtype: str = "float32",
+ int_dtype: str = "long",
+ dest_sample_rate: int = 16000,
+ speed_perturb: Union[list, tuple] = None,
+ mode: str = "train",
+ ):
+ assert check_argument_types()
+ if len(path_name_type_list) == 0:
+ raise ValueError(
+ '1 or more elements are required for "path_name_type_list"'
+ )
+
+ path_name_type_list = copy.deepcopy(path_name_type_list)
+ self.preprocess = preprocess
+
+ self.float_dtype = float_dtype
+ self.int_dtype = int_dtype
+ self.dest_sample_rate = dest_sample_rate
+ self.speed_perturb = speed_perturb
+ self.mode = mode
+ if self.speed_perturb is not None:
+ logging.info("Using speed_perturb: {}".format(speed_perturb))
+
+ self.loader_dict = {}
+ self.debug_info = {}
+ for path, name, _type in path_name_type_list:
+ if name in self.loader_dict:
+ raise RuntimeError(f'"{name}" is duplicated for data-key')
+
+ loader = self._build_loader(path, _type)
+ self.loader_dict[name] = loader
+ self.debug_info[name] = path, _type
+ if len(self.loader_dict[name]) == 0:
+ raise RuntimeError(f"{path} has no samples")
+
+ def _build_loader(
+ self, path: str, loader_type: str
+ ) -> Mapping[str, Union[np.ndarray, torch.Tensor, str, numbers.Number]]:
+ """Helper function to instantiate Loader.
+
+ Args:
+ path: The file path
+ loader_type: loader_type. sound, npy, text, etc
+ """
+ if loader_type == "sound":
+ speed_perturb = self.speed_perturb if self.mode == "train" else None
+ loader = SoundScpReader(path, self.dest_sample_rate, normalize=True, always_2d=False,
+ speed_perturb=speed_perturb)
+ return AdapterForSoundScpReader(loader, self.float_dtype)
+ elif loader_type == "kaldi_ark":
+ loader = kaldiio.load_scp(path)
+ return AdapterForSoundScpReader(loader, self.float_dtype)
+ elif loader_type == "npy":
+ return NpyScpReader(path)
+ elif loader_type == "text":
+ text_loader = {}
+ with open(path, "r", encoding="utf-8") as f:
+ for linenum, line in enumerate(f, 1):
+ sps = line.rstrip().split(maxsplit=1)
+ if len(sps) == 1:
+ k, v = sps[0], ""
+ else:
+ k, v = sps
+ if k in text_loader:
+ raise RuntimeError(f"{k} is duplicated ({path}:{linenum})")
+ text_loader[k] = v
+ return text_loader
+ else:
+ raise RuntimeError(f"Not supported: loader_type={loader_type}")
+
+ def has_name(self, name) -> bool:
+ return name in self.loader_dict
+
+ def names(self) -> Tuple[str, ...]:
+ return tuple(self.loader_dict)
+
+ def __iter__(self):
+ return iter(next(iter(self.loader_dict.values())))
+
+ def __repr__(self):
+ _mes = self.__class__.__name__
+ _mes += "("
+ for name, (path, _type) in self.debug_info.items():
+ _mes += f'\n {name}: {{"path": "{path}", "type": "{_type}"}}'
+ _mes += f"\n preprocess: {self.preprocess})"
+ return _mes
+
+ def __getitem__(self, uid: Union[str, int]) -> Tuple[str, Dict[str, np.ndarray]]:
+ assert check_argument_types()
+
+ # Change integer-id to string-id
+ if isinstance(uid, int):
+ d = next(iter(self.loader_dict.values()))
+ uid = list(d)[uid]
+
+ data = {}
+ # 1. Load data from each loaders
+ for name, loader in self.loader_dict.items():
+ try:
+ value = loader[uid]
+ if isinstance(value, (list, tuple)):
+ value = np.array(value)
+ if not isinstance(
+ value, (np.ndarray, torch.Tensor, str, numbers.Number)
+ ):
+ raise TypeError(
+ f"Must be ndarray, torch.Tensor, str or Number: {type(value)}"
+ )
+ except Exception:
+ path, _type = self.debug_info[name]
+ logging.error(
+ f"Error happened with path={path}, type={_type}, id={uid}"
+ )
+ raise
+
+ # torch.Tensor is converted to ndarray
+ if isinstance(value, torch.Tensor):
+ value = value.numpy()
+ elif isinstance(value, numbers.Number):
+ value = np.array([value])
+ data[name] = value
+
+ # 2. [Option] Apply preprocessing
+ # e.g. funasr.train.preprocessor:CommonPreprocessor
+ if self.preprocess is not None:
+ data = self.preprocess(uid, data)
+
+ # 3. Force data-precision
+ for name in data:
+ value = data[name]
+ if not isinstance(value, np.ndarray):
+ raise RuntimeError(
+ f"All values must be converted to np.ndarray object "
+ f'by preprocessing, but "{name}" is still {type(value)}.'
+ )
+
+ # Cast to desired type
+ if value.dtype.kind == "f":
+ value = value.astype(self.float_dtype)
+ elif value.dtype.kind == "i":
+ value = value.astype(self.int_dtype)
+ else:
+ raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+ data[name] = value
+
+ retval = uid, data
+ assert check_return_type(retval)
+ return retval
diff --git a/funasr/datasets/small_datasets/length_batch_sampler.py b/funasr/datasets/small_datasets/length_batch_sampler.py
new file mode 100644
index 0000000..8ee8bdc
--- /dev/null
+++ b/funasr/datasets/small_datasets/length_batch_sampler.py
@@ -0,0 +1,147 @@
+from typing import Iterator
+from typing import List
+from typing import Dict
+from typing import Tuple
+from typing import Union
+
+from typeguard import check_argument_types
+
+from funasr.fileio.read_text import load_num_sequence_text
+from funasr.samplers.abs_sampler import AbsSampler
+
+
+class LengthBatchSampler(AbsSampler):
+ def __init__(
+ self,
+ batch_bins: int,
+ shape_files: Union[Tuple[str, ...], List[str], Dict],
+ min_batch_size: int = 1,
+ sort_in_batch: str = "descending",
+ sort_batch: str = "ascending",
+ drop_last: bool = False,
+ padding: bool = True,
+ ):
+ assert check_argument_types()
+ assert batch_bins > 0
+ if sort_batch != "ascending" and sort_batch != "descending":
+ raise ValueError(
+ f"sort_batch must be ascending or descending: {sort_batch}"
+ )
+ if sort_in_batch != "descending" and sort_in_batch != "ascending":
+ raise ValueError(
+ f"sort_in_batch must be ascending or descending: {sort_in_batch}"
+ )
+
+ self.batch_bins = batch_bins
+ self.shape_files = shape_files
+ self.sort_in_batch = sort_in_batch
+ self.sort_batch = sort_batch
+ self.drop_last = drop_last
+
+ # utt2shape: (Length, ...)
+ # uttA 100,...
+ # uttB 201,...
+ if isinstance(shape_files, dict):
+ utt2shapes = [shape_files]
+ else:
+ utt2shapes = [
+ load_num_sequence_text(s, loader_type="csv_int") for s in shape_files
+ ]
+
+ first_utt2shape = utt2shapes[0]
+ for s, d in zip(shape_files, utt2shapes):
+ if set(d) != set(first_utt2shape):
+ raise RuntimeError(
+ f"keys are mismatched between {s} != {shape_files[0]}"
+ )
+
+ # Sort samples in ascending order
+ # (shape order should be like (Length, Dim))
+ keys = sorted(first_utt2shape, key=lambda k: first_utt2shape[k][0])
+ if len(keys) == 0:
+ raise RuntimeError(f"0 lines found: {shape_files[0]}")
+
+ # Decide batch-sizes
+ batch_sizes = []
+ current_batch_keys = []
+ for key in keys:
+ current_batch_keys.append(key)
+ # shape: (Length, dim1, dim2, ...)
+ if padding:
+ # bins = bs x max_length
+ bins = sum(len(current_batch_keys) * sh[key][0] for sh in utt2shapes)
+ else:
+ # bins = sum of lengths
+ bins = sum(d[k][0] for k in current_batch_keys for d in utt2shapes)
+
+ if bins > batch_bins and len(current_batch_keys) >= min_batch_size:
+ batch_sizes.append(len(current_batch_keys))
+ current_batch_keys = []
+ else:
+ if len(current_batch_keys) != 0 and (
+ not self.drop_last or len(batch_sizes) == 0
+ ):
+ batch_sizes.append(len(current_batch_keys))
+
+ if len(batch_sizes) == 0:
+ # Maybe we can't reach here
+ raise RuntimeError("0 batches")
+
+ # If the last batch-size is smaller than minimum batch_size,
+ # the samples are redistributed to the other mini-batches
+ if len(batch_sizes) > 1 and batch_sizes[-1] < min_batch_size:
+ for i in range(batch_sizes.pop(-1)):
+ batch_sizes[-(i % len(batch_sizes)) - 1] += 1
+
+ if not self.drop_last:
+ # Bug check
+ assert sum(batch_sizes) == len(keys), f"{sum(batch_sizes)} != {len(keys)}"
+
+ # Set mini-batch
+ self.batch_list = []
+ iter_bs = iter(batch_sizes)
+ bs = next(iter_bs)
+ minibatch_keys = []
+ for key in keys:
+ minibatch_keys.append(key)
+ if len(minibatch_keys) == bs:
+ if sort_in_batch == "descending":
+ minibatch_keys.reverse()
+ elif sort_in_batch == "ascending":
+ # Key are already sorted in ascending
+ pass
+ else:
+ raise ValueError(
+ "sort_in_batch must be ascending"
+ f" or descending: {sort_in_batch}"
+ )
+ self.batch_list.append(tuple(minibatch_keys))
+ minibatch_keys = []
+ try:
+ bs = next(iter_bs)
+ except StopIteration:
+ break
+
+ if sort_batch == "ascending":
+ pass
+ elif sort_batch == "descending":
+ self.batch_list.reverse()
+ else:
+ raise ValueError(
+ f"sort_batch must be ascending or descending: {sort_batch}"
+ )
+
+ def __repr__(self):
+ return (
+ f"{self.__class__.__name__}("
+ f"N-batch={len(self)}, "
+ f"batch_bins={self.batch_bins}, "
+ f"sort_in_batch={self.sort_in_batch}, "
+ f"sort_batch={self.sort_batch})"
+ )
+
+ def __len__(self):
+ return len(self.batch_list)
+
+ def __iter__(self) -> Iterator[Tuple[str, ...]]:
+ return iter(self.batch_list)
diff --git a/funasr/datasets/small_datasets/preprocessor.py b/funasr/datasets/small_datasets/preprocessor.py
new file mode 100644
index 0000000..d80f48a
--- /dev/null
+++ b/funasr/datasets/small_datasets/preprocessor.py
@@ -0,0 +1,875 @@
+from abc import ABC
+from abc import abstractmethod
+from pathlib import Path
+from typing import Collection
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Union
+
+import numpy as np
+import scipy.signal
+import soundfile
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.cleaner import TextCleaner
+from funasr.text.token_id_converter import TokenIDConverter
+
+
+class AbsPreprocessor(ABC):
+ def __init__(self, train: bool):
+ self.train = train
+
+ @abstractmethod
+ def __call__(
+ self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ raise NotImplementedError
+
+
+def forward_segment(text, dic):
+ word_list = []
+ i = 0
+ while i < len(text):
+ longest_word = text[i]
+ for j in range(i + 1, len(text) + 1):
+ word = text[i:j]
+ if word in dic:
+ if len(word) > len(longest_word):
+ longest_word = word
+ word_list.append(longest_word)
+ i += len(longest_word)
+ return word_list
+
+
+def seg_tokenize(txt, seg_dict):
+ out_txt = ""
+ for word in txt:
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+
+def seg_tokenize_wo_pattern(txt, seg_dict):
+ out_txt = ""
+ for word in txt:
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+
+def framing(
+ x,
+ frame_length: int = 512,
+ frame_shift: int = 256,
+ centered: bool = True,
+ padded: bool = True,
+):
+ if x.size == 0:
+ raise ValueError("Input array size is zero")
+ if frame_length < 1:
+ raise ValueError("frame_length must be a positive integer")
+ if frame_length > x.shape[-1]:
+ raise ValueError("frame_length is greater than input length")
+ if 0 >= frame_shift:
+ raise ValueError("frame_shift must be greater than 0")
+
+ if centered:
+ pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [
+ (frame_length // 2, frame_length // 2)
+ ]
+ x = np.pad(x, pad_shape, mode="constant", constant_values=0)
+
+ if padded:
+ # Pad to integer number of windowed segments
+ # I.e make x.shape[-1] = frame_length + (nseg-1)*nstep,
+ # with integer nseg
+ nadd = (-(x.shape[-1] - frame_length) % frame_shift) % frame_length
+ pad_shape = [(0, 0) for _ in range(x.ndim - 1)] + [(0, nadd)]
+ x = np.pad(x, pad_shape, mode="constant", constant_values=0)
+
+ # Created strided array of data segments
+ if frame_length == 1 and frame_length == frame_shift:
+ result = x[..., None]
+ else:
+ shape = x.shape[:-1] + (
+ (x.shape[-1] - frame_length) // frame_shift + 1,
+ frame_length,
+ )
+ strides = x.strides[:-1] + (frame_shift * x.strides[-1], x.strides[-1])
+ result = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
+ return result
+
+
+def detect_non_silence(
+ x: np.ndarray,
+ threshold: float = 0.01,
+ frame_length: int = 1024,
+ frame_shift: int = 512,
+ window: str = "boxcar",
+) -> np.ndarray:
+ """Power based voice activity detection.
+
+ Args:
+ x: (Channel, Time)
+ >>> x = np.random.randn(1000)
+ >>> detect = detect_non_silence(x)
+ >>> assert x.shape == detect.shape
+ >>> assert detect.dtype == np.bool
+ """
+ if x.shape[-1] < frame_length:
+ return np.full(x.shape, fill_value=True, dtype=np.bool)
+
+ if x.dtype.kind == "i":
+ x = x.astype(np.float64)
+ # framed_w: (C, T, F)
+ framed_w = framing(
+ x,
+ frame_length=frame_length,
+ frame_shift=frame_shift,
+ centered=False,
+ padded=True,
+ )
+ framed_w *= scipy.signal.get_window(window, frame_length).astype(framed_w.dtype)
+ # power: (C, T)
+ power = (framed_w ** 2).mean(axis=-1)
+ # mean_power: (C, 1)
+ mean_power = np.mean(power, axis=-1, keepdims=True)
+ if np.all(mean_power == 0):
+ return np.full(x.shape, fill_value=True, dtype=np.bool)
+ # detect_frames: (C, T)
+ detect_frames = power / mean_power > threshold
+ # detects: (C, T, F)
+ detects = np.broadcast_to(
+ detect_frames[..., None], detect_frames.shape + (frame_shift,)
+ )
+ # detects: (C, TF)
+ detects = detects.reshape(*detect_frames.shape[:-1], -1)
+ # detects: (C, TF)
+ return np.pad(
+ detects,
+ [(0, 0)] * (x.ndim - 1) + [(0, x.shape[-1] - detects.shape[-1])],
+ mode="edge",
+ )
+
+
+class CommonPreprocessor(AbsPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(train)
+ self.train = train
+ self.speech_name = speech_name
+ self.text_name = text_name
+ self.speech_volume_normalize = speech_volume_normalize
+ self.rir_apply_prob = rir_apply_prob
+ self.noise_apply_prob = noise_apply_prob
+ self.split_with_space = split_with_space
+ self.seg_dict = None
+ if seg_dict_file is not None:
+ self.seg_dict = {}
+ with open(seg_dict_file) as f:
+ lines = f.readlines()
+ for line in lines:
+ s = line.strip().split()
+ key = s[0]
+ value = s[1:]
+ self.seg_dict[key] = " ".join(value)
+
+ if token_type is not None:
+ if token_list is None:
+ raise ValueError("token_list is required if token_type is not None")
+ self.text_cleaner = TextCleaner(text_cleaner)
+
+ self.tokenizer = build_tokenizer(
+ token_type=token_type,
+ bpemodel=bpemodel,
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ self.token_id_converter = TokenIDConverter(
+ token_list=token_list,
+ unk_symbol=unk_symbol,
+ )
+ else:
+ self.text_cleaner = None
+ self.tokenizer = None
+ self.token_id_converter = None
+
+ if train and rir_scp is not None:
+ self.rirs = []
+ with open(rir_scp, "r", encoding="utf-8") as f:
+ for line in f:
+ sps = line.strip().split(None, 1)
+ if len(sps) == 1:
+ self.rirs.append(sps[0])
+ else:
+ self.rirs.append(sps[1])
+ else:
+ self.rirs = None
+
+ if train and noise_scp is not None:
+ self.noises = []
+ with open(noise_scp, "r", encoding="utf-8") as f:
+ for line in f:
+ sps = line.strip().split(None, 1)
+ if len(sps) == 1:
+ self.noises.append(sps[0])
+ else:
+ self.noises.append(sps[1])
+ sps = noise_db_range.split("_")
+ if len(sps) == 1:
+ self.noise_db_low, self.noise_db_high = float(sps[0])
+ elif len(sps) == 2:
+ self.noise_db_low, self.noise_db_high = float(sps[0]), float(sps[1])
+ else:
+ raise ValueError(
+ "Format error: '{noise_db_range}' e.g. -3_4 -> [-3db,4db]"
+ )
+ else:
+ self.noises = None
+
+ def _speech_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, Union[str, np.ndarray]]:
+ assert check_argument_types()
+ if self.speech_name in data:
+ if self.train and (self.rirs is not None or self.noises is not None):
+ speech = data[self.speech_name]
+ nsamples = len(speech)
+
+ # speech: (Nmic, Time)
+ if speech.ndim == 1:
+ speech = speech[None, :]
+ else:
+ speech = speech.T
+ # Calc power on non shlence region
+ power = (speech[detect_non_silence(speech)] ** 2).mean()
+
+ # 1. Convolve RIR
+ if self.rirs is not None and self.rir_apply_prob >= np.random.random():
+ rir_path = np.random.choice(self.rirs)
+ if rir_path is not None:
+ rir, _ = soundfile.read(
+ rir_path, dtype=np.float64, always_2d=True
+ )
+
+ # rir: (Nmic, Time)
+ rir = rir.T
+
+ # speech: (Nmic, Time)
+ # Note that this operation doesn't change the signal length
+ speech = scipy.signal.convolve(speech, rir, mode="full")[
+ :, : speech.shape[1]
+ ]
+ # Reverse mean power to the original power
+ power2 = (speech[detect_non_silence(speech)] ** 2).mean()
+ speech = np.sqrt(power / max(power2, 1e-10)) * speech
+
+ # 2. Add Noise
+ if (
+ self.noises is not None
+ and self.noise_apply_prob >= np.random.random()
+ ):
+ noise_path = np.random.choice(self.noises)
+ if noise_path is not None:
+ noise_db = np.random.uniform(
+ self.noise_db_low, self.noise_db_high
+ )
+ with soundfile.SoundFile(noise_path) as f:
+ if f.frames == nsamples:
+ noise = f.read(dtype=np.float64, always_2d=True)
+ elif f.frames < nsamples:
+ offset = np.random.randint(0, nsamples - f.frames)
+ # noise: (Time, Nmic)
+ noise = f.read(dtype=np.float64, always_2d=True)
+ # Repeat noise
+ noise = np.pad(
+ noise,
+ [(offset, nsamples - f.frames - offset), (0, 0)],
+ mode="wrap",
+ )
+ else:
+ offset = np.random.randint(0, f.frames - nsamples)
+ f.seek(offset)
+ # noise: (Time, Nmic)
+ noise = f.read(
+ nsamples, dtype=np.float64, always_2d=True
+ )
+ if len(noise) != nsamples:
+ raise RuntimeError(f"Something wrong: {noise_path}")
+ # noise: (Nmic, Time)
+ noise = noise.T
+
+ noise_power = (noise ** 2).mean()
+ scale = (
+ 10 ** (-noise_db / 20)
+ * np.sqrt(power)
+ / np.sqrt(max(noise_power, 1e-10))
+ )
+ speech = speech + scale * noise
+
+ speech = speech.T
+ ma = np.max(np.abs(speech))
+ if ma > 1.0:
+ speech /= ma
+ data[self.speech_name] = speech
+
+ if self.speech_volume_normalize is not None:
+ speech = data[self.speech_name]
+ ma = np.max(np.abs(speech))
+ data[self.speech_name] = speech * self.speech_volume_normalize / ma
+ assert check_return_type(data)
+ return data
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ if self.text_name in data and self.tokenizer is not None:
+ text = data[self.text_name]
+ text = self.text_cleaner(text)
+ if self.split_with_space:
+ tokens = text.strip().split(" ")
+ if self.seg_dict is not None:
+ tokens = forward_segment("".join(tokens), self.seg_dict)
+ tokens = seg_tokenize(tokens, self.seg_dict)
+ else:
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[self.text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+ def __call__(
+ self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ assert check_argument_types()
+
+ data = self._speech_process(data)
+ data = self._text_process(data)
+ return data
+
+
+## FIXME
+class LMPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(train,
+ token_type,
+ token_list,
+ bpemodel,
+ text_cleaner,
+ g2p_type,
+ unk_symbol,
+ space_symbol,
+ non_linguistic_symbols,
+ delimiter,
+ rir_scp,
+ rir_apply_prob,
+ noise_scp,
+ noise_apply_prob,
+ noise_db_range,
+ speech_volume_normalize,
+ speech_name,
+ text_name,
+ split_with_space,
+ seg_dict_file,
+ )
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ if self.text_name in data and self.tokenizer is not None:
+ text = data[self.text_name]
+ text = self.text_cleaner(text)
+ if self.split_with_space:
+ tokens = text.strip().split(" ")
+ if self.seg_dict is not None:
+ tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
+ else:
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[self.text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+
+class CommonPreprocessor_multi(AbsPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ speech_name: str = "speech",
+ text_name: List[str] = ["text"],
+ ):
+ super().__init__(train)
+ self.train = train
+ self.speech_name = speech_name
+ self.text_name = text_name
+
+ if token_type is not None:
+ if token_list is None:
+ raise ValueError("token_list is required if token_type is not None")
+ self.text_cleaner = TextCleaner(text_cleaner)
+
+ self.tokenizer = build_tokenizer(
+ token_type=token_type,
+ bpemodel=bpemodel,
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ self.token_id_converter = TokenIDConverter(
+ token_list=token_list,
+ unk_symbol=unk_symbol,
+ )
+ else:
+ self.text_cleaner = None
+ self.tokenizer = None
+ self.token_id_converter = None
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ for text_n in self.text_name:
+ if text_n in data and self.tokenizer is not None:
+ text = data[text_n]
+ text = self.text_cleaner(text)
+ tokens = self.tokenizer.text2tokens(text)
+ text_ints = self.token_id_converter.tokens2ids(tokens)
+ data[text_n] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+ def __call__(
+ self, uid: str, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ assert check_argument_types()
+
+ if self.speech_name in data:
+ # Nothing now: candidates:
+ # - STFT
+ # - Fbank
+ # - CMVN
+ # - Data augmentation
+ pass
+
+ data = self._text_process(data)
+ return data
+
+
+class MutliTokenizerCommonPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: List[str] = [None],
+ token_list: List[Union[Path, str, Iterable[str]]] = [None],
+ bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: List[str] = ["text"],
+ ):
+ # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+ super().__init__(
+ train=train,
+ token_type=token_type[0],
+ token_list=token_list[0],
+ bpemodel=bpemodel[0],
+ text_cleaner=text_cleaner,
+ g2p_type=g2p_type,
+ unk_symbol=unk_symbol,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ delimiter=delimiter,
+ speech_name=speech_name,
+ text_name=text_name[0],
+ rir_scp=rir_scp,
+ rir_apply_prob=rir_apply_prob,
+ noise_scp=noise_scp,
+ noise_apply_prob=noise_apply_prob,
+ noise_db_range=noise_db_range,
+ speech_volume_normalize=speech_volume_normalize,
+ )
+
+ assert (
+ len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+ ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+ self.num_tokenizer = len(token_type)
+ self.tokenizer = []
+ self.token_id_converter = []
+
+ for i in range(self.num_tokenizer):
+ if token_type[i] is not None:
+ if token_list[i] is None:
+ raise ValueError("token_list is required if token_type is not None")
+
+ self.tokenizer.append(
+ build_tokenizer(
+ token_type=token_type[i],
+ bpemodel=bpemodel[i],
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ )
+ self.token_id_converter.append(
+ TokenIDConverter(
+ token_list=token_list[i],
+ unk_symbol=unk_symbol,
+ )
+ )
+ else:
+ self.tokenizer.append(None)
+ self.token_id_converter.append(None)
+
+ self.text_cleaner = TextCleaner(text_cleaner)
+ self.text_name = text_name # override the text_name from CommonPreprocessor
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ for i in range(self.num_tokenizer):
+ text_name = self.text_name[i]
+ if text_name in data and self.tokenizer[i] is not None:
+ text = data[text_name]
+ text = self.text_cleaner(text)
+ tokens = self.tokenizer[i].text2tokens(text)
+ text_ints = self.token_id_converter[i].tokens2ids(tokens)
+ data[text_name] = np.array(text_ints, dtype=np.int64)
+ assert check_return_type(data)
+ return data
+
+
+class CodeMixTokenizerCommonPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: str = None,
+ token_list: Union[Path, str, Iterable[str]] = None,
+ bpemodel: Union[Path, str, Iterable[str]] = None,
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: str = "text",
+ split_text_name: str = "split_text",
+ split_with_space: bool = False,
+ seg_dict_file: str = None,
+ ):
+ super().__init__(
+ train=train,
+ # Force to use word.
+ token_type="word",
+ token_list=token_list,
+ bpemodel=bpemodel,
+ text_cleaner=text_cleaner,
+ g2p_type=g2p_type,
+ unk_symbol=unk_symbol,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ delimiter=delimiter,
+ speech_name=speech_name,
+ text_name=text_name,
+ rir_scp=rir_scp,
+ rir_apply_prob=rir_apply_prob,
+ noise_scp=noise_scp,
+ noise_apply_prob=noise_apply_prob,
+ noise_db_range=noise_db_range,
+ speech_volume_normalize=speech_volume_normalize,
+ split_with_space=split_with_space,
+ seg_dict_file=seg_dict_file,
+ )
+ # The data field name for split text.
+ self.split_text_name = split_text_name
+
+ @classmethod
+ def split_words(cls, text: str):
+ words = []
+ segs = text.split()
+ for seg in segs:
+ # There is no space in seg.
+ current_word = ""
+ for c in seg:
+ if len(c.encode()) == 1:
+ # This is an ASCII char.
+ current_word += c
+ else:
+ # This is a Chinese char.
+ if len(current_word) > 0:
+ words.append(current_word)
+ current_word = ""
+ words.append(c)
+ if len(current_word) > 0:
+ words.append(current_word)
+ return words
+
+ def __call__(
+ self, uid: str, data: Dict[str, Union[list, str, np.ndarray]]
+ ) -> Dict[str, Union[list, np.ndarray]]:
+ assert check_argument_types()
+ # Split words.
+ if isinstance(data[self.text_name], str):
+ split_text = self.split_words(data[self.text_name])
+ else:
+ split_text = data[self.text_name]
+ data[self.text_name] = " ".join(split_text)
+ data = self._speech_process(data)
+ data = self._text_process(data)
+ data[self.split_text_name] = split_text
+ return data
+
+ def pop_split_text_data(self, data: Dict[str, Union[str, np.ndarray]]):
+ result = data[self.split_text_name]
+ del data[self.split_text_name]
+ return result
+
+
+class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
+ def __init__(
+ self,
+ train: bool,
+ token_type: List[str] = [None],
+ token_list: List[Union[Path, str, Iterable[str]]] = [None],
+ bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+ text_cleaner: Collection[str] = None,
+ g2p_type: str = None,
+ unk_symbol: str = "<unk>",
+ space_symbol: str = "<space>",
+ non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+ delimiter: str = None,
+ rir_scp: str = None,
+ rir_apply_prob: float = 1.0,
+ noise_scp: str = None,
+ noise_apply_prob: float = 1.0,
+ noise_db_range: str = "3_10",
+ speech_volume_normalize: float = None,
+ speech_name: str = "speech",
+ text_name: List[str] = ["text"],
+ vad_name: str = "vad_indexes",
+ ):
+ # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+ super().__init__(
+ train=train,
+ token_type=token_type[0],
+ token_list=token_list[0],
+ bpemodel=bpemodel[0],
+ text_cleaner=text_cleaner,
+ g2p_type=g2p_type,
+ unk_symbol=unk_symbol,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ delimiter=delimiter,
+ speech_name=speech_name,
+ text_name=text_name[0],
+ rir_scp=rir_scp,
+ rir_apply_prob=rir_apply_prob,
+ noise_scp=noise_scp,
+ noise_apply_prob=noise_apply_prob,
+ noise_db_range=noise_db_range,
+ speech_volume_normalize=speech_volume_normalize,
+ )
+
+ assert (
+ len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+ ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+ self.num_tokenizer = len(token_type)
+ self.tokenizer = []
+ self.token_id_converter = []
+
+ for i in range(self.num_tokenizer):
+ if token_type[i] is not None:
+ if token_list[i] is None:
+ raise ValueError("token_list is required if token_type is not None")
+
+ self.tokenizer.append(
+ build_tokenizer(
+ token_type=token_type[i],
+ bpemodel=bpemodel[i],
+ delimiter=delimiter,
+ space_symbol=space_symbol,
+ non_linguistic_symbols=non_linguistic_symbols,
+ g2p_type=g2p_type,
+ )
+ )
+ self.token_id_converter.append(
+ TokenIDConverter(
+ token_list=token_list[i],
+ unk_symbol=unk_symbol,
+ )
+ )
+ else:
+ self.tokenizer.append(None)
+ self.token_id_converter.append(None)
+
+ self.text_cleaner = TextCleaner(text_cleaner)
+ self.text_name = text_name # override the text_name from CommonPreprocessor
+ self.vad_name = vad_name
+
+ def _text_process(
+ self, data: Dict[str, Union[str, np.ndarray]]
+ ) -> Dict[str, np.ndarray]:
+ for i in range(self.num_tokenizer):
+ text_name = self.text_name[i]
+ if text_name in data and self.tokenizer[i] is not None:
+ text = data[text_name]
+ text = self.text_cleaner(text)
+ tokens = self.tokenizer[i].text2tokens(text)
+ if "vad:" in tokens[-1]:
+ vad = tokens[-1][4:]
+ tokens = tokens[:-1]
+ if len(vad) == 0:
+ vad = -1
+ else:
+ vad = int(vad)
+ data[self.vad_name] = np.array([vad], dtype=np.int64)
+ text_ints = self.token_id_converter[i].tokens2ids(tokens)
+ data[text_name] = np.array(text_ints, dtype=np.int64)
+
+
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+ assert word_limit > 1
+ if len(words) <= word_limit:
+ return [words]
+ sentences = []
+ length = len(words)
+ sentence_len = length // word_limit
+ for i in range(sentence_len):
+ sentences.append(words[i * word_limit:(i + 1) * word_limit])
+ if length % word_limit > 0:
+ sentences.append(words[sentence_len * word_limit:])
+ return sentences
+
+
+def build_preprocess(args, train):
+ if not args.use_preprocessor:
+ return None
+ if args.task_name in ["asr", "data2vec", "diar", "sv"]:
+ retval = CommonPreprocessor(
+ train=train,
+ token_type=args.token_type,
+ token_list=args.token_list,
+ bpemodel=args.bpemodel,
+ non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+ seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
+ rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+ rir_apply_prob=args.rir_apply_prob if hasattr(args, "rir_apply_prob") else 1.0,
+ noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+ noise_apply_prob=args.noise_apply_prob if hasattr(args, "noise_apply_prob") else 1.0,
+ noise_db_range=args.noise_db_range if hasattr(args, "noise_db_range") else "13_15",
+ speech_volume_normalize=args.speech_volume_normalize if hasattr(args, "rir_scp") else None,
+ )
+ elif args.task_name == "punc":
+ token_types = [args.token_type, args.token_type]
+ token_lists = [args.token_list, args.punc_list]
+ bpemodels = [args.bpemodel, args.bpemodel]
+ text_names = ["text", "punc"]
+ retval = PuncTrainTokenizerCommonPreprocessor(
+ train=train,
+ token_type=token_types,
+ token_list=token_lists,
+ bpemodel=bpemodels,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ text_name=text_names,
+ non_linguistic_symbols=args.non_linguistic_symbols,
+ )
+ elif args.task_name == "lm":
+ retval = LMPreprocessor(
+ train=train,
+ token_type=args.token_type,
+ token_list=args.token_list,
+ bpemodel=args.bpemodel,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ text_name="text",
+ non_linguistic_symbols=args.non_linguistic_symbols,
+ split_with_space=args.split_with_space,
+ seg_dict_file=args.seg_dict_file
+ )
+ elif args.task_name == "vad":
+ retval = None
+ else:
+ raise ValueError(f"Not supported task={args.task_name}")
+ return retval
diff --git a/funasr/datasets/small_datasets/sequence_iter_factory.py b/funasr/datasets/small_datasets/sequence_iter_factory.py
new file mode 100644
index 0000000..3ebcc5a
--- /dev/null
+++ b/funasr/datasets/small_datasets/sequence_iter_factory.py
@@ -0,0 +1,189 @@
+import logging
+
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
+from funasr.datasets.small_datasets.dataset import ESPnetDataset
+from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
+from funasr.datasets.small_datasets.preprocessor import build_preprocess
+from funasr.iterators.abs_iter_factory import AbsIterFactory
+from funasr.samplers.abs_sampler import AbsSampler
+
+
+class RawSampler(AbsSampler):
+ def __init__(self, batches):
+ self.batches = batches
+
+ def __len__(self):
+ return len(self.batches)
+
+ def __iter__(self):
+ return iter(self.batches)
+
+ def generate(self, seed):
+ return list(self.batches)
+
+
+class SequenceIterFactory(AbsIterFactory):
+ """Build iterator for each epoch, modified from ESPnet
+
+ """
+
+ def __init__(self, args, mode="train"):
+
+ # preprocess
+ preprocess_fn = build_preprocess(args, train=mode == "train")
+
+ # collate
+ if args.task_name in ["punc", "lm"]:
+ collate_fn = CommonCollateFn(int_pad_value=0)
+ else:
+ collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+ # dataset
+ dest_sample_rate = args.frontend_conf["fs"] if (
+ args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
+ if mode == "train":
+ data_path_and_name_and_type = args.train_data_path_and_name_and_type
+ shape_files = args.train_shape_file
+ elif mode == "valid":
+ data_path_and_name_and_type = args.valid_data_path_and_name_and_type
+ shape_files = args.valid_shape_file
+ else:
+ raise NotImplementedError(f"mode={mode}")
+ dataset = ESPnetDataset(
+ data_path_and_name_and_type,
+ preprocess=preprocess_fn,
+ dest_sample_rate=dest_sample_rate,
+ speed_perturb=args.speed_perturb if mode=="train" else None,
+ )
+
+ # sampler
+ dataset_conf = args.dataset_conf
+ batch_sampler = LengthBatchSampler(
+ batch_bins=dataset_conf["batch_conf"]["batch_size"] * args.ngpu,
+ shape_files=shape_files,
+ sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending",
+ sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending",
+ drop_last=False,
+ padding=True,
+ )
+
+ batches = list(batch_sampler)
+ bs_list = [len(batch) for batch in batches]
+ logging.info(f"[{mode}] dataset:\n{dataset}")
+ logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
+ logging.info(
+ f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
+ f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
+ )
+
+ if args.scheduler == "tri_stage" and mode == "train":
+ args.max_update = len(bs_list) * args.max_epoch
+ logging.info("Max update: {}".format(args.max_update))
+
+ if args.distributed and mode=="train":
+ world_size = torch.distributed.get_world_size()
+ rank = torch.distributed.get_rank()
+ for batch in batches:
+ if len(batch) < world_size:
+ raise RuntimeError(
+ f"The batch-size must be equal or more than world_size: "
+ f"{len(batch)} < {world_size}"
+ )
+ batches = [batch[rank::world_size] for batch in batches]
+
+ if not isinstance(batches, AbsSampler):
+ self.sampler = RawSampler(batches)
+ else:
+ self.sampler = batches
+
+ self.dataset = dataset
+ self.num_iters_per_epoch = None
+ self.shuffle = mode == "train"
+ self.seed = args.seed
+ self.num_workers = args.dataset_conf.get("num_workers", 8)
+ self.collate_fn = collate_fn
+ self.pin_memory = args.ngpu > 0
+
+ def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
+ if shuffle is None:
+ shuffle = self.shuffle
+
+ if self.num_iters_per_epoch is not None:
+ N = len(self.sampler)
+ # If corpus size is larger than the num_per_epoch
+ if self.num_iters_per_epoch < N:
+ N = len(self.sampler)
+ real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
+
+ if offset >= self.num_iters_per_epoch:
+ current_batches = self.sampler.generate(real_epoch + self.seed)
+ if shuffle:
+ np.random.RandomState(real_epoch + self.seed).shuffle(
+ current_batches
+ )
+ batches = current_batches[
+ offset - self.num_iters_per_epoch: offset
+ ]
+ else:
+ prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
+ current_batches = self.sampler.generate(real_epoch + self.seed)
+ if shuffle:
+ np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
+ prev_batches
+ )
+ np.random.RandomState(real_epoch + self.seed).shuffle(
+ current_batches
+ )
+ batches = (
+ prev_batches[offset - self.num_iters_per_epoch:]
+ + current_batches[:offset]
+ )
+
+ # If corpus size is less than the num_per_epoch
+ else:
+ _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
+ _remain = self.num_iters_per_epoch
+ batches = []
+ current_batches = self.sampler.generate(_epoch + self.seed)
+ if shuffle:
+ np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
+ while _remain > 0:
+
+ _batches = current_batches[_cursor: _cursor + _remain]
+ batches += _batches
+ if _cursor + _remain >= N:
+ _epoch += 1
+ _cursor = 0
+ current_batches = self.sampler.generate(_epoch + self.seed)
+ if shuffle:
+ np.random.RandomState(_epoch + self.seed).shuffle(
+ current_batches
+ )
+ else:
+ _cursor = _cursor + _remain
+ _remain -= len(_batches)
+
+ assert len(batches) == self.num_iters_per_epoch
+
+ else:
+ batches = self.sampler.generate(epoch + self.seed)
+ if shuffle:
+ np.random.RandomState(epoch + self.seed).shuffle(batches)
+
+ # For backward compatibility for pytorch DataLoader
+ if self.collate_fn is not None:
+ kwargs = dict(collate_fn=self.collate_fn)
+ else:
+ kwargs = {}
+
+ return DataLoader(
+ dataset=self.dataset,
+ batch_sampler=batches,
+ num_workers=self.num_workers,
+ pin_memory=self.pin_memory,
+ **kwargs,
+ )
diff --git a/funasr/fileio/sound_scp.py b/funasr/fileio/sound_scp.py
index d757f7f..c752fe6 100644
--- a/funasr/fileio/sound_scp.py
+++ b/funasr/fileio/sound_scp.py
@@ -2,10 +2,14 @@
from pathlib import Path
from typing import Union
+import random
import numpy as np
import soundfile
import librosa
from typeguard import check_argument_types
+
+import torch
+import torchaudio
from funasr.fileio.read_text import read_2column_text
@@ -32,6 +36,7 @@
always_2d: bool = False,
normalize: bool = False,
dest_sample_rate: int = 16000,
+ speed_perturb: Union[list, tuple] = None,
):
assert check_argument_types()
self.fname = fname
@@ -40,6 +45,7 @@
self.normalize = normalize
self.data = read_2column_text(fname)
self.dest_sample_rate = dest_sample_rate
+ self.speed_perturb = speed_perturb
def __getitem__(self, key):
wav = self.data[key]
@@ -53,8 +59,17 @@
wav, sr=self.dest_sample_rate, mono=self.always_2d, dtype=self.dtype
)
+ if self.speed_perturb is not None:
+ speed = random.choice(self.speed_perturb)
+ if speed != 1.0:
+ array, _ = torchaudio.sox_effects.apply_effects_tensor(
+ torch.tensor(array).view(1, -1), rate,
+ [['speed', str(speed)], ['rate', str(rate)]])
+ array = array.view(-1).numpy()
+
if array.ndim==2:
array=array.transpose((1, 0))
+
return rate, array
def get_path(self, key):
diff --git a/funasr/layers/abs_normalize.py b/funasr/layers/abs_normalize.py
index f2be748..4e617d0 100644
--- a/funasr/layers/abs_normalize.py
+++ b/funasr/layers/abs_normalize.py
@@ -11,4 +11,4 @@
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# return output, output_lengths
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/layers/global_mvn.py b/funasr/layers/global_mvn.py
index 5515cdd..8e43582 100644
--- a/funasr/layers/global_mvn.py
+++ b/funasr/layers/global_mvn.py
@@ -13,9 +13,7 @@
class GlobalMVN(AbsNormalize, InversibleInterface):
"""Apply global mean and variance normalization
-
TODO(kamo): Make this class portable somehow
-
Args:
stats_file: npy file
norm_means: Apply mean normalization
@@ -66,7 +64,6 @@
self, x: torch.Tensor, ilens: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward function
-
Args:
x: (B, L, ...)
ilens: (B,)
@@ -118,4 +115,4 @@
if norm_means:
x += self.mean
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
- return x, ilens
+ return x, ilens
\ No newline at end of file
diff --git a/funasr/layers/inversible_interface.py b/funasr/layers/inversible_interface.py
index a1a5939..657ec68 100644
--- a/funasr/layers/inversible_interface.py
+++ b/funasr/layers/inversible_interface.py
@@ -11,4 +11,4 @@
self, input: torch.Tensor, input_lengths: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
# return output, output_lengths
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/lm/abs_model.py b/funasr/lm/abs_model.py
index 1f3c8a7..560879e 100644
--- a/funasr/lm/abs_model.py
+++ b/funasr/lm/abs_model.py
@@ -15,7 +15,7 @@
from funasr.modules.nets_utils import make_pad_mask
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
"""The abstract LM class
@@ -39,7 +39,7 @@
raise NotImplementedError
-class LanguageModel(AbsESPnetModel):
+class LanguageModel(FunASRModel):
def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
assert check_argument_types()
super().__init__()
diff --git a/funasr/main_funcs/calculate_all_attentions.py b/funasr/main_funcs/calculate_all_attentions.py
index 8f238c6..c3bf015 100644
--- a/funasr/main_funcs/calculate_all_attentions.py
+++ b/funasr/main_funcs/calculate_all_attentions.py
@@ -21,12 +21,12 @@
from funasr.modules.attention import MultiHeadedAttention
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
@torch.no_grad()
def calculate_all_attentions(
- model: AbsESPnetModel, batch: Dict[str, torch.Tensor]
+ model: FunASRModel, batch: Dict[str, torch.Tensor]
) -> Dict[str, List[torch.Tensor]]:
"""Derive the outputs from the all attention layers
diff --git a/funasr/main_funcs/collect_stats.py b/funasr/main_funcs/collect_stats.py
index bacda8f..584b85a 100644
--- a/funasr/main_funcs/collect_stats.py
+++ b/funasr/main_funcs/collect_stats.py
@@ -17,12 +17,12 @@
from funasr.fileio.npy_scp import NpyScpWriter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
@torch.no_grad()
def collect_stats(
- model: AbsESPnetModel,
+ model: FunASRModel,
train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]],
output_dir: Path,
diff --git a/funasr/models/base_model.py b/funasr/models/base_model.py
new file mode 100644
index 0000000..80b3bbd
--- /dev/null
+++ b/funasr/models/base_model.py
@@ -0,0 +1,17 @@
+import torch
+
+
+class FunASRModel(torch.nn.Module):
+ """The common model class
+
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.num_updates = 0
+
+ def set_num_updates(self, num_updates):
+ self.num_updates = num_updates
+
+ def get_num_updates(self):
+ return self.num_updates
diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py
index fcd6bd2..e5bd640 100644
--- a/funasr/models/data2vec.py
+++ b/funasr/models/data2vec.py
@@ -18,7 +18,7 @@
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -29,7 +29,7 @@
yield
-class Data2VecPretrainModel(AbsESPnetModel):
+class Data2VecPretrainModel(FunASRModel):
"""Data2Vec Pretrain model"""
def __init__(
@@ -57,7 +57,6 @@
speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -106,7 +105,6 @@
speech_lengths: torch.Tensor,
):
"""Frontend + Encoder.
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py
index f64ea3d..e6e6a52 100644
--- a/funasr/models/e2e_asr.py
+++ b/funasr/models/e2e_asr.py
@@ -28,7 +28,7 @@
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -39,7 +39,7 @@
yield
-class ESPnetASRModel(AbsESPnetModel):
+class ASRModel(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
@@ -49,9 +49,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -64,6 +62,8 @@
sym_space: str = "<space>",
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -133,7 +133,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -249,7 +248,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -331,9 +329,7 @@
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
-
Normally, this function is called in batchify_nll.
-
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@@ -370,7 +366,6 @@
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
-
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
diff --git a/funasr/models/e2e_asr_mfcca.py b/funasr/models/e2e_asr_mfcca.py
index f22f12a..fbf0d11 100644
--- a/funasr/models/e2e_asr_mfcca.py
+++ b/funasr/models/e2e_asr_mfcca.py
@@ -23,7 +23,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -35,7 +35,8 @@
import pdb
import random
import math
-class MFCCA(AbsESPnetModel):
+
+class MFCCA(FunASRModel):
"""
Author: Audio, Speech and Language Processing Group (ASLP@NPU), Northwestern Polytechnical University
MFCCA:Multi-Frame Cross-Channel attention for multi-speaker ASR in Multi-party meeting scenario
@@ -43,26 +44,26 @@
"""
def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
- encoder: AbsEncoder,
- decoder: AbsDecoder,
- ctc: CTC,
- rnnt_decoder: None,
- ctc_weight: float = 0.5,
- ignore_id: int = -1,
- lsm_weight: float = 0.0,
- mask_ratio: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ encoder: AbsEncoder,
+ decoder: AbsDecoder,
+ ctc: CTC,
+ rnnt_decoder: None,
+ ctc_weight: float = 0.5,
+ ignore_id: int = -1,
+ lsm_weight: float = 0.0,
+ mask_ratio: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ preencoder: Optional[AbsPreEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -76,10 +77,9 @@
self.ignore_id = ignore_id
self.ctc_weight = ctc_weight
self.token_list = token_list.copy()
-
+
self.mask_ratio = mask_ratio
-
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
@@ -113,14 +113,13 @@
self.error_calculator = None
def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -130,22 +129,22 @@
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- == text.shape[0]
- == text_lengths.shape[0]
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
- #pdb.set_trace()
- if(speech.dim()==3 and speech.size(2)==8 and self.mask_ratio !=0):
+ # pdb.set_trace()
+ if (speech.dim() == 3 and speech.size(2) == 8 and self.mask_ratio != 0):
rate_num = random.random()
- #rate_num = 0.1
- if(rate_num<=self.mask_ratio):
- retain_channel = math.ceil(random.random() *8)
- if(retain_channel>1):
- speech = speech[:,:,torch.randperm(8)[0:retain_channel].sort().values]
+ # rate_num = 0.1
+ if (rate_num <= self.mask_ratio):
+ retain_channel = math.ceil(random.random() * 8)
+ if (retain_channel > 1):
+ speech = speech[:, :, torch.randperm(8)[0:retain_channel].sort().values]
else:
- speech = speech[:,:,torch.randperm(8)[0]]
- #pdb.set_trace()
+ speech = speech[:, :, torch.randperm(8)[0]]
+ # pdb.set_trace()
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
@@ -195,20 +194,19 @@
return loss, stats, weight
def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
) -> Dict[str, torch.Tensor]:
feats, feats_lengths, channel_size = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -227,14 +225,14 @@
# Pre-encoder, e.g. used for raw input data
if self.preencoder is not None:
feats, feats_lengths = self.preencoder(feats, feats_lengths)
- #pdb.set_trace()
+ # pdb.set_trace()
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths, channel_size)
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
speech.size(0),
)
- if(encoder_out.dim()==4):
+ if (encoder_out.dim() == 4):
assert encoder_out.size(2) <= encoder_out_lens.max(), (
encoder_out.size(),
encoder_out_lens.max(),
@@ -248,7 +246,7 @@
return encoder_out, encoder_out_lens
def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
# for data-parallel
@@ -266,11 +264,11 @@
return feats, feats_lengths, channel_size
def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
):
ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_in_lens = ys_pad_lens + 1
@@ -298,14 +296,14 @@
return loss_att, acc_att, cer_att, wer_att
def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
):
# Calc CTC loss
- if(encoder_out.dim()==4):
+ if (encoder_out.dim() == 4):
encoder_out = encoder_out.mean(1)
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
@@ -317,10 +315,10 @@
return loss_ctc, cer_ctc
def _calc_rnnt_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
):
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index d02783f..00e08b1 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -29,9 +29,8 @@
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.models.predictor.cif import CifPredictorV3
-
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -42,7 +41,7 @@
yield
-class Paraformer(AbsESPnetModel):
+class Paraformer(FunASRModel):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
@@ -56,9 +55,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -79,6 +76,8 @@
predictor_bias: int = 0,
sampling_ratio: float = 0.2,
share_embedding: bool = False,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -153,7 +152,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -270,7 +268,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -368,9 +365,7 @@
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
-
Normally, this function is called in batchify_nll.
-
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@@ -407,7 +402,6 @@
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
-
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
@@ -664,7 +658,10 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor, cache: dict = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
+<<<<<<< HEAD
+=======
+>>>>>>> 4cd79db451786548d8a100f25c3b03da0eb30f4b
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -738,9 +735,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -763,6 +758,8 @@
embeds_id: int = 2,
embeds_loss_weight: float = 0.0,
embed_dims: int = 768,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -894,7 +891,6 @@
embed_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -913,9 +909,9 @@
self.step_cur += 1
# for data-parallel
text = text[:, : text_lengths.max()]
- speech = speech[:, :speech_lengths.max(), :]
+ speech = speech[:, :speech_lengths.max()]
if embed is not None:
- embed = embed[:, :embed_lengths.max(), :]
+ embed = embed[:, :embed_lengths.max()]
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
@@ -1003,74 +999,73 @@
class BiCifParaformer(Paraformer):
-
"""
Paraformer model with an extra cif predictor
to conduct accurate timestamp prediction
"""
def __init__(
- self,
- vocab_size: int,
- token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
- encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
- decoder: AbsDecoder,
- ctc: CTC,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- extract_feats_in_collect_stats: bool = True,
- predictor = None,
- predictor_weight: float = 0.0,
- predictor_bias: int = 0,
- sampling_ratio: float = 0.2,
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ encoder: AbsEncoder,
+ decoder: AbsDecoder,
+ ctc: CTC,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ extract_feats_in_collect_stats: bool = True,
+ predictor=None,
+ predictor_weight: float = 0.0,
+ predictor_bias: int = 0,
+ sampling_ratio: float = 0.2,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
super().__init__(
- vocab_size=vocab_size,
- token_list=token_list,
- frontend=frontend,
- specaug=specaug,
- normalize=normalize,
- preencoder=preencoder,
- encoder=encoder,
- postencoder=postencoder,
- decoder=decoder,
- ctc=ctc,
- ctc_weight=ctc_weight,
- interctc_weight=interctc_weight,
- ignore_id=ignore_id,
- blank_id=blank_id,
- sos=sos,
- eos=eos,
- lsm_weight=lsm_weight,
- length_normalized_loss=length_normalized_loss,
- report_cer=report_cer,
- report_wer=report_wer,
- sym_space=sym_space,
- sym_blank=sym_blank,
- extract_feats_in_collect_stats=extract_feats_in_collect_stats,
- predictor=predictor,
- predictor_weight=predictor_weight,
- predictor_bias=predictor_bias,
- sampling_ratio=sampling_ratio,
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ postencoder=postencoder,
+ decoder=decoder,
+ ctc=ctc,
+ ctc_weight=ctc_weight,
+ interctc_weight=interctc_weight,
+ ignore_id=ignore_id,
+ blank_id=blank_id,
+ sos=sos,
+ eos=eos,
+ lsm_weight=lsm_weight,
+ length_normalized_loss=length_normalized_loss,
+ report_cer=report_cer,
+ report_wer=report_wer,
+ sym_space=sym_space,
+ sym_blank=sym_blank,
+ extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+ predictor=predictor,
+ predictor_weight=predictor_weight,
+ predictor_bias=predictor_bias,
+ sampling_ratio=sampling_ratio,
)
assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
@@ -1145,21 +1140,23 @@
cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
return loss_att, acc_att, cer_att, wer_att, loss_pre
-
+
def calc_predictor(self, encoder_out, encoder_out_lens):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, None, encoder_out_mask,
- ignore_id=self.ignore_id)
+ pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out,
+ None,
+ encoder_out_mask,
+ ignore_id=self.ignore_id)
return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
-
+
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
+ encoder_out_mask,
+ token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def forward(
@@ -1170,7 +1167,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -1253,7 +1249,8 @@
elif self.ctc_weight == 1.0:
loss = loss_ctc
else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
+ loss = self.ctc_weight * loss_ctc + (
+ 1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight * 0.5
# Collect Attn branch stats
stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -1282,9 +1279,7 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
encoder: AbsEncoder,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
@@ -1314,6 +1309,8 @@
bias_encoder_type: str = 'lstm',
label_bracket: bool = False,
use_decoder_embedding: bool = False,
+ preencoder: Optional[AbsPreEncoder] = None,
+ postencoder: Optional[AbsPostEncoder] = None,
):
assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
@@ -1377,7 +1374,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -1761,4 +1757,4 @@
"torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
var_dict_tf[name_tf].shape))
- return var_dict_torch_update
+ return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/e2e_diar_eend_ola.py b/funasr/models/e2e_diar_eend_ola.py
index 097b23a..da7c674 100644
--- a/funasr/models/e2e_diar_eend_ola.py
+++ b/funasr/models/e2e_diar_eend_ola.py
@@ -16,7 +16,7 @@
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.modules.eend_ola.utils.power import generate_mapping_dict
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
pass
@@ -34,7 +34,7 @@
return att
-class DiarEENDOLAModel(AbsESPnetModel):
+class DiarEENDOLAModel(FunASRModel):
"""EEND-OLA diarization model"""
def __init__(
@@ -91,7 +91,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index 3f7011d..9c3fb92 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -22,7 +22,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss, SequenceBinaryCrossEntropy
from funasr.utils.misc import int2vec
@@ -35,7 +35,7 @@
yield
-class DiarSondModel(AbsESPnetModel):
+class DiarSondModel(FunASRModel):
"""
Author: Speech Lab, Alibaba Group, China
SOND: Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis
@@ -115,7 +115,6 @@
binary_labels_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
-
Args:
speech: (Batch, samples) or (Batch, frames, input_size)
speech_lengths: (Batch,) default None for chunk interator,
@@ -391,7 +390,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch,)
@@ -491,4 +489,4 @@
speaker_miss,
speaker_falarm,
speaker_error,
- )
+ )
\ No newline at end of file
diff --git a/funasr/models/e2e_sv.py b/funasr/models/e2e_sv.py
index 5b21277..bd5178e 100644
--- a/funasr/models/e2e_sv.py
+++ b/funasr/models/e2e_sv.py
@@ -29,7 +29,7 @@
from funasr.modules.e2e_asr_common import ErrorCalculator
from funasr.modules.nets_utils import th_accuracy
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -40,7 +40,7 @@
yield
-class ESPnetSVModel(AbsESPnetModel):
+class ESPnetSVModel(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
@@ -80,7 +80,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -221,7 +220,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -271,4 +269,4 @@
else:
# No frontend and no feature extract
feats, feats_lengths = speech, speech_lengths
- return feats, feats_lengths
+ return feats, feats_lengths
\ No newline at end of file
diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
index d1367ab..33948f9 100644
--- a/funasr/models/e2e_tp.py
+++ b/funasr/models/e2e_tp.py
@@ -17,9 +17,8 @@
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.models.predictor.cif import CifPredictorV3
-
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -30,7 +29,7 @@
yield
-class TimestampPredictor(AbsESPnetModel):
+class TimestampPredictor(FunASRModel):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
@@ -56,7 +55,7 @@
self.predictor_bias = predictor_bias
self.criterion_pre = mae_loss()
self.token_list = token_list
-
+
def forward(
self,
speech: torch.Tensor,
@@ -65,7 +64,6 @@
text_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -113,7 +111,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -128,7 +125,7 @@
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
return encoder_out, encoder_out_lens
-
+
def _extract_feats(
self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -151,8 +148,8 @@
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
- encoder_out_mask,
- token_num)
+ encoder_out_mask,
+ token_num)
return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def collect_feats(
diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py
index ca76244..d08ea37 100644
--- a/funasr/models/e2e_uni_asr.py
+++ b/funasr/models/e2e_uni_asr.py
@@ -25,7 +25,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.modules.streaming_utils.chunk_utilis import sequence_mask
from funasr.models.predictor.cif import mae_loss
@@ -38,7 +38,7 @@
yield
-class UniASR(AbsESPnetModel):
+class UniASR(FunASRModel):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
"""
@@ -179,7 +179,6 @@
decoding_ind: int = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Decoder + Calc loss
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -469,7 +468,6 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor, ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -533,7 +531,6 @@
ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Frontend + Encoder. Note that this method is used by asr_inference.py
-
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
@@ -627,9 +624,7 @@
ys_pad_lens: torch.Tensor,
) -> torch.Tensor:
"""Compute negative log likelihood(nll) from transformer-decoder
-
Normally, this function is called in batchify_nll.
-
Args:
encoder_out: (Batch, Length, Dim)
encoder_out_lens: (Batch,)
@@ -666,7 +661,6 @@
batch_size: int = 100,
):
"""Compute negative log likelihood(nll) from transformer-decoder
-
To avoid OOM, this fuction seperate the input into batches.
Then call nll for each batch and combine and return results.
Args:
@@ -1072,4 +1066,3 @@
ys_hat = self.ctc2.argmax(encoder_out).data
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
return loss_ctc, cer_ctc
-
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index d72c635..82d8422 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -469,7 +469,7 @@
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
-
+
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
@@ -499,11 +499,11 @@
return segments, in_cache
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
- is_final: bool = False, max_end_sil: int = 800
- ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+ is_final: bool = False, max_end_sil: int = 800
+ ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
-
+
self.ComputeScores(feats, in_cache)
self.ComputeDecibel()
if not is_final:
diff --git a/funasr/models/encoder/abs_encoder.py b/funasr/models/encoder/abs_encoder.py
index 1fb7c97..034bc1f 100644
--- a/funasr/models/encoder/abs_encoder.py
+++ b/funasr/models/encoder/abs_encoder.py
@@ -18,4 +18,4 @@
ilens: torch.Tensor,
prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 434f2a4..aa3b67e 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -15,13 +15,13 @@
from typeguard import check_argument_types
from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttentionChunk,
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
+from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
diff --git a/funasr/models/encoder/data2vec_encoder.py b/funasr/models/encoder/data2vec_encoder.py
index fd1796c..64c2144 100644
--- a/funasr/models/encoder/data2vec_encoder.py
+++ b/funasr/models/encoder/data2vec_encoder.py
@@ -574,4 +574,4 @@
)
def output_size(self) -> int:
- return self.encoder_embed_dim
+ return self.encoder_embed_dim
\ No newline at end of file
diff --git a/funasr/models/encoder/mfcca_encoder.py b/funasr/models/encoder/mfcca_encoder.py
index 83d0b0e..95ccf07 100644
--- a/funasr/models/encoder/mfcca_encoder.py
+++ b/funasr/models/encoder/mfcca_encoder.py
@@ -38,13 +38,12 @@
import pdb
import math
+
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
-
Args:
channels (int): The number of channels of conv layers.
kernel_size (int): Kernerl size of conv layers.
-
"""
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
@@ -83,13 +82,10 @@
def forward(self, x):
"""Compute convolution module.
-
Args:
x (torch.Tensor): Input tensor (#batch, time, channels).
-
Returns:
torch.Tensor: Output tensor (#batch, time, channels).
-
"""
# exchange the temporal dimension and the feature dimension
x = x.transpose(1, 2)
@@ -107,10 +103,8 @@
return x.transpose(1, 2)
-
class MFCCAEncoder(AbsEncoder):
"""Conformer encoder module.
-
Args:
input_size (int): Input dimension.
output_size (int): Dimention of attention.
@@ -140,33 +134,32 @@
zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
cnn_module_kernel (int): Kernerl size of convolution module.
padding_idx (int): Padding idx for input_layer=embed.
-
"""
def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: str = "conv2d",
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 3,
- macaron_style: bool = False,
- rel_pos_type: str = "legacy",
- pos_enc_layer_type: str = "rel_pos",
- selfattention_layer_type: str = "rel_selfattn",
- activation_type: str = "swish",
- use_cnn_module: bool = True,
- zero_triu: bool = False,
- cnn_module_kernel: int = 31,
- padding_idx: int = -1,
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = "conv2d",
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 3,
+ macaron_style: bool = False,
+ rel_pos_type: str = "legacy",
+ pos_enc_layer_type: str = "rel_pos",
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ zero_triu: bool = False,
+ cnn_module_kernel: int = 31,
+ padding_idx: int = -1,
):
assert check_argument_types()
super().__init__()
@@ -199,7 +192,7 @@
)
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
-
+
if input_layer == "linear":
self.embed = torch.nn.Sequential(
torch.nn.Linear(input_size, output_size),
@@ -283,7 +276,7 @@
assert pos_enc_layer_type == "legacy_rel_pos"
encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
encoder_selfattn_layer_args = (
- attention_heads,
+ attention_heads,
output_size,
attention_dropout_rate,
)
@@ -326,42 +319,39 @@
)
if self.normalize_before:
self.after_norm = LayerNorm(output_size)
- self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
+ self.conv1 = torch.nn.Conv2d(8, 16, [5, 7], stride=[1, 1], padding=(2, 3))
- self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
+ self.conv2 = torch.nn.Conv2d(16, 32, [5, 7], stride=[1, 1], padding=(2, 3))
- self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
+ self.conv3 = torch.nn.Conv2d(32, 16, [5, 7], stride=[1, 1], padding=(2, 3))
- self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
+ self.conv4 = torch.nn.Conv2d(16, 1, [5, 7], stride=[1, 1], padding=(2, 3))
def output_size(self) -> int:
return self._output_size
def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- channel_size: torch.Tensor,
- prev_states: torch.Tensor = None,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ channel_size: torch.Tensor,
+ prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
-
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
-
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
-
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
@@ -380,48 +370,46 @@
t_leng = xs_pad.size(1)
d_dim = xs_pad.size(2)
- xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
- #pdb.set_trace()
- if(channel_size<8):
- repeat_num = math.ceil(8/channel_size)
- xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
+ xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim)
+ # pdb.set_trace()
+ if (channel_size < 8):
+ repeat_num = math.ceil(8 / channel_size)
+ xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :]
xs_pad = self.conv1(xs_pad)
xs_pad = self.conv2(xs_pad)
xs_pad = self.conv3(xs_pad)
xs_pad = self.conv4(xs_pad)
- xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
+ xs_pad = xs_pad.squeeze().reshape(-1, t_leng, d_dim)
mask_tmp = masks.size(1)
- masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
+ masks = masks.reshape(-1, channel_size, mask_tmp, t_leng)[:, 0, :, :]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
olens = masks.squeeze(1).sum(1)
return xs_pad, olens, None
+
def forward_hidden(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Calculate forward propagation.
-
Args:
xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
ilens (torch.Tensor): Input length (#batch).
prev_states (torch.Tensor): Not to be used now.
-
Returns:
torch.Tensor: Output tensor (#batch, L, output_size).
torch.Tensor: Output length (#batch).
torch.Tensor: Not to be used now.
-
"""
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
if (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
@@ -447,4 +435,4 @@
self.hidden_feature = self.after_norm(hidden_feature)
olens = masks.squeeze(1).sum(1)
- return xs_pad, olens, None
+ return xs_pad, olens, None
\ No newline at end of file
diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py
index 93695c8..8445feb 100644
--- a/funasr/models/encoder/resnet34_encoder.py
+++ b/funasr/models/encoder/resnet34_encoder.py
@@ -850,4 +850,4 @@
else:
logging.warning("{} is missed from tf checkpoint".format(name))
- return var_dict_torch_update
+ return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/encoder/rnn_encoder.py b/funasr/models/encoder/rnn_encoder.py
index 7a3b053..59730da 100644
--- a/funasr/models/encoder/rnn_encoder.py
+++ b/funasr/models/encoder/rnn_encoder.py
@@ -1,3 +1,4 @@
+
from typing import Optional
from typing import Sequence
from typing import Tuple
@@ -14,7 +15,6 @@
class RNNEncoder(AbsEncoder):
"""RNNEncoder class.
-
Args:
input_size: The number of expected features in the input
output_size: The number of output features
@@ -23,7 +23,6 @@
use_projection: Use projection layer or not
num_layers: Number of recurrent layers
dropout: dropout probability
-
"""
def __init__(
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 2a68011..e071e57 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -27,9 +27,10 @@
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
+from funasr.modules.mask import subsequent_mask, vad_mask
+
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
diff --git a/funasr/models/frontend/abs_frontend.py b/funasr/models/frontend/abs_frontend.py
index 538236f..6049a01 100644
--- a/funasr/models/frontend/abs_frontend.py
+++ b/funasr/models/frontend/abs_frontend.py
@@ -14,4 +14,4 @@
def forward(
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/frontend/default.py b/funasr/models/frontend/default.py
index 2e1b0c4..71f8ec4 100644
--- a/funasr/models/frontend/default.py
+++ b/funasr/models/frontend/default.py
@@ -18,7 +18,6 @@
class DefaultFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
-
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
@@ -142,7 +141,6 @@
class MultiChannelFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
-
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
"""
@@ -260,4 +258,4 @@
# Change torch.Tensor to ComplexTensor
# input_stft: (..., F, 2) -> (..., F)
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
- return input_stft, feats_lens
+ return input_stft, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/fused.py b/funasr/models/frontend/fused.py
index 8b5e56e..857486d 100644
--- a/funasr/models/frontend/fused.py
+++ b/funasr/models/frontend/fused.py
@@ -143,4 +143,4 @@
else:
raise NotImplementedError
- return input_feats, feats_lens
+ return input_feats, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/s3prl.py b/funasr/models/frontend/s3prl.py
index f2b6107..b03d2c9 100644
--- a/funasr/models/frontend/s3prl.py
+++ b/funasr/models/frontend/s3prl.py
@@ -100,7 +100,6 @@
def _tile_representations(self, feature):
"""Tile up the representations by `tile_factor`.
-
Input - sequence of representations
shape: (batch_size, seq_len, feature_dim)
Output - sequence of tiled representations
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 1dbf490..35fab57 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -500,4 +500,4 @@
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
- return feats_pad, feats_lens
+ return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/models/frontend/wav_frontend_kaldifeat.py b/funasr/models/frontend/wav_frontend_kaldifeat.py
index b91ac63..5372de3 100644
--- a/funasr/models/frontend/wav_frontend_kaldifeat.py
+++ b/funasr/models/frontend/wav_frontend_kaldifeat.py
@@ -1,15 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
-from typing import Tuple
-
import numpy as np
import torch
-import torchaudio.compliance.kaldi as kaldi
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from typeguard import check_argument_types
-from torch.nn.utils.rnn import pad_sequence
-# import kaldifeat
+
def load_cmvn(cmvn_file):
with open(cmvn_file, 'r', encoding='utf-8') as f:
@@ -33,9 +27,9 @@
means = np.array(means_list).astype(np.float)
vars = np.array(vars_list).astype(np.float)
cmvn = np.array([means, vars])
- cmvn = torch.as_tensor(cmvn)
- return cmvn
-
+ cmvn = torch.as_tensor(cmvn)
+ return cmvn
+
def apply_cmvn(inputs, cmvn_file): # noqa
"""
@@ -73,108 +67,3 @@
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
-
-
-# class WavFrontend_kaldifeat(AbsFrontend):
-# """Conventional frontend structure for ASR.
-# """
-#
-# def __init__(
-# self,
-# cmvn_file: str = None,
-# fs: int = 16000,
-# window: str = 'hamming',
-# n_mels: int = 80,
-# frame_length: int = 25,
-# frame_shift: int = 10,
-# lfr_m: int = 1,
-# lfr_n: int = 1,
-# dither: float = 1.0,
-# snip_edges: bool = True,
-# upsacle_samples: bool = True,
-# device: str = 'cpu',
-# **kwargs,
-# ):
-# super().__init__()
-#
-# opts = kaldifeat.FbankOptions()
-# opts.device = device
-# opts.frame_opts.samp_freq = fs
-# opts.frame_opts.dither = dither
-# opts.frame_opts.window_type = window
-# opts.frame_opts.frame_shift_ms = float(frame_shift)
-# opts.frame_opts.frame_length_ms = float(frame_length)
-# opts.mel_opts.num_bins = n_mels
-# opts.energy_floor = 0
-# opts.frame_opts.snip_edges = snip_edges
-# opts.mel_opts.debug_mel = False
-# self.opts = opts
-# self.fbank_fn = None
-# self.fbank_beg_idx = 0
-# self.reset_fbank_status()
-#
-# self.lfr_m = lfr_m
-# self.lfr_n = lfr_n
-# self.cmvn_file = cmvn_file
-# self.upsacle_samples = upsacle_samples
-#
-# def output_size(self) -> int:
-# return self.n_mels * self.lfr_m
-#
-# def forward_fbank(
-# self,
-# input: torch.Tensor,
-# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-# batch_size = input.size(0)
-# feats = []
-# feats_lens = []
-# for i in range(batch_size):
-# waveform_length = input_lengths[i]
-# waveform = input[i][:waveform_length]
-# waveform = waveform * (1 << 15)
-#
-# self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
-# frames = self.fbank_fn.num_frames_ready
-# frames_cur = frames - self.fbank_beg_idx
-# mat = torch.empty([frames_cur, self.opts.mel_opts.num_bins], dtype=torch.float32).to(
-# device=self.opts.device)
-# for i in range(self.fbank_beg_idx, frames):
-# mat[i, :] = self.fbank_fn.get_frame(i)
-# self.fbank_beg_idx += frames_cur
-#
-# feat_length = mat.size(0)
-# feats.append(mat)
-# feats_lens.append(feat_length)
-#
-# feats_lens = torch.as_tensor(feats_lens)
-# feats_pad = pad_sequence(feats,
-# batch_first=True,
-# padding_value=0.0)
-# return feats_pad, feats_lens
-#
-# def reset_fbank_status(self):
-# self.fbank_fn = kaldifeat.OnlineFbank(self.opts)
-# self.fbank_beg_idx = 0
-#
-# def forward_lfr_cmvn(
-# self,
-# input: torch.Tensor,
-# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-# batch_size = input.size(0)
-# feats = []
-# feats_lens = []
-# for i in range(batch_size):
-# mat = input[i, :input_lengths[i], :]
-# if self.lfr_m != 1 or self.lfr_n != 1:
-# mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
-# if self.cmvn_file is not None:
-# mat = apply_cmvn(mat, self.cmvn_file)
-# feat_length = mat.size(0)
-# feats.append(mat)
-# feats_lens.append(feat_length)
-#
-# feats_lens = torch.as_tensor(feats_lens)
-# feats_pad = pad_sequence(feats,
-# batch_first=True,
-# padding_value=0.0)
-# return feats_pad, feats_lens
diff --git a/funasr/models/frontend/windowing.py b/funasr/models/frontend/windowing.py
index 7c4c568..a526758 100644
--- a/funasr/models/frontend/windowing.py
+++ b/funasr/models/frontend/windowing.py
@@ -12,12 +12,10 @@
class SlidingWindow(AbsFrontend):
"""Sliding Window.
-
Provides a sliding window over a batched continuous raw audio tensor.
Optionally, provides padding (Currently not implemented).
Combine this module with a pre-encoder compatible with raw audio data,
for example Sinc convolutions.
-
Known issues:
Output length is calculated incorrectly if audio shorter than win_length.
WARNING: trailing values are discarded - padding not implemented yet.
@@ -33,7 +31,6 @@
fs=None,
):
"""Initialize.
-
Args:
win_length: Length of frame.
hop_length: Relative starting point of next frame.
@@ -53,11 +50,9 @@
self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply a sliding window on the input.
-
Args:
input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
input_lengths: Input lengths within batch.
-
Returns:
Tensor: Output with dimensions (B, T, C, D), with D=win_length.
Tensor: Output lengths within batch.
@@ -78,4 +73,4 @@
def output_size(self) -> int:
"""Return output length of feature dimension D, i.e. the window length."""
- return self.win_length
+ return self.win_length
\ No newline at end of file
diff --git a/funasr/models/specaug/abs_specaug.py b/funasr/models/specaug/abs_specaug.py
index 3cbac41..da6637e 100644
--- a/funasr/models/specaug/abs_specaug.py
+++ b/funasr/models/specaug/abs_specaug.py
@@ -6,9 +6,7 @@
class AbsSpecAug(torch.nn.Module):
"""Abstract class for the augmentation of spectrogram
-
The process-flow:
-
Frontend -> SpecAug -> Normalization -> Encoder -> Decoder
"""
diff --git a/funasr/models/target_delay_transformer.py b/funasr/models/target_delay_transformer.py
index e893c65..19e5c7c 100644
--- a/funasr/models/target_delay_transformer.py
+++ b/funasr/models/target_delay_transformer.py
@@ -6,13 +6,10 @@
import torch.nn as nn
from funasr.modules.embedding import SinusoidalPositionEncoder
-#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
-#from funasr.modules.mask import subsequent_n_mask
-from funasr.train.abs_model import AbsPunctuation
-class TargetDelayTransformer(AbsPunctuation):
+class TargetDelayTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
diff --git a/funasr/models/vad_realtime_transformer.py b/funasr/models/vad_realtime_transformer.py
index fe298ce..e2d13f9 100644
--- a/funasr/models/vad_realtime_transformer.py
+++ b/funasr/models/vad_realtime_transformer.py
@@ -7,10 +7,9 @@
from funasr.modules.embedding import SinusoidalPositionEncoder
from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
-from funasr.train.abs_model import AbsPunctuation
-class VadRealtimeTransformer(AbsPunctuation):
+class VadRealtimeTransformer(torch.nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 55a5d79..361ff89 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -30,6 +30,7 @@
import torch.nn
import torch.optim
import yaml
+from funasr.models.base_model import FunASRModel
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from typeguard import check_return_type
@@ -44,19 +45,18 @@
from funasr.iterators.multiple_iter_factory import MultipleIterFactory
from funasr.iterators.sequence_iter_factory import SequenceIterFactory
from funasr.main_funcs.collect_stats import collect_stats
-from funasr.optimizers.sgd import SGD
from funasr.optimizers.fairseq_adam import FairseqAdam
+from funasr.optimizers.sgd import SGD
from funasr.samplers.build_batch_sampler import BATCH_TYPES
from funasr.samplers.build_batch_sampler import build_batch_sampler
from funasr.samplers.unsorted_batch_sampler import UnsortedBatchSampler
from funasr.schedulers.noam_lr import NoamLR
-from funasr.schedulers.warmup_lr import WarmupLR
from funasr.schedulers.tri_stage_scheduler import TriStageLR
+from funasr.schedulers.warmup_lr import WarmupLR
from funasr.torch_utils.load_pretrained_model import load_pretrained_model
from funasr.torch_utils.model_summary import model_summary
from funasr.torch_utils.pytorch_version import pytorch_cudnn_version
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.train.class_choices import ClassChoices
from funasr.train.distributed_utils import DistributedOption
from funasr.train.trainer import Trainer
@@ -230,8 +230,8 @@
>>> cls.check_task_requirements()
If your model is defined as following,
- >>> from funasr.train.abs_espnet_model import AbsESPnetModel
- >>> class Model(AbsESPnetModel):
+ >>> from funasr.models.base_model import FunASRModel
+ >>> class Model(FunASRModel):
... def forward(self, input, output, opt=None): pass
then "required_data_names" should be as
@@ -251,8 +251,8 @@
>>> cls.check_task_requirements()
If your model is defined as follows,
- >>> from funasr.train.abs_espnet_model import AbsESPnetModel
- >>> class Model(AbsESPnetModel):
+ >>> from funasr.models.base_model import FunASRModel
+ >>> class Model(FunASRModel):
... def forward(self, input, output, opt=None): pass
then "optional_data_names" should be as
@@ -263,7 +263,7 @@
@classmethod
@abstractmethod
- def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
+ def build_model(cls, args: argparse.Namespace) -> FunASRModel:
raise NotImplementedError
@classmethod
@@ -1172,7 +1172,8 @@
args.batch_bins = args.batch_bins * args.ngpu
# filter samples if wav.scp and text are mismatch
- if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
+ if (
+ args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large":
if not args.simple_ddp or distributed_option.dist_rank == 0:
filter_wav_text(args.data_dir, args.train_set)
filter_wav_text(args.data_dir, args.dev_set)
@@ -1181,8 +1182,10 @@
if args.train_shape_file is None and args.dataset_type == "small":
if not args.simple_ddp or distributed_option.dist_rank == 0:
- calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
- calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min, args.speech_length_max)
+ calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min,
+ args.speech_length_max)
+ calc_shape(args.data_dir, args.dev_set, args.frontend_conf, args.speech_length_min,
+ args.speech_length_max)
if args.simple_ddp:
dist.barrier()
args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
@@ -1244,9 +1247,9 @@
# 2. Build model
model = cls.build_model(args=args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model = model.to(
dtype=getattr(torch, args.train_dtype),
@@ -1374,15 +1377,21 @@
if args.dataset_type == "large":
from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
- frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
- seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
- punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
+ frontend_conf=args.frontend_conf if hasattr(args,
+ "frontend_conf") else None,
+ seg_dict_file=args.seg_dict_file if hasattr(args,
+ "seg_dict_file") else None,
+ punc_dict_file=args.punc_list if hasattr(args,
+ "punc_list") else None,
bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
mode="train")
- valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
- frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
- seg_dict_file=args.seg_dict_file if hasattr(args, "seg_dict_file") else None,
- punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
+ valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
+ frontend_conf=args.frontend_conf if hasattr(args,
+ "frontend_conf") else None,
+ seg_dict_file=args.seg_dict_file if hasattr(args,
+ "seg_dict_file") else None,
+ punc_dict_file=args.punc_list if hasattr(args,
+ "punc_list") else None,
bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
mode="eval")
elif args.dataset_type == "small":
@@ -1929,7 +1938,7 @@
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
- ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
+ ) -> Tuple[FunASRModel, argparse.Namespace]:
"""Build model from the files.
This method is used for inference or fine-tuning.
@@ -1956,9 +1965,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
if model_file is not None:
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 43ea5ab..b01aa6f 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -38,9 +38,9 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.joint_net.joint_network import JointNetwork
-from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
@@ -76,7 +76,7 @@
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -122,7 +122,7 @@
model_choices = ClassChoices(
"model",
classes=dict(
- asr=ESPnetASRModel,
+ asr=ASRModel,
uniasr=UniASR,
paraformer=Paraformer,
paraformer_online=ParaformerOnline,
@@ -133,7 +133,7 @@
mfcca=MFCCA,
timestamp_prediction=TimestampPredictor,
),
- type_check=AbsESPnetModel,
+ type_check=FunASRModel,
default="asr",
)
preencoder_choices = ClassChoices(
@@ -457,7 +457,7 @@
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
- non_linguistic_symbols=args.non_linguistic_symbols,
+ non_linguistic_symbols=args.non_linguistic_symbols if hasattr(args, "non_linguistic_symbols") else None,
text_cleaner=args.cleaner,
g2p_type=args.g2p,
split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
@@ -827,9 +827,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
@@ -1074,9 +1074,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 45e4ce7..2625fec 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -58,7 +58,7 @@
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@@ -114,7 +114,7 @@
sond=DiarSondModel,
eend_ola=DiarEENDOLAModel,
),
- type_check=AbsESPnetModel,
+ type_check=FunASRModel,
default="sond",
)
encoder_choices = ClassChoices(
@@ -544,9 +544,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
@@ -902,9 +902,9 @@
args = yaml.safe_load(f)
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
if model_file is not None:
if device == "cuda":
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index 80d66d5..1e48655 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -206,6 +206,3 @@
# 3. Initialize
if args.init is not None:
initialize(model, args.init)
-
- assert check_return_type(model)
- return model
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index 0170f28..a63bbe4 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -14,7 +14,6 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
-from funasr.train.abs_model import AbsPunctuation
from funasr.train.abs_model import PunctuationModel
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
@@ -31,7 +30,6 @@
punc_choices = ClassChoices(
"punctuation",
classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
- type_check=AbsPunctuation,
default="target_delay",
)
diff --git a/funasr/tasks/sv.py b/funasr/tasks/sv.py
index 9710447..e4815da 100644
--- a/funasr/tasks/sv.py
+++ b/funasr/tasks/sv.py
@@ -25,7 +25,7 @@
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.models.e2e_asr import ESPnetASRModel
+from funasr.models.e2e_asr import ASRModel
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
@@ -49,7 +49,7 @@
from funasr.models.specaug.specaug import SpecAug
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@@ -94,7 +94,7 @@
classes=dict(
espnet=ESPnetSVModel,
),
- type_check=AbsESPnetModel,
+ type_check=FunASRModel,
default="espnet",
)
preencoder_choices = ClassChoices(
@@ -488,9 +488,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
diff --git a/funasr/tasks/vad.py b/funasr/tasks/vad.py
index d07acf1..ec95596 100644
--- a/funasr/tasks/vad.py
+++ b/funasr/tasks/vad.py
@@ -1,77 +1,42 @@
import argparse
import logging
+import os
+from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
-import os
-from pathlib import Path
-from typing import Tuple
from typing import Union
-import yaml
+
import numpy as np
import torch
+import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
-from funasr.datasets.preprocessor import CommonPreprocessor
-from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.rnn_decoder import RNNDecoder
-from funasr.models.decoder.transformer_decoder import (
- DynamicConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
-from funasr.models.decoder.transformer_decoder import (
- LightweightConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import (
- LightweightConvolutionTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import TransformerDecoder
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.transformer_encoder import TransformerEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.postencoder.hugging_face_transformers_postencoder import (
- HuggingFaceTransformersPostEncoder, # noqa: H301
-)
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.preencoder.linear import LinearProjection
-from funasr.models.preencoder.sinc import LightweightSincConvs
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.specaug.specaug import SpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
-from funasr.train.abs_espnet_model import AbsESPnetModel
-from funasr.train.class_choices import ClassChoices
-from funasr.train.trainer import Trainer
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import float_or_none
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2
-from funasr.modules.subsampling import Conv1dSubsampling
from funasr.models.e2e_vad import E2EVadModel
from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.tasks.abs_task import AbsTask
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str_or_none
frontend_choices = ClassChoices(
name="frontend",
@@ -292,7 +257,7 @@
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("e2evad")
-
+
# 1. frontend
if args.input_size is None:
# Extract features in the model
@@ -308,7 +273,7 @@
args.frontend_conf = {}
frontend = None
input_size = args.input_size
-
+
model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
return model
@@ -344,7 +309,7 @@
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
- #if cmvn_file is not None:
+ # if cmvn_file is not None:
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
diff --git a/funasr/train/abs_espnet_model.py b/funasr/train/abs_espnet_model.py
deleted file mode 100644
index cc6a5a2..0000000
--- a/funasr/train/abs_espnet_model.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
-from abc import ABC
-from abc import abstractmethod
-from typing import Dict
-from typing import Tuple
-
-import torch
-
-
-class AbsESPnetModel(torch.nn.Module, ABC):
- """The common abstract class among each tasks
-
- "ESPnetModel" is referred to a class which inherits torch.nn.Module,
- and makes the dnn-models forward as its member field,
- a.k.a delegate pattern,
- and defines "loss", "stats", and "weight" for the task.
-
- If you intend to implement new task in ESPNet,
- the model must inherit this class.
- In other words, the "mediator" objects between
- our training system and the your task class are
- just only these three values, loss, stats, and weight.
-
- Example:
- >>> from funasr.tasks.abs_task import AbsTask
- >>> class YourESPnetModel(AbsESPnetModel):
- ... def forward(self, input, input_lengths):
- ... ...
- ... return loss, stats, weight
- >>> class YourTask(AbsTask):
- ... @classmethod
- ... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
- """
-
- def __init__(self):
- super().__init__()
- self.num_updates = 0
-
- @abstractmethod
- def forward(
- self, **batch: torch.Tensor
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- raise NotImplementedError
-
- @abstractmethod
- def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
- raise NotImplementedError
-
- def set_num_updates(self, num_updates):
- self.num_updates = num_updates
-
- def get_num_updates(self):
- return self.num_updates
diff --git a/funasr/train/abs_model.py b/funasr/train/abs_model.py
index 1c7ff3d..026140b 100644
--- a/funasr/train/abs_model.py
+++ b/funasr/train/abs_model.py
@@ -12,34 +12,12 @@
from funasr.modules.nets_utils import make_pad_mask
from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
-
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from funasr.models.base_model import FunASRModel
-class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
- """The abstract class
-
- To share the loss calculation way among different models,
- We uses delegate pattern here:
- The instance of this class should be passed to "LanguageModel"
-
- This "model" is one of mediator objects for "Task" class.
-
- """
-
- @abstractmethod
- def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
- raise NotImplementedError
-
- @abstractmethod
- def with_vad(self) -> bool:
- raise NotImplementedError
-
-
-class PunctuationModel(AbsESPnetModel):
+class PunctuationModel(FunASRModel):
- def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
+ def __init__(self, punc_model: torch.nn.Module, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
assert check_argument_types()
super().__init__()
self.punc_model = punc_model
diff --git a/funasr/train/trainer.py b/funasr/train/trainer.py
index a40f031..4052448 100644
--- a/funasr/train/trainer.py
+++ b/funasr/train/trainer.py
@@ -39,7 +39,7 @@
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.recursive_op import recursive_average
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.distributed_utils import DistributedOption
from funasr.train.reporter import Reporter
from funasr.train.reporter import SubReporter
@@ -166,7 +166,7 @@
@classmethod
def run(
cls,
- model: AbsESPnetModel,
+ model: FunASRModel,
optimizers: Sequence[torch.optim.Optimizer],
schedulers: Sequence[Optional[AbsScheduler]],
train_iter_factory: AbsIterFactory,
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
new file mode 100644
index 0000000..3f55170
--- /dev/null
+++ b/funasr/utils/prepare_data.py
@@ -0,0 +1,209 @@
+import logging
+import os
+import shutil
+from multiprocessing import Pool
+
+import numpy as np
+import torch.distributed as dist
+import torchaudio
+
+
+def filter_wav_text(data_dir, dataset):
+ wav_file = os.path.join(data_dir, dataset, "wav.scp")
+ text_file = os.path.join(data_dir, dataset, "text")
+ with open(wav_file) as f_wav, open(text_file) as f_text:
+ wav_lines = f_wav.readlines()
+ text_lines = f_text.readlines()
+ os.rename(wav_file, "{}.bak".format(wav_file))
+ os.rename(text_file, "{}.bak".format(text_file))
+ wav_dict = {}
+ for line in wav_lines:
+ parts = line.strip().split()
+ if len(parts) < 2:
+ continue
+ wav_dict[parts[0]] = parts[1]
+ text_dict = {}
+ for line in text_lines:
+ parts = line.strip().split()
+ if len(parts) < 2:
+ continue
+ text_dict[parts[0]] = " ".join(parts[1:])
+ filter_count = 0
+ with open(wav_file, "w") as f_wav, open(text_file, "w") as f_text:
+ for sample_name, wav_path in wav_dict.items():
+ if sample_name in text_dict.keys():
+ f_wav.write(sample_name + " " + wav_path + "\n")
+ f_text.write(sample_name + " " + text_dict[sample_name] + "\n")
+ else:
+ filter_count += 1
+ logging.info("{}/{} samples in {} are filtered because of the mismatch between wav.scp and text".
+ format(filter_count, len(wav_lines), dataset))
+
+
+def wav2num_frame(wav_path, frontend_conf):
+ waveform, sampling_rate = torchaudio.load(wav_path)
+ n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
+ feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
+ return n_frames, feature_dim
+
+
+def calc_shape_core(root_path, args, idx):
+ wav_scp_file = os.path.join(root_path, "wav.scp.{}".format(idx))
+ shape_file = os.path.join(root_path, "speech_shape.{}".format(idx))
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ frontend_conf = args.frontend_conf
+ dataset_conf = args.dataset_conf
+ speech_length_min = dataset_conf.speech_length_min if hasattr(dataset_conf, "speech_length_min") else -1
+ speech_length_max = dataset_conf.speech_length_max if hasattr(dataset_conf, "speech_length_max") else -1
+ with open(shape_file, "w") as f:
+ for line in lines:
+ sample_name, wav_path = line.strip().split()
+ n_frames, feature_dim = wav2num_frame(wav_path, frontend_conf)
+ write_flag = True
+ if n_frames > 0 and speech_length_min > 0:
+ write_flag = n_frames >= speech_length_min
+ if n_frames > 0 and speech_length_max > 0:
+ write_flag = n_frames <= speech_length_max
+ if write_flag:
+ f.write("{} {},{}\n".format(sample_name, str(int(np.ceil(n_frames))), str(int(feature_dim))))
+ f.flush()
+
+
+def calc_shape(args, dataset, nj=64):
+ shape_path = os.path.join(args.data_dir, dataset, "speech_shape")
+ if os.path.exists(shape_path):
+ logging.info('Shape file for small dataset already exists.')
+ return
+
+ split_shape_path = os.path.join(args.data_dir, dataset, "shape_files")
+ if os.path.exists(split_shape_path):
+ shutil.rmtree(split_shape_path)
+ os.mkdir(split_shape_path)
+
+ # split
+ wav_scp_file = os.path.join(args.data_dir, dataset, "wav.scp")
+ with open(wav_scp_file) as f:
+ lines = f.readlines()
+ num_lines = len(lines)
+ num_job_lines = num_lines // nj
+ start = 0
+ for i in range(nj):
+ end = start + num_job_lines
+ file = os.path.join(split_shape_path, "wav.scp.{}".format(str(i + 1)))
+ with open(file, "w") as f:
+ if i == nj - 1:
+ f.writelines(lines[start:])
+ else:
+ f.writelines(lines[start:end])
+ start = end
+
+ p = Pool(nj)
+ for i in range(nj):
+ p.apply_async(calc_shape_core, args=(split_shape_path, args, str(i + 1)))
+ logging.info("Generating shape files, please wait a few minutes...")
+ p.close()
+ p.join()
+
+ # combine
+ with open(shape_path, "w") as f:
+ for i in range(nj):
+ job_file = os.path.join(split_shape_path, "speech_shape.{}".format(str(i + 1)))
+ with open(job_file) as job_f:
+ lines = job_f.readlines()
+ f.writelines(lines)
+ logging.info('Generating shape files done.')
+
+
+def generate_data_list(data_dir, dataset, nj=64):
+ list_file = os.path.join(data_dir, dataset, "data.list")
+ if os.path.exists(list_file):
+ logging.info('Data list for large dataset already exists.')
+ return
+ split_path = os.path.join(data_dir, dataset, "split")
+ if os.path.exists(split_path):
+ shutil.rmtree(split_path)
+ os.mkdir(split_path)
+
+ with open(os.path.join(data_dir, dataset, "wav.scp")) as f_wav:
+ wav_lines = f_wav.readlines()
+ with open(os.path.join(data_dir, dataset, "text")) as f_text:
+ text_lines = f_text.readlines()
+ num_lines = len(wav_lines)
+ num_job_lines = num_lines // nj
+ start = 0
+ for i in range(nj):
+ end = start + num_job_lines
+ split_path_nj = os.path.join(split_path, str(i + 1))
+ os.mkdir(split_path_nj)
+ wav_file = os.path.join(split_path_nj, "wav.scp")
+ text_file = os.path.join(split_path_nj, "text")
+ with open(wav_file, "w") as fw, open(text_file, "w") as ft:
+ if i == nj - 1:
+ fw.writelines(wav_lines[start:])
+ ft.writelines(text_lines[start:])
+ else:
+ fw.writelines(wav_lines[start:end])
+ ft.writelines(text_lines[start:end])
+ start = end
+
+ with open(list_file, "w") as f_data:
+ for i in range(nj):
+ wav_path = os.path.join(split_path, str(i + 1), "wav.scp")
+ text_path = os.path.join(split_path, str(i + 1), "text")
+ f_data.write(wav_path + " " + text_path + "\n")
+
+
+def prepare_data(args, distributed_option):
+ distributed = distributed_option.distributed
+ if not distributed or distributed_option.dist_rank == 0:
+ filter_wav_text(args.data_dir, args.train_set)
+ filter_wav_text(args.data_dir, args.valid_set)
+
+ if args.dataset_type == "small":
+ calc_shape(args, args.train_set)
+ calc_shape(args, args.valid_set)
+
+ if args.dataset_type == "large":
+ generate_data_list(args.data_dir, args.train_set)
+ generate_data_list(args.data_dir, args.valid_set)
+
+ if args.dataset_type == "small":
+ args.train_shape_file = [os.path.join(args.data_dir, args.train_set, "speech_shape")]
+ args.valid_shape_file = [os.path.join(args.data_dir, args.valid_set, "speech_shape")]
+ data_names = args.dataset_conf.get("data_names", "speech,text").split(",")
+ data_types = args.dataset_conf.get("data_types", "sound,text").split(",")
+ args.train_data_path_and_name_and_type = [
+ ["{}/{}/wav.scp".format(args.data_dir, args.train_set), data_names[0], data_types[0]],
+ ["{}/{}/text".format(args.data_dir, args.train_set), data_names[1], data_types[1]]
+ ]
+ args.valid_data_path_and_name_and_type = [
+ ["{}/{}/wav.scp".format(args.data_dir, args.valid_set), data_names[0], data_types[0]],
+ ["{}/{}/text".format(args.data_dir, args.valid_set), data_names[1], data_types[1]]
+ ]
+ if args.embed_path is not None:
+ args.train_data_path_and_name_and_type.append(
+ [os.path.join(args.embed_path, "embeds", args.train_set, "embeds.scp"), "embed", "kaldi_ark"])
+ args.valid_data_path_and_name_and_type.append(
+ [os.path.join(args.embed_path, "embeds", args.valid_set, "embeds.scp"), "embed", "kaldi_ark"])
+ else:
+ args.train_data_file = os.path.join(args.data_dir, args.train_set, "data.list")
+ args.valid_data_file = os.path.join(args.data_dir, args.valid_set, "data.list")
+ if args.embed_path is not None:
+ if not distributed or distributed_option.dist_rank == 0:
+ for d in [args.train_set, args.valid_set]:
+ file = os.path.join(args.data_dir, d, "data.list")
+ with open(file) as f:
+ lines = f.readlines()
+ out_file = os.path.join(args.data_dir, d, "data_with_embed.list")
+ with open(out_file, "w") as out_f:
+ for line in lines:
+ parts = line.strip().split()
+ idx = parts[0].split("/")[-2]
+ embed_file = os.path.join(args.embed_path, "embeds", args.valid_set, "ark",
+ "embeds.{}.ark".format(idx))
+ out_f.write(parts[0] + " " + parts[1] + " " + embed_file + "\n")
+ args.train_data_file = os.path.join(args.data_dir, args.train_set, "data_with_embed.list")
+ args.valid_data_file = os.path.join(args.data_dir, args.valid_set, "data_with_embed.list")
+ if distributed:
+ dist.barrier()
--
Gitblit v1.9.1