From d76aea23d9f5daac4df7ee1985d07f7428abc719 Mon Sep 17 00:00:00 2001
From: smohan-speech <smohan@mail.ustc.edu.cn>
Date: 星期日, 07 五月 2023 02:21:58 +0800
Subject: [PATCH] add speaker-attributed ASR task for alimeeting
---
egs/alimeeting/sa-asr/asr_local.sh | 33
funasr/bin/asr_train.py | 8
funasr/models/e2e_sa_asr.py | 3
egs/alimeeting/sa-asr/asr_local_infer.sh | 3
egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh | 10
egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh | 116 ++++
egs/alimeeting/sa-asr/local/data/get_reco2dur.sh | 0
funasr/bin/sa_asr_inference.py | 24
egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh | 14
egs/alimeeting/sa-asr/path.sh | 3
funasr/tasks/sa_asr.py | 2
egs/alimeeting/sa-asr/local/data/get_utt2dur.sh | 2
egs/alimeeting/sa-asr/local/format_wav_scp.sh | 142 +++++
egs/alimeeting/sa-asr/local/validate_data_dir.sh | 4
funasr/models/decoder/transformer_decoder.py | 428 +++++++++++++++++
egs/alimeeting/sa-asr/local/data/split_data.sh | 6
funasr/bin/sa_asr_train.py | 8
egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml | 1
egs/alimeeting/sa-asr/local/combine_data.sh | 6
egs/alimeeting/sa-asr/local/format_wav_scp.py | 243 ++++++++++
egs/alimeeting/sa-asr/local/validate_text.pl | 0
egs/alimeeting/sa-asr/local/fix_data_dir.sh | 6
funasr/losses/label_smoothing_loss.py | 46 +
egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml | 1
egs/alimeeting/sa-asr/utils | 1
egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh | 2
egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl | 0
/dev/null | 246 ----------
egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl | 0
funasr/bin/asr_inference.py | 28 +
egs/alimeeting/sa-asr/local/copy_data_dir.sh | 28
egs/alimeeting/sa-asr/local/apply_map.pl | 0
funasr/bin/asr_inference_launch.py | 8
33 files changed, 1,090 insertions(+), 332 deletions(-)
diff --git a/egs/alimeeting/sa-asr/asr_local.sh b/egs/alimeeting/sa-asr/asr_local.sh
index c0359eb..419e341 100755
--- a/egs/alimeeting/sa-asr/asr_local.sh
+++ b/egs/alimeeting/sa-asr/asr_local.sh
@@ -434,14 +434,14 @@
log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp"
for factor in ${speed_perturb_factors}; do
if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then
- scripts/utils/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
+ local/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
_dirs+="data/${train_set}_sp${factor} "
else
# If speed factor is 1, same as the original
_dirs+="data/${train_set} "
fi
done
- utils/combine_data.sh "data/${train_set}_sp" ${_dirs}
+ local/combine_data.sh "data/${train_set}_sp" ${_dirs}
else
log "Skip stage 2: Speed perturbation"
fi
@@ -473,7 +473,7 @@
_suf=""
fi
fi
- utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
+ local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
@@ -488,7 +488,7 @@
_opts+="--segments data/${dset}/segments "
fi
# shellcheck disable=SC2086
- scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
+ local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
--audio-format "${audio_format}" --fs "${fs}" ${_opts} \
"data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
@@ -515,7 +515,7 @@
for dset in $rm_dset; do
# Copy data dir
- utils/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
+ local/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type"
# Remove short utterances
@@ -564,7 +564,7 @@
awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text"
# fix_data_dir.sh leaves only utts which exist in all files
- utils/fix_data_dir.sh "${data_feats}/${dset}"
+ local/fix_data_dir.sh "${data_feats}/${dset}"
# generate uttid
cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
@@ -1283,6 +1283,7 @@
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
+ --mc True \
--nbest 1 \
--ngpu "${_ngpu}" \
--njob ${njob_infer} \
@@ -1312,10 +1313,10 @@
_data="${data_feats}/${dset}"
_dir="${asr_exp}/${inference_tag}/${dset}"
- python local/proce_text.py ${_data}/text ${_data}/text.proc
- python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+ python utils/proce_text.py ${_data}/text ${_data}/text.proc
+ python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
- python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
@@ -1390,6 +1391,7 @@
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
+ --mc True \
--nbest 1 \
--ngpu "${_ngpu}" \
--njob ${njob_infer} \
@@ -1421,10 +1423,10 @@
_data="${data_feats}/${dset}"
_dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
- python local/proce_text.py ${_data}/text ${_data}/text.proc
- python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+ python utils/proce_text.py ${_data}/text ${_data}/text.proc
+ python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
- python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
@@ -1506,6 +1508,7 @@
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
+ --mc True \
--nbest 1 \
--ngpu "${_ngpu}" \
--njob ${njob_infer} \
@@ -1536,10 +1539,10 @@
_data="${data_feats}/${dset}"
_dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
- python local/proce_text.py ${_data}/text ${_data}/text.proc
- python local/proce_text.py ${_dir}/text ${_dir}/text.proc
+ python utils/proce_text.py ${_data}/text ${_data}/text.proc
+ python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
- python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
cat ${_dir}/text.cer.txt
diff --git a/egs/alimeeting/sa-asr/asr_local_infer.sh b/egs/alimeeting/sa-asr/asr_local_infer.sh
index 8e8148f..b7a9289 100755
--- a/egs/alimeeting/sa-asr/asr_local_infer.sh
+++ b/egs/alimeeting/sa-asr/asr_local_infer.sh
@@ -436,7 +436,7 @@
_suf=""
- utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
+ local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
_opts=
@@ -548,6 +548,7 @@
${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
python -m funasr.bin.asr_inference_launch \
--batch_size 1 \
+ --mc True \
--nbest 1 \
--ngpu "${_ngpu}" \
--njob ${njob_infer} \
diff --git a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
index a8c9968..7865763 100644
--- a/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
+++ b/egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
@@ -4,7 +4,6 @@
n_fft: 400
win_length: 400
hop_length: 160
- use_channel: 0
# encoder related
encoder: conformer
diff --git a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
index e91db18..421d7df 100644
--- a/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
+++ b/egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
@@ -4,7 +4,6 @@
n_fft: 400
win_length: 400
hop_length: 160
- use_channel: 0
# encoder related
asr_encoder: conformer
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
index 8151bae..7d39cdc 100755
--- a/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
+++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
@@ -78,7 +78,7 @@
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1
#sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
- utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
+ local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
sed -e 's/锛�//g' $near_dir/tmp1> $near_dir/tmp2
@@ -109,7 +109,7 @@
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
- utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
+ local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
@@ -121,8 +121,8 @@
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
log "stage 3: finali data process"
- utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
- utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+ local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
+ local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
@@ -146,10 +146,10 @@
python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
- utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
+ local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
- ./utils/fix_data_dir.sh $far_single_speaker_dir
- utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+ ./local/fix_data_dir.sh $far_single_speaker_dir
+ local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
# remove space in text
for x in ${tgt}_Ali_far_single_speaker; do
diff --git a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
index 382a056..e3ce934 100755
--- a/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
+++ b/egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
@@ -77,7 +77,7 @@
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
#sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk
- utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
+ local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
@@ -89,7 +89,7 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
log "stage 2: finali data process"
- utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
+ local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
@@ -113,10 +113,10 @@
python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir
cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text
- utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
+ local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
- ./utils/fix_data_dir.sh $far_single_speaker_dir
- utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
+ ./local/fix_data_dir.sh $far_single_speaker_dir
+ local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
# remove space in text
for x in ${tgt}_Ali_far_single_speaker; do
diff --git a/egs/alimeeting/sa-asr/utils/apply_map.pl b/egs/alimeeting/sa-asr/local/apply_map.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/utils/apply_map.pl
rename to egs/alimeeting/sa-asr/local/apply_map.pl
diff --git a/egs/alimeeting/sa-asr/utils/combine_data.sh b/egs/alimeeting/sa-asr/local/combine_data.sh
similarity index 96%
rename from egs/alimeeting/sa-asr/utils/combine_data.sh
rename to egs/alimeeting/sa-asr/local/combine_data.sh
index e1eba85..a3436b5 100755
--- a/egs/alimeeting/sa-asr/utils/combine_data.sh
+++ b/egs/alimeeting/sa-asr/local/combine_data.sh
@@ -98,7 +98,7 @@
for in_dir in $*; do
if [ ! -f $in_dir/segments ]; then
echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
- utils/data/get_segments_for_data.sh $in_dir
+ local/data/get_segments_for_data.sh $in_dir
else
cat $in_dir/segments
fi
@@ -133,14 +133,14 @@
fi
done
-utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
+local/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
if [[ $dir_with_frame_shift ]]; then
cp $dir_with_frame_shift/frame_shift $dest
fi
if ! $skip_fix ; then
- utils/fix_data_dir.sh $dest || exit 1;
+ local/fix_data_dir.sh $dest || exit 1;
fi
exit 0
diff --git a/egs/alimeeting/sa-asr/utils/copy_data_dir.sh b/egs/alimeeting/sa-asr/local/copy_data_dir.sh
similarity index 80%
rename from egs/alimeeting/sa-asr/utils/copy_data_dir.sh
rename to egs/alimeeting/sa-asr/local/copy_data_dir.sh
index 9fd420c..6e748dd 100755
--- a/egs/alimeeting/sa-asr/utils/copy_data_dir.sh
+++ b/egs/alimeeting/sa-asr/local/copy_data_dir.sh
@@ -71,25 +71,25 @@
cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq
fi
-cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map | \
- utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
+cat $srcdir/utt2spk | local/apply_map.pl -f 1 $destdir/utt_map | \
+ local/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
-utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
+local/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
if [ -f $srcdir/feats.scp ]; then
- utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
fi
if [ -f $srcdir/vad.scp ]; then
- utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
fi
if [ -f $srcdir/segments ]; then
- utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
cp $srcdir/wav.scp $destdir
else # no segments->wav indexed by utt.
if [ -f $srcdir/wav.scp ]; then
- utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
fi
fi
@@ -98,26 +98,26 @@
fi
if [ -f $srcdir/text ]; then
- utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
fi
if [ -f $srcdir/utt2dur ]; then
- utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
fi
if [ -f $srcdir/utt2num_frames ]; then
- utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
fi
if [ -f $srcdir/reco2dur ]; then
if [ -f $srcdir/segments ]; then
cp $srcdir/reco2dur $destdir/reco2dur
else
- utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
+ local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
fi
fi
if [ -f $srcdir/spk2gender ]; then
- utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
+ local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
fi
if [ -f $srcdir/cmvn.scp ]; then
- utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
+ local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
fi
for f in frame_shift stm glm ctm; do
if [ -f $srcdir/$f ]; then
@@ -142,4 +142,4 @@
[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats"
[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text"
-utils/validate_data_dir.sh $validate_opts $destdir
+local/validate_data_dir.sh $validate_opts $destdir
diff --git a/egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh b/egs/alimeeting/sa-asr/local/data/get_reco2dur.sh
similarity index 100%
rename from egs/alimeeting/sa-asr/utils/data/get_reco2dur.sh
rename to egs/alimeeting/sa-asr/local/data/get_reco2dur.sh
diff --git a/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh b/egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh
similarity index 93%
rename from egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh
rename to egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh
index 6b161b3..9310715 100755
--- a/egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh
+++ b/egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh
@@ -20,7 +20,7 @@
data=$1
if [ ! -s $data/utt2dur ]; then
- utils/data/get_utt2dur.sh $data 1>&2 || exit 1;
+ local/data/get_utt2dur.sh $data 1>&2 || exit 1;
fi
# <utt-id> <utt-id> 0 <utt-dur>
diff --git a/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh b/egs/alimeeting/sa-asr/local/data/get_utt2dur.sh
similarity index 98%
rename from egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh
rename to egs/alimeeting/sa-asr/local/data/get_utt2dur.sh
index 5ee7ea3..833a7fc 100755
--- a/egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh
+++ b/egs/alimeeting/sa-asr/local/data/get_utt2dur.sh
@@ -94,7 +94,7 @@
nj=$num_utts
fi
- utils/data/split_data.sh --per-utt $data $nj
+ local/data/split_data.sh --per-utt $data $nj
sdata=$data/split${nj}utt
$cmd JOB=1:$nj $data/log/get_durations.JOB.log \
diff --git a/egs/alimeeting/sa-asr/utils/data/split_data.sh b/egs/alimeeting/sa-asr/local/data/split_data.sh
similarity index 96%
rename from egs/alimeeting/sa-asr/utils/data/split_data.sh
rename to egs/alimeeting/sa-asr/local/data/split_data.sh
index 8aa71a1..97ad8c5 100755
--- a/egs/alimeeting/sa-asr/utils/data/split_data.sh
+++ b/egs/alimeeting/sa-asr/local/data/split_data.sh
@@ -60,11 +60,11 @@
nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then
echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can "
- echo "** use utils/fix_data_dir.sh $data to fix this."
+ echo "** use local/fix_data_dir.sh $data to fix this."
fi
if [ -f $data/text ] && [ $nu -ne $nt ]; then
echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can "
- echo "** use utils/fix_data_dir.sh to fix this."
+ echo "** use local/fix_data_dir.sh to fix this."
fi
@@ -112,7 +112,7 @@
for n in `seq $numsplit`; do
dsn=$data/split${numsplit}${utt}/$n
- utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
+ local/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
done
maybe_wav_scp=
diff --git a/egs/alimeeting/sa-asr/utils/fix_data_dir.sh b/egs/alimeeting/sa-asr/local/fix_data_dir.sh
similarity index 97%
rename from egs/alimeeting/sa-asr/utils/fix_data_dir.sh
rename to egs/alimeeting/sa-asr/local/fix_data_dir.sh
index ed4710d..3abd465 100755
--- a/egs/alimeeting/sa-asr/utils/fix_data_dir.sh
+++ b/egs/alimeeting/sa-asr/local/fix_data_dir.sh
@@ -112,7 +112,7 @@
function filter_speakers {
# throughout this program, we regard utt2spk as primary and spk2utt as derived, so...
- utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
+ local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
for s in cmvn.scp spk2gender; do
@@ -123,7 +123,7 @@
done
filter_file $tmpdir/speakers $data/spk2utt
- utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
+ local/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
for s in cmvn.scp spk2gender $spk_extra_files; do
f=$data/$s
@@ -210,6 +210,6 @@
filter_speakers
filter_recordings
-utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
+local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
echo "fix_data_dir.sh: old files are kept in $data/.backup"
diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.py b/egs/alimeeting/sa-asr/local/format_wav_scp.py
new file mode 100755
index 0000000..1fd63d6
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/format_wav_scp.py
@@ -0,0 +1,243 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+from io import BytesIO
+from pathlib import Path
+from typing import Tuple, Optional
+
+import kaldiio
+import humanfriendly
+import numpy as np
+import resampy
+import soundfile
+from tqdm import tqdm
+from typeguard import check_argument_types
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.read_text import read_2column_text
+from funasr.fileio.sound_scp import SoundScpWriter
+
+
+def humanfriendly_or_none(value: str):
+ if value in ("none", "None", "NONE"):
+ return None
+ return humanfriendly.parse_size(value)
+
+
+def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
+ """
+
+ >>> str2int_tuple('3,4,5')
+ (3, 4, 5)
+
+ """
+ assert check_argument_types()
+ if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
+ return None
+ return tuple(map(int, integers.strip().split(",")))
+
+
+def main():
+ logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
+ logging.basicConfig(level=logging.INFO, format=logfmt)
+ logging.info(get_commandline_args())
+
+ parser = argparse.ArgumentParser(
+ description='Create waves list from "wav.scp"',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument("scp")
+ parser.add_argument("outdir")
+ parser.add_argument(
+ "--name",
+ default="wav",
+ help="Specify the prefix word of output file name " 'such as "wav.scp"',
+ )
+ parser.add_argument("--segments", default=None)
+ parser.add_argument(
+ "--fs",
+ type=humanfriendly_or_none,
+ default=None,
+ help="If the sampling rate specified, " "Change the sampling rate.",
+ )
+ parser.add_argument("--audio-format", default="wav")
+ group = parser.add_mutually_exclusive_group()
+ group.add_argument("--ref-channels", default=None, type=str2int_tuple)
+ group.add_argument("--utt2ref-channels", default=None, type=str)
+ args = parser.parse_args()
+
+ out_num_samples = Path(args.outdir) / f"utt2num_samples"
+
+ if args.ref_channels is not None:
+
+ def utt2ref_channels(x) -> Tuple[int, ...]:
+ return args.ref_channels
+
+ elif args.utt2ref_channels is not None:
+ utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
+
+ def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
+ chs_str = d[x]
+ return tuple(map(int, chs_str.split()))
+
+ else:
+ utt2ref_channels = None
+
+ Path(args.outdir).mkdir(parents=True, exist_ok=True)
+ out_wavscp = Path(args.outdir) / f"{args.name}.scp"
+ if args.segments is not None:
+ # Note: kaldiio supports only wav-pcm-int16le file.
+ loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
+ if args.audio_format.endswith("ark"):
+ fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
+ fscp = out_wavscp.open("w")
+ else:
+ writer = SoundScpWriter(
+ args.outdir,
+ out_wavscp,
+ format=args.audio_format,
+ )
+
+ with out_num_samples.open("w") as fnum_samples:
+ for uttid, (rate, wave) in tqdm(loader):
+ # wave: (Time,) or (Time, Nmic)
+ if wave.ndim == 2 and utt2ref_channels is not None:
+ wave = wave[:, utt2ref_channels(uttid)]
+
+ if args.fs is not None and args.fs != rate:
+ # FIXME(kamo): To use sox?
+ wave = resampy.resample(
+ wave.astype(np.float64), rate, args.fs, axis=0
+ )
+ wave = wave.astype(np.int16)
+ rate = args.fs
+ if args.audio_format.endswith("ark"):
+ if "flac" in args.audio_format:
+ suf = "flac"
+ elif "wav" in args.audio_format:
+ suf = "wav"
+ else:
+ raise RuntimeError("wav.ark or flac")
+
+ # NOTE(kamo): Using extended ark format style here.
+ # This format is incompatible with Kaldi
+ kaldiio.save_ark(
+ fark,
+ {uttid: (wave, rate)},
+ scp=fscp,
+ append=True,
+ write_function=f"soundfile_{suf}",
+ )
+
+ else:
+ writer[uttid] = rate, wave
+ fnum_samples.write(f"{uttid} {len(wave)}\n")
+ else:
+ if args.audio_format.endswith("ark"):
+ fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
+ else:
+ wavdir = Path(args.outdir) / f"data_{args.name}"
+ wavdir.mkdir(parents=True, exist_ok=True)
+
+ with Path(args.scp).open("r") as fscp, out_wavscp.open(
+ "w"
+ ) as fout, out_num_samples.open("w") as fnum_samples:
+ for line in tqdm(fscp):
+ uttid, wavpath = line.strip().split(None, 1)
+
+ if wavpath.endswith("|"):
+ # Streaming input e.g. cat a.wav |
+ with kaldiio.open_like_kaldi(wavpath, "rb") as f:
+ with BytesIO(f.read()) as g:
+ wave, rate = soundfile.read(g, dtype=np.int16)
+ if wave.ndim == 2 and utt2ref_channels is not None:
+ wave = wave[:, utt2ref_channels(uttid)]
+
+ if args.fs is not None and args.fs != rate:
+ # FIXME(kamo): To use sox?
+ wave = resampy.resample(
+ wave.astype(np.float64), rate, args.fs, axis=0
+ )
+ wave = wave.astype(np.int16)
+ rate = args.fs
+
+ if args.audio_format.endswith("ark"):
+ if "flac" in args.audio_format:
+ suf = "flac"
+ elif "wav" in args.audio_format:
+ suf = "wav"
+ else:
+ raise RuntimeError("wav.ark or flac")
+
+ # NOTE(kamo): Using extended ark format style here.
+ # This format is incompatible with Kaldi
+ kaldiio.save_ark(
+ fark,
+ {uttid: (wave, rate)},
+ scp=fout,
+ append=True,
+ write_function=f"soundfile_{suf}",
+ )
+ else:
+ owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
+ soundfile.write(owavpath, wave, rate)
+ fout.write(f"{uttid} {owavpath}\n")
+ else:
+ wave, rate = soundfile.read(wavpath, dtype=np.int16)
+ if wave.ndim == 2 and utt2ref_channels is not None:
+ wave = wave[:, utt2ref_channels(uttid)]
+ save_asis = False
+
+ elif args.audio_format.endswith("ark"):
+ save_asis = False
+
+ elif Path(wavpath).suffix == "." + args.audio_format and (
+ args.fs is None or args.fs == rate
+ ):
+ save_asis = True
+
+ else:
+ save_asis = False
+
+ if save_asis:
+ # Neither --segments nor --fs are specified and
+ # the line doesn't end with "|",
+ # i.e. not using unix-pipe,
+ # only in this case,
+ # just using the original file as is.
+ fout.write(f"{uttid} {wavpath}\n")
+ else:
+ if args.fs is not None and args.fs != rate:
+ # FIXME(kamo): To use sox?
+ wave = resampy.resample(
+ wave.astype(np.float64), rate, args.fs, axis=0
+ )
+ wave = wave.astype(np.int16)
+ rate = args.fs
+
+ if args.audio_format.endswith("ark"):
+ if "flac" in args.audio_format:
+ suf = "flac"
+ elif "wav" in args.audio_format:
+ suf = "wav"
+ else:
+ raise RuntimeError("wav.ark or flac")
+
+ # NOTE(kamo): Using extended ark format style here.
+ # This format is not supported in Kaldi.
+ kaldiio.save_ark(
+ fark,
+ {uttid: (wave, rate)},
+ scp=fout,
+ append=True,
+ write_function=f"soundfile_{suf}",
+ )
+ else:
+ owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
+ soundfile.write(owavpath, wave, rate)
+ fout.write(f"{uttid} {owavpath}\n")
+ fnum_samples.write(f"{uttid} {len(wave)}\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/alimeeting/sa-asr/local/format_wav_scp.sh b/egs/alimeeting/sa-asr/local/format_wav_scp.sh
new file mode 100755
index 0000000..04fc4a5
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/format_wav_scp.sh
@@ -0,0 +1,142 @@
+#!/usr/bin/env bash
+set -euo pipefail
+SECONDS=0
+log() {
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+help_message=$(cat << EOF
+Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
+e.g.
+$0 data/test/wav.scp data/test_format/
+
+Format 'wav.scp': In short words,
+changing "kaldi-datadir" to "modified-kaldi-datadir"
+
+The 'wav.scp' format in kaldi is very flexible,
+e.g. It can use unix-pipe as describing that wav file,
+but it sometime looks confusing and make scripts more complex.
+This tools creates actual wav files from 'wav.scp'
+and also segments wav files using 'segments'.
+
+Options
+ --fs <fs>
+ --segments <segments>
+ --nj <nj>
+ --cmd <cmd>
+EOF
+)
+
+out_filename=wav.scp
+cmd=utils/run.pl
+nj=30
+fs=none
+segments=
+
+ref_channels=
+utt2ref_channels=
+
+audio_format=wav
+write_utt2num_samples=true
+
+log "$0 $*"
+. utils/parse_options.sh
+
+if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
+ log "${help_message}"
+ log "Error: invalid command line arguments"
+ exit 1
+fi
+
+. ./path.sh # Setup the environment
+
+scp=$1
+if [ ! -f "${scp}" ]; then
+ log "${help_message}"
+ echo "$0: Error: No such file: ${scp}"
+ exit 1
+fi
+dir=$2
+
+
+if [ $# -eq 2 ]; then
+ logdir=${dir}/logs
+ outdir=${dir}/data
+
+elif [ $# -eq 3 ]; then
+ logdir=$3
+ outdir=${dir}/data
+
+elif [ $# -eq 4 ]; then
+ logdir=$3
+ outdir=$4
+fi
+
+
+mkdir -p ${logdir}
+
+rm -f "${dir}/${out_filename}"
+
+
+opts=
+if [ -n "${utt2ref_channels}" ]; then
+ opts="--utt2ref-channels ${utt2ref_channels} "
+elif [ -n "${ref_channels}" ]; then
+ opts="--ref-channels ${ref_channels} "
+fi
+
+
+if [ -n "${segments}" ]; then
+ log "[info]: using ${segments}"
+ nutt=$(<${segments} wc -l)
+ nj=$((nj<nutt?nj:nutt))
+
+ split_segments=""
+ for n in $(seq ${nj}); do
+ split_segments="${split_segments} ${logdir}/segments.${n}"
+ done
+
+ utils/split_scp.pl "${segments}" ${split_segments}
+
+ ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
+ local/format_wav_scp.py \
+ ${opts} \
+ --fs ${fs} \
+ --audio-format "${audio_format}" \
+ "--segment=${logdir}/segments.JOB" \
+ "${scp}" "${outdir}/format.JOB"
+
+else
+ log "[info]: without segments"
+ nutt=$(<${scp} wc -l)
+ nj=$((nj<nutt?nj:nutt))
+
+ split_scps=""
+ for n in $(seq ${nj}); do
+ split_scps="${split_scps} ${logdir}/wav.${n}.scp"
+ done
+
+ utils/split_scp.pl "${scp}" ${split_scps}
+ ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
+ local/format_wav_scp.py \
+ ${opts} \
+ --fs "${fs}" \
+ --audio-format "${audio_format}" \
+ "${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
+fi
+
+# Workaround for the NFS problem
+ls ${outdir}/format.* > /dev/null
+
+# concatenate the .scp files together.
+for n in $(seq ${nj}); do
+ cat "${outdir}/format.${n}/wav.scp" || exit 1;
+done > "${dir}/${out_filename}" || exit 1
+
+if "${write_utt2num_samples}"; then
+ for n in $(seq ${nj}); do
+ cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
+ done > "${dir}/utt2num_samples" || exit 1
+fi
+
+log "Successfully finished. [elapsed=${SECONDS}s]"
diff --git a/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh b/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
new file mode 100755
index 0000000..9e08dba
--- /dev/null
+++ b/egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
@@ -0,0 +1,116 @@
+#!/usr/bin/env bash
+
+# 2020 @kamo-naoyuki
+# This file was copied from Kaldi and
+# I deleted parts related to wav duration
+# because we shouldn't use kaldi's command here
+# and we don't need the files actually.
+
+# Copyright 2013 Johns Hopkins University (author: Daniel Povey)
+# 2014 Tom Ko
+# 2018 Emotech LTD (author: Pawel Swietojanski)
+# Apache 2.0
+
+# This script operates on a directory, such as in data/train/,
+# that contains some subset of the following files:
+# wav.scp
+# spk2utt
+# utt2spk
+# text
+#
+# It generates the files which are used for perturbing the speed of the original data.
+
+export LC_ALL=C
+set -euo pipefail
+
+if [[ $# != 3 ]]; then
+ echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
+ echo "e.g.:"
+ echo " $0 0.9 data/train_si284 data/train_si284p"
+ exit 1
+fi
+
+factor=$1
+srcdir=$2
+destdir=$3
+label="sp"
+spk_prefix="${label}${factor}-"
+utt_prefix="${label}${factor}-"
+
+#check is sox on the path
+
+! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
+
+if [[ ! -f ${srcdir}/utt2spk ]]; then
+ echo "$0: no such file ${srcdir}/utt2spk"
+ exit 1;
+fi
+
+if [[ ${destdir} == "${srcdir}" ]]; then
+ echo "$0: this script requires <srcdir> and <destdir> to be different."
+ exit 1
+fi
+
+mkdir -p "${destdir}"
+
+<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
+<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
+<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
+if [[ ! -f ${srcdir}/utt2uniq ]]; then
+ <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
+else
+ <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
+fi
+
+
+<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
+ utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
+
+utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
+
+if [[ -f ${srcdir}/segments ]]; then
+
+ utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
+ utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
+ awk -v factor="${factor}" \
+ '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
+ >"${destdir}"/segments
+
+ utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+ # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
+ awk -v factor="${factor}" \
+ '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
+ else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
+ else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
+ > "${destdir}"/wav.scp
+ if [[ -f ${srcdir}/reco2file_and_channel ]]; then
+ utils/apply_map.pl -f 1 "${destdir}"/reco_map \
+ <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
+ fi
+
+else # no segments->wav indexed by utterance.
+ if [[ -f ${srcdir}/wav.scp ]]; then
+ utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
+ # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
+ awk -v factor="${factor}" \
+ '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
+ else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
+ else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
+ > "${destdir}"/wav.scp
+ fi
+fi
+
+if [[ -f ${srcdir}/text ]]; then
+ utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
+fi
+if [[ -f ${srcdir}/spk2gender ]]; then
+ utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
+fi
+if [[ -f ${srcdir}/utt2lang ]]; then
+ utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
+fi
+
+rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
+echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
+
+utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
diff --git a/egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl b/egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/utils/spk2utt_to_utt2spk.pl
rename to egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl
diff --git a/egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl b/egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/utils/utt2spk_to_spk2utt.pl
rename to egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl
diff --git a/egs/alimeeting/sa-asr/utils/validate_data_dir.sh b/egs/alimeeting/sa-asr/local/validate_data_dir.sh
similarity index 98%
rename from egs/alimeeting/sa-asr/utils/validate_data_dir.sh
rename to egs/alimeeting/sa-asr/local/validate_data_dir.sh
index 3eec443..37c99ae 100755
--- a/egs/alimeeting/sa-asr/utils/validate_data_dir.sh
+++ b/egs/alimeeting/sa-asr/local/validate_data_dir.sh
@@ -113,7 +113,7 @@
check_sorted_and_uniq $data/spk2utt
! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \
- <(utils/spk2utt_to_utt2spk.pl $data/spk2utt) && \
+ <(local/spk2utt_to_utt2spk.pl $data/spk2utt) && \
echo "$0: spk2utt and utt2spk do not seem to match" && exit 1;
cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts
@@ -135,7 +135,7 @@
echo "$0: text contains $n_non_print lines with non-printable characters" &&\
exit 1;
fi
- utils/validate_text.pl $data/text || exit 1;
+ local/validate_text.pl $data/text || exit 1;
check_sorted_and_uniq $data/text
text_len=`cat $data/text | wc -l`
illegal_sym_list="<s> </s> #0"
diff --git a/egs/alimeeting/sa-asr/utils/validate_text.pl b/egs/alimeeting/sa-asr/local/validate_text.pl
similarity index 100%
rename from egs/alimeeting/sa-asr/utils/validate_text.pl
rename to egs/alimeeting/sa-asr/local/validate_text.pl
diff --git a/egs/alimeeting/sa-asr/path.sh b/egs/alimeeting/sa-asr/path.sh
index 3aa13d0..5721f3f 100755
--- a/egs/alimeeting/sa-asr/path.sh
+++ b/egs/alimeeting/sa-asr/path.sh
@@ -2,5 +2,4 @@
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
-export PATH=$FUNASR_DIR/funasr/bin:$PATH
-export PATH=$PWD/utils/:$PATH
\ No newline at end of file
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/utils b/egs/alimeeting/sa-asr/utils
new file mode 120000
index 0000000..fe070dd
--- /dev/null
+++ b/egs/alimeeting/sa-asr/utils
@@ -0,0 +1 @@
+../../aishell/transformer/utils
\ No newline at end of file
diff --git a/egs/alimeeting/sa-asr/utils/filter_scp.pl b/egs/alimeeting/sa-asr/utils/filter_scp.pl
deleted file mode 100755
index b76d37f..0000000
--- a/egs/alimeeting/sa-asr/utils/filter_scp.pl
+++ /dev/null
@@ -1,87 +0,0 @@
-#!/usr/bin/env perl
-# Copyright 2010-2012 Microsoft Corporation
-# Johns Hopkins University (author: Daniel Povey)
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
-# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
-# MERCHANTABLITY OR NON-INFRINGEMENT.
-# See the Apache 2 License for the specific language governing permissions and
-# limitations under the License.
-
-
-# This script takes a list of utterance-ids or any file whose first field
-# of each line is an utterance-id, and filters an scp
-# file (or any file whose "n-th" field is an utterance id), printing
-# out only those lines whose "n-th" field is in id_list. The index of
-# the "n-th" field is 1, by default, but can be changed by using
-# the -f <n> switch
-
-$exclude = 0;
-$field = 1;
-$shifted = 0;
-
-do {
- $shifted=0;
- if ($ARGV[0] eq "--exclude") {
- $exclude = 1;
- shift @ARGV;
- $shifted=1;
- }
- if ($ARGV[0] eq "-f") {
- $field = $ARGV[1];
- shift @ARGV; shift @ARGV;
- $shifted=1
- }
-} while ($shifted);
-
-if(@ARGV < 1 || @ARGV > 2) {
- die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
- "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
- "Note: only the first field of each line in id_list matters. With --exclude, prints\n" .
- "only the lines that were *not* in id_list.\n" .
- "Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
- "If your older scripts (written before Oct 2014) stopped working and you used the\n" .
- "-f option, add 1 to the argument.\n" .
- "See also: utils/filter_scp.pl .\n";
-}
-
-
-$idlist = shift @ARGV;
-open(F, "<$idlist") || die "Could not open id-list file $idlist";
-while(<F>) {
- @A = split;
- @A>=1 || die "Invalid id-list file line $_";
- $seen{$A[0]} = 1;
-}
-
-if ($field == 1) { # Treat this as special case, since it is common.
- while(<>) {
- $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
- # $1 is what we filter on.
- if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
- print $_;
- }
- }
-} else {
- while(<>) {
- @A = split;
- @A > 0 || die "Invalid scp file line $_";
- @A >= $field || die "Invalid scp file line $_";
- if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
- print $_;
- }
- }
-}
-
-# tests:
-# the following should print "foo 1"
-# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo)
-# the following should print "bar 2".
-# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2)
diff --git a/egs/alimeeting/sa-asr/utils/parse_options.sh b/egs/alimeeting/sa-asr/utils/parse_options.sh
deleted file mode 100755
index 71fb9e5..0000000
--- a/egs/alimeeting/sa-asr/utils/parse_options.sh
+++ /dev/null
@@ -1,97 +0,0 @@
-#!/usr/bin/env bash
-
-# Copyright 2012 Johns Hopkins University (Author: Daniel Povey);
-# Arnab Ghoshal, Karel Vesely
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
-# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
-# MERCHANTABLITY OR NON-INFRINGEMENT.
-# See the Apache 2 License for the specific language governing permissions and
-# limitations under the License.
-
-
-# Parse command-line options.
-# To be sourced by another script (as in ". parse_options.sh").
-# Option format is: --option-name arg
-# and shell variable "option_name" gets set to value "arg."
-# The exception is --help, which takes no arguments, but prints the
-# $help_message variable (if defined).
-
-
-###
-### The --config file options have lower priority to command line
-### options, so we need to import them first...
-###
-
-# Now import all the configs specified by command-line, in left-to-right order
-for ((argpos=1; argpos<$#; argpos++)); do
- if [ "${!argpos}" == "--config" ]; then
- argpos_plus1=$((argpos+1))
- config=${!argpos_plus1}
- [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
- . $config # source the config file.
- fi
-done
-
-
-###
-### Now we process the command line options
-###
-while true; do
- [ -z "${1:-}" ] && break; # break if there are no arguments
- case "$1" in
- # If the enclosing script is called with --help option, print the help
- # message and exit. Scripts should put help messages in $help_message
- --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
- else printf "$help_message\n" 1>&2 ; fi;
- exit 0 ;;
- --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
- exit 1 ;;
- # If the first command-line argument begins with "--" (e.g. --foo-bar),
- # then work out the variable name as $name, which will equal "foo_bar".
- --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
- # Next we test whether the variable in question is undefned-- if so it's
- # an invalid option and we die. Note: $0 evaluates to the name of the
- # enclosing script.
- # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
- # is undefined. We then have to wrap this test inside "eval" because
- # foo_bar is itself inside a variable ($name).
- eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
-
- oldval="`eval echo \\$$name`";
- # Work out whether we seem to be expecting a Boolean argument.
- if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
- was_bool=true;
- else
- was_bool=false;
- fi
-
- # Set the variable to the right value-- the escaped quotes make it work if
- # the option had spaces, like --cmd "queue.pl -sync y"
- eval $name=\"$2\";
-
- # Check that Boolean-valued arguments are really Boolean.
- if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
- echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
- exit 1;
- fi
- shift 2;
- ;;
- *) break;
- esac
-done
-
-
-# Check for an empty argument to the --cmd option, which can easily occur as a
-# result of scripting errors.
-[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
-
-
-true; # so this script returns exit code 0.
diff --git a/egs/alimeeting/sa-asr/utils/split_scp.pl b/egs/alimeeting/sa-asr/utils/split_scp.pl
deleted file mode 100755
index 0876dcb..0000000
--- a/egs/alimeeting/sa-asr/utils/split_scp.pl
+++ /dev/null
@@ -1,246 +0,0 @@
-#!/usr/bin/env perl
-
-# Copyright 2010-2011 Microsoft Corporation
-
-# See ../../COPYING for clarification regarding multiple authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
-# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
-# MERCHANTABLITY OR NON-INFRINGEMENT.
-# See the Apache 2 License for the specific language governing permissions and
-# limitations under the License.
-
-
-# This program splits up any kind of .scp or archive-type file.
-# If there is no utt2spk option it will work on any text file and
-# will split it up with an approximately equal number of lines in
-# each but.
-# With the --utt2spk option it will work on anything that has the
-# utterance-id as the first entry on each line; the utt2spk file is
-# of the form "utterance speaker" (on each line).
-# It splits it into equal size chunks as far as it can. If you use the utt2spk
-# option it will make sure these chunks coincide with speaker boundaries. In
-# this case, if there are more chunks than speakers (and in some other
-# circumstances), some of the resulting chunks will be empty and it will print
-# an error message and exit with nonzero status.
-# You will normally call this like:
-# split_scp.pl scp scp.1 scp.2 scp.3 ...
-# or
-# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
-# Note that you can use this script to split the utt2spk file itself,
-# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
-
-# You can also call the scripts like:
-# split_scp.pl -j 3 0 scp scp.0
-# [note: with this option, it assumes zero-based indexing of the split parts,
-# i.e. the second number must be 0 <= n < num-jobs.]
-
-use warnings;
-
-$num_jobs = 0;
-$job_id = 0;
-$utt2spk_file = "";
-$one_based = 0;
-
-for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
- if ($ARGV[0] eq "-j") {
- shift @ARGV;
- $num_jobs = shift @ARGV;
- $job_id = shift @ARGV;
- }
- if ($ARGV[0] =~ /--utt2spk=(.+)/) {
- $utt2spk_file=$1;
- shift;
- }
- if ($ARGV[0] eq '--one-based') {
- $one_based = 1;
- shift @ARGV;
- }
-}
-
-if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
- $job_id - $one_based >= $num_jobs)) {
- die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
- ($one_based ? " --one-based" : "") . "'\n"
-}
-
-$one_based
- and $job_id--;
-
-if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
- die
-"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
- or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
- ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
-}
-
-$error = 0;
-$inscp = shift @ARGV;
-if ($num_jobs == 0) { # without -j option
- @OUTPUTS = @ARGV;
-} else {
- for ($j = 0; $j < $num_jobs; $j++) {
- if ($j == $job_id) {
- if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
- else { push @OUTPUTS, "-"; }
- } else {
- push @OUTPUTS, "/dev/null";
- }
- }
-}
-
-if ($utt2spk_file ne "") { # We have the --utt2spk option...
- open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
- while(<$u_fh>) {
- @A = split;
- @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
- ($u,$s) = @A;
- $utt2spk{$u} = $s;
- }
- close $u_fh;
- open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
- @spkrs = ();
- while(<$i_fh>) {
- @A = split;
- if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
- $u = $A[0];
- $s = $utt2spk{$u};
- defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
- if(!defined $spk_count{$s}) {
- push @spkrs, $s;
- $spk_count{$s} = 0;
- $spk_data{$s} = []; # ref to new empty array.
- }
- $spk_count{$s}++;
- push @{$spk_data{$s}}, $_;
- }
- # Now split as equally as possible ..
- # First allocate spks to files by allocating an approximately
- # equal number of speakers.
- $numspks = @spkrs; # number of speakers.
- $numscps = @OUTPUTS; # number of output files.
- if ($numspks < $numscps) {
- die "$0: Refusing to split data because number of speakers $numspks " .
- "is less than the number of output .scp files $numscps\n";
- }
- for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
- $scparray[$scpidx] = []; # [] is array reference.
- }
- for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
- $scpidx = int(($spkidx*$numscps) / $numspks);
- $spk = $spkrs[$spkidx];
- push @{$scparray[$scpidx]}, $spk;
- $scpcount[$scpidx] += $spk_count{$spk};
- }
-
- # Now will try to reassign beginning + ending speakers
- # to different scp's and see if it gets more balanced.
- # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
- # We can show that if considering changing just 2 scp's, we minimize
- # this by minimizing the squared difference in sizes. This is
- # equivalent to minimizing the absolute difference in sizes. This
- # shows this method is bound to converge.
-
- $changed = 1;
- while($changed) {
- $changed = 0;
- for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
- # First try to reassign ending spk of this scp.
- if($scpidx < $numscps-1) {
- $sz = @{$scparray[$scpidx]};
- if($sz > 0) {
- $spk = $scparray[$scpidx]->[$sz-1];
- $count = $spk_count{$spk};
- $nutt1 = $scpcount[$scpidx];
- $nutt2 = $scpcount[$scpidx+1];
- if( abs( ($nutt2+$count) - ($nutt1-$count))
- < abs($nutt2 - $nutt1)) { # Would decrease
- # size-diff by reassigning spk...
- $scpcount[$scpidx+1] += $count;
- $scpcount[$scpidx] -= $count;
- pop @{$scparray[$scpidx]};
- unshift @{$scparray[$scpidx+1]}, $spk;
- $changed = 1;
- }
- }
- }
- if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
- $spk = $scparray[$scpidx]->[0];
- $count = $spk_count{$spk};
- $nutt1 = $scpcount[$scpidx-1];
- $nutt2 = $scpcount[$scpidx];
- if( abs( ($nutt2-$count) - ($nutt1+$count))
- < abs($nutt2 - $nutt1)) { # Would decrease
- # size-diff by reassigning spk...
- $scpcount[$scpidx-1] += $count;
- $scpcount[$scpidx] -= $count;
- shift @{$scparray[$scpidx]};
- push @{$scparray[$scpidx-1]}, $spk;
- $changed = 1;
- }
- }
- }
- }
- # Now print out the files...
- for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
- $scpfile = $OUTPUTS[$scpidx];
- ($scpfile ne '-' ? open($f_fh, '>', $scpfile)
- : open($f_fh, '>&', \*STDOUT)) ||
- die "$0: Could not open scp file $scpfile for writing: $!\n";
- $count = 0;
- if(@{$scparray[$scpidx]} == 0) {
- print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
- "$scpfile (too many splits and too few speakers?)\n";
- $error = 1;
- } else {
- foreach $spk ( @{$scparray[$scpidx]} ) {
- print $f_fh @{$spk_data{$spk}};
- $count += $spk_count{$spk};
- }
- $count == $scpcount[$scpidx] || die "Count mismatch [code error]";
- }
- close($f_fh);
- }
-} else {
- # This block is the "normal" case where there is no --utt2spk
- # option and we just break into equal size chunks.
-
- open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
-
- $numscps = @OUTPUTS; # size of array.
- @F = ();
- while(<$i_fh>) {
- push @F, $_;
- }
- $numlines = @F;
- if($numlines == 0) {
- print STDERR "$0: error: empty input scp file $inscp\n";
- $error = 1;
- }
- $linesperscp = int( $numlines / $numscps); # the "whole part"..
- $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
- $remainder = $numlines - ($linesperscp * $numscps);
- ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
- # [just doing int() rounds down].
- $n = 0;
- for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
- $scpfile = $OUTPUTS[$scpidx];
- ($scpfile ne '-' ? open($o_fh, '>', $scpfile)
- : open($o_fh, '>&', \*STDOUT)) ||
- die "$0: Could not open scp file $scpfile for writing: $!\n";
- for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
- print $o_fh $F[$n++];
- }
- close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
- }
- $n == $numlines || die "$n != $numlines [code error]";
-}
-
-exit ($error);
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index c18472f..18f0add 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -40,6 +40,8 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
@@ -90,6 +92,12 @@
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ if asr_train_args.frontend=='wav_frontend':
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
+ else:
+ frontend_class=frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(**asr_train_args.frontend_conf).eval()
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
@@ -197,12 +205,21 @@
"""
assert check_argument_types()
-
+
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
- batch = {"speech": speech, "speech_lengths": speech_lengths}
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
@@ -275,6 +292,7 @@
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
+ mc: bool = False,
**kwargs,
):
inference_pipeline = inference_modelscope(
@@ -305,6 +323,7 @@
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
+ mc=mc,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -337,6 +356,7 @@
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
+ mc: bool = False,
param_dict: dict = None,
**kwargs,
):
@@ -406,7 +426,7 @@
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
- mc=True,
+ mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
@@ -415,7 +435,7 @@
allow_variable_data_keys=allow_variable_data_keys,
inference=True,
)
-
+
finish_count = 0
file_count = 1
# 7 .Start for-loop
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index e165531..9a1ffe5 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -71,7 +71,13 @@
)
group.add_argument("--key_file", type=str_or_none)
group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-
+ group.add_argument(
+ "--mc",
+ type=bool,
+ default=False,
+ help="MultiChannel input",
+ )
+
group = parser.add_argument_group("The model configuration related")
group.add_argument(
"--vad_infer_config",
diff --git a/funasr/bin/asr_train.py b/funasr/bin/asr_train.py
index c1e2cb2..a43472c 100755
--- a/funasr/bin/asr_train.py
+++ b/funasr/bin/asr_train.py
@@ -2,14 +2,6 @@
import os
-import logging
-
-logging.basicConfig(
- level='INFO',
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-)
-
from funasr.tasks.asr import ASRTask
diff --git a/funasr/bin/sa_asr_inference.py b/funasr/bin/sa_asr_inference.py
index be63af1..ec575df 100644
--- a/funasr/bin/sa_asr_inference.py
+++ b/funasr/bin/sa_asr_inference.py
@@ -35,6 +35,8 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
@@ -85,6 +87,12 @@
asr_train_config, asr_model_file, cmvn_file, device
)
frontend = None
+ if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+ if asr_train_args.frontend=='wav_frontend':
+ frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
+ else:
+ frontend_class=frontend_choices.get_class(asr_train_args.frontend)
+ frontend = frontend_class(**asr_train_args.frontend_conf).eval()
logging.info("asr_model: {}".format(asr_model))
logging.info("asr_train_args: {}".format(asr_train_args))
@@ -201,7 +209,16 @@
if isinstance(profile, np.ndarray):
profile = torch.tensor(profile)
- batch = {"speech": speech, "speech_lengths": speech_lengths}
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths)
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ self.asr_model.frontend = None
+ else:
+ feats = speech
+ feats_len = speech_lengths
+ lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+ batch = {"speech": feats, "speech_lengths": feats_len}
# a. To device
batch = to_device(batch, device=self.device)
@@ -308,6 +325,7 @@
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
+ mc: bool = False,
**kwargs,
):
inference_pipeline = inference_modelscope(
@@ -338,6 +356,7 @@
ngram_weight=ngram_weight,
nbest=nbest,
num_workers=num_workers,
+ mc=mc,
**kwargs,
)
return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -370,6 +389,7 @@
ngram_weight: float = 0.9,
nbest: int = 1,
num_workers: int = 1,
+ mc: bool = False,
param_dict: dict = None,
**kwargs,
):
@@ -437,7 +457,7 @@
data_path_and_name_and_type,
dtype=dtype,
fs=fs,
- mc=True,
+ mc=mc,
batch_size=batch_size,
key_file=key_file,
num_workers=num_workers,
diff --git a/funasr/bin/sa_asr_train.py b/funasr/bin/sa_asr_train.py
index c7c7c42..07b9b19 100755
--- a/funasr/bin/sa_asr_train.py
+++ b/funasr/bin/sa_asr_train.py
@@ -2,14 +2,6 @@
import os
-import logging
-
-logging.basicConfig(
- level='INFO',
- format=f"[{os.uname()[1].split('.')[0]}]"
- f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-)
-
from funasr.tasks.sa_asr import ASRTask
diff --git a/funasr/losses/label_smoothing_loss.py b/funasr/losses/label_smoothing_loss.py
index 28df73f..8f63df9 100644
--- a/funasr/losses/label_smoothing_loss.py
+++ b/funasr/losses/label_smoothing_loss.py
@@ -79,3 +79,49 @@
loss = self.criterion(pred, label)
denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
return loss.masked_fill(pad_mask, 0).sum() / denom
+
+
+class NllLoss(nn.Module):
+ """Nll loss.
+
+ :param int size: the number of class
+ :param int padding_idx: ignored class id
+ :param bool normalize_length: normalize loss by sequence length if True
+ :param torch.nn.Module criterion: loss function
+ """
+
+ def __init__(
+ self,
+ size,
+ padding_idx,
+ normalize_length=False,
+ criterion=nn.NLLLoss(reduction='none'),
+ ):
+ """Construct an LabelSmoothingLoss object."""
+ super(NllLoss, self).__init__()
+ self.criterion = criterion
+ self.padding_idx = padding_idx
+ self.size = size
+ self.true_dist = None
+ self.normalize_length = normalize_length
+
+ def forward(self, x, target):
+ """Compute loss between x and target.
+
+ :param torch.Tensor x: prediction (batch, seqlen, class)
+ :param torch.Tensor target:
+ target signal masked with self.padding_id (batch, seqlen)
+ :return: scalar float value
+ :rtype torch.Tensor
+ """
+ assert x.size(2) == self.size
+ batch_size = x.size(0)
+ x = x.view(-1, self.size)
+ target = target.view(-1)
+ with torch.no_grad():
+ ignore = target == self.padding_idx # (B,)
+ total = len(target) - ignore.sum().item()
+ target = target.masked_fill(ignore, 0) # avoid -1 index
+ kl = self.criterion(x , target)
+ denom = total if self.normalize_length else batch_size
+ return kl.masked_fill(ignore, 0).sum() / denom
diff --git a/funasr/models/decoder/transformer_decoder.py b/funasr/models/decoder/transformer_decoder.py
index aed7f20..45fdda8 100644
--- a/funasr/models/decoder/transformer_decoder.py
+++ b/funasr/models/decoder/transformer_decoder.py
@@ -13,6 +13,7 @@
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.modules.attention import MultiHeadedAttention
+from funasr.modules.attention import CosineDistanceAttention
from funasr.modules.dynamic_conv import DynamicConvolution
from funasr.modules.dynamic_conv2d import DynamicConvolution2D
from funasr.modules.embedding import PositionalEncoding
@@ -763,4 +764,429 @@
normalize_before,
concat_after,
),
- )
\ No newline at end of file
+ )
+
+class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
+
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ spker_embedding_dim: int = 256,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ input_layer: str = "embed",
+ use_asr_output_layer: bool = True,
+ use_spk_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ attention_dim = encoder_output_size
+
+ if input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(vocab_size, attention_dim),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ elif input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(vocab_size, attention_dim),
+ torch.nn.LayerNorm(attention_dim),
+ torch.nn.Dropout(dropout_rate),
+ torch.nn.ReLU(),
+ pos_enc_class(attention_dim, positional_dropout_rate),
+ )
+ else:
+ raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+ self.normalize_before = normalize_before
+ if self.normalize_before:
+ self.after_norm = LayerNorm(attention_dim)
+ if use_asr_output_layer:
+ self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
+ else:
+ self.asr_output_layer = None
+
+ if use_spk_output_layer:
+ self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
+ else:
+ self.spk_output_layer = None
+
+ self.cos_distance_att = CosineDistanceAttention()
+
+ self.decoder1 = None
+ self.decoder2 = None
+ self.decoder3 = None
+ self.decoder4 = None
+
+ def forward(
+ self,
+ asr_hs_pad: torch.Tensor,
+ spk_hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ profile: torch.Tensor,
+ profile_lens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ tgt = ys_in_pad
+ # tgt_mask: (B, 1, L)
+ tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+ # m: (1, L, L)
+ m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
+ # tgt_mask: (B, L, L)
+ tgt_mask = tgt_mask & m
+
+ asr_memory = asr_hs_pad
+ spk_memory = spk_hs_pad
+ memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
+ # Spk decoder
+ x = self.embed(tgt)
+
+ x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
+ x, tgt_mask, asr_memory, spk_memory, memory_mask
+ )
+ x, tgt_mask, spk_memory, memory_mask = self.decoder2(
+ x, tgt_mask, spk_memory, memory_mask
+ )
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.spk_output_layer is not None:
+ x = self.spk_output_layer(x)
+ dn, weights = self.cos_distance_att(x, profile, profile_lens)
+ # Asr decoder
+ x, tgt_mask, asr_memory, memory_mask = self.decoder3(
+ z, tgt_mask, asr_memory, memory_mask, dn
+ )
+ x, tgt_mask, asr_memory, memory_mask = self.decoder4(
+ x, tgt_mask, asr_memory, memory_mask
+ )
+
+ if self.normalize_before:
+ x = self.after_norm(x)
+ if self.asr_output_layer is not None:
+ x = self.asr_output_layer(x)
+
+ olens = tgt_mask.sum(1)
+ return x, weights, olens
+
+
+ def forward_one_step(
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ asr_memory: torch.Tensor,
+ spk_memory: torch.Tensor,
+ profile: torch.Tensor,
+ cache: List[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+
+ x = self.embed(tgt)
+
+ if cache is None:
+ cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
+ new_cache = []
+ x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
+ x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
+ )
+ new_cache.append(x)
+ for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
+ x, tgt_mask, spk_memory, _ = decoder(
+ x, tgt_mask, spk_memory, None, cache=c
+ )
+ new_cache.append(x)
+ if self.normalize_before:
+ x = self.after_norm(x)
+ else:
+ x = x
+ if self.spk_output_layer is not None:
+ x = self.spk_output_layer(x)
+ dn, weights = self.cos_distance_att(x, profile, None)
+
+ x, tgt_mask, asr_memory, _ = self.decoder3(
+ z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
+ )
+ new_cache.append(x)
+
+ for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
+ x, tgt_mask, asr_memory, _ = decoder(
+ x, tgt_mask, asr_memory, None, cache=c
+ )
+ new_cache.append(x)
+
+ if self.normalize_before:
+ y = self.after_norm(x[:, -1])
+ else:
+ y = x[:, -1]
+ if self.asr_output_layer is not None:
+ y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
+
+ return y, weights, new_cache
+
+ def score(self, ys, state, asr_enc, spk_enc, profile):
+ """Score."""
+ ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
+ logp, weights, state = self.forward_one_step(
+ ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
+ )
+ return logp.squeeze(0), weights.squeeze(), state
+
+class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
+ def __init__(
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ spker_embedding_dim: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ asr_num_blocks: int = 6,
+ spk_num_blocks: int = 3,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_asr_output_layer: bool = True,
+ use_spk_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ ):
+ assert check_argument_types()
+ super().__init__(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ spker_embedding_dim=spker_embedding_dim,
+ dropout_rate=dropout_rate,
+ positional_dropout_rate=positional_dropout_rate,
+ input_layer=input_layer,
+ use_asr_output_layer=use_asr_output_layer,
+ use_spk_output_layer=use_spk_output_layer,
+ pos_enc_class=pos_enc_class,
+ normalize_before=normalize_before,
+ )
+
+ attention_dim = encoder_output_size
+
+ self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ )
+ self.decoder2 = repeat(
+ spk_num_blocks - 1,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+
+ self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
+ attention_dim,
+ spker_embedding_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ )
+ self.decoder4 = repeat(
+ asr_num_blocks - 1,
+ lambda lnum: DecoderLayer(
+ attention_dim,
+ MultiHeadedAttention(
+ attention_heads, attention_dim, self_attention_dropout_rate
+ ),
+ MultiHeadedAttention(
+ attention_heads, attention_dim, src_attention_dropout_rate
+ ),
+ PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+
+class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
+ self.size = size
+ self.self_attn = self_attn
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ self.norm2 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
+
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ else:
+ # compute only the last frame query keeping dim: max_time_out -> 1
+ assert cache.shape == (
+ tgt.shape[0],
+ tgt.shape[1] - 1,
+ self.size,
+ ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+ tgt_q = tgt[:, -1:, :]
+ residual = residual[:, -1:, :]
+ tgt_q_mask = None
+ if tgt_mask is not None:
+ tgt_q_mask = tgt_mask[:, -1:, :]
+
+ if self.concat_after:
+ tgt_concat = torch.cat(
+ (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
+ )
+ x = residual + self.concat_linear1(tgt_concat)
+ else:
+ x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+ if not self.normalize_before:
+ x = self.norm1(x)
+ z = x
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
+
+ if self.concat_after:
+ x_concat = torch.cat(
+ (x, skip), dim=-1
+ )
+ x = residual + self.concat_linear2(x_concat)
+ else:
+ x = residual + self.dropout(skip)
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
+
+class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
+
+ def __init__(
+ self,
+ size,
+ d_size,
+ src_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an DecoderLayer object."""
+ super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
+ self.size = size
+ self.src_attn = src_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ self.norm2 = LayerNorm(size)
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ self.spk_linear = nn.Linear(d_size, size, bias=False)
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
+
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ else:
+
+ tgt_q = tgt[:, -1:, :]
+ residual = residual[:, -1:, :]
+ tgt_q_mask = None
+ if tgt_mask is not None:
+ tgt_q_mask = tgt_mask[:, -1:, :]
+
+ x = tgt_q
+ if self.normalize_before:
+ x = self.norm2(x)
+ if self.concat_after:
+ x_concat = torch.cat(
+ (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
+ )
+ x = residual + self.concat_linear2(x_concat)
+ else:
+ x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
+ if not self.normalize_before:
+ x = self.norm2(x)
+ residual = x
+
+ if dn!=None:
+ x = x + self.spk_linear(dn)
+ if self.normalize_before:
+ x = self.norm3(x)
+
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm3(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, memory, memory_mask
\ No newline at end of file
diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py
index 0d4097e..f694cc2 100644
--- a/funasr/models/e2e_sa_asr.py
+++ b/funasr/models/e2e_sa_asr.py
@@ -16,9 +16,8 @@
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
- LabelSmoothingLoss, # noqa: H301
+ LabelSmoothingLoss, NllLoss # noqa: H301
)
-from funasr.losses.nll_loss import NllLoss
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
diff --git a/funasr/tasks/sa_asr.py b/funasr/tasks/sa_asr.py
index 738ec52..7cfcbd0 100644
--- a/funasr/tasks/sa_asr.py
+++ b/funasr/tasks/sa_asr.py
@@ -28,7 +28,7 @@
from funasr.models.decoder.transformer_decoder import (
DynamicConvolution2DTransformerDecoder, # noqa: H301
)
-from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder
+from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
from funasr.models.decoder.transformer_decoder import (
LightweightConvolution2DTransformerDecoder, # noqa: H301
--
Gitblit v1.9.1