From cf8e000a84e888495dcf30c4dbfecea1ee7ab4e2 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 07 八月 2023 16:13:37 +0800
Subject: [PATCH] Merge pull request #807 from alibaba-damo-academy/dev_wjm
---
egs/callhome/diarization/sond/finetune.sh | 104 ++++++++++++++++++++++++++++++++++-----------------
1 files changed, 69 insertions(+), 35 deletions(-)
diff --git a/egs/callhome/diarization/sond/finetune.sh b/egs/callhome/diarization/sond/finetune.sh
index f2428ad..1b53690 100644
--- a/egs/callhome/diarization/sond/finetune.sh
+++ b/egs/callhome/diarization/sond/finetune.sh
@@ -30,20 +30,24 @@
ln -s ${kaldi_root}/egs/callhome_diarization/v2/utils ./utils
fi
+# callhome data root like path/to/NIST/LDC2001S97
+callhome_root=
+if [ -z "${kaldi_root}" ]; then
+ echo "We need callhome corpus to prepare data."
+ exit;
+fi
+
# machines configuration
gpu_devices="0,1,2,3" # for V100-16G, need 4 gpus.
gpu_num=4
count=1
# general configuration
-stage=1
-stop_stage=1
+stage=0
+stop_stage=10
# number of jobs for data process
nj=16
sr=8000
-
-# dataset related
-callhome_root=path/to/NIST/LDC2001S97
# experiment configuration
lang=en
@@ -92,8 +96,24 @@
# Prepare datasets
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
- echp "Stage 0: Prepare callhome data."
+ echo "Stage 0: Prepare callhome data."
local/make_callhome.sh ${callhome_root} ${datadir}/
+
+ # split ref.rttm
+ for dset in callhome1 callhome2; do
+ rm -rf ${datadir}/${dset}/ref.rttm
+ for name in `awk '{print $1}' ${datadir}/${dset}/wav.scp`; do
+ grep ${name} ${datadir}/callhome/fullref.rttm >> ${datadir}/${dset}/ref.rttm;
+ done
+
+ # filter out records which don't have rttm labels.
+ awk '{print $2}' ${datadir}/${dset}/ref.rttm | sort | uniq > ${datadir}/${dset}/uttid
+ mv ${datadir}/${dset}/wav.scp ${datadir}/${dset}/wav.scp.bak
+ awk '{if (NR==FNR){a[$1]=1}else{if (a[$1]==1){print $0}}}' ${datadir}/${dset}/uttid ${datadir}/${dset}/wav.scp.bak > ${datadir}/${dset}/wav.scp
+ mkdir ${datadir}/${dset}/raw
+ mv ${datadir}/${dset}/{reco2num_spk,segments,spk2utt,utt2spk,uttid,wav.scp.bak} ${datadir}/${dset}/raw/
+ awk '{print $1,$1}' ${datadir}/${dset}/wav.scp > ${datadir}/${dset}/utt2spk
+ done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
@@ -123,10 +143,10 @@
mkdir -p ${dumpdir}/${dset}/nonoverlap_0s
python -Wignore script/extract_nonoverlap_segments.py \
${datadir}/${dset}/wav.scp ${datadir}/${dset}/ref.rttm ${dumpdir}/${dset}/nonoverlap_0s \
- --min_dur 0 --max_spk_num 8 --sr ${sr} --no_pbar --nj ${nj}
+ --min_dur 0.1 --max_spk_num 8 --sr ${sr} --no_pbar --nj ${nj}
mkdir -p ${datadir}/${dset}/nonoverlap_0s
- find `pwd`/${dumpdir}/${dset}/nonoverlap_0s | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/${dset}/nonoverlap_0s/wav.scp
+ find ${dumpdir}/${dset}/nonoverlap_0s/ -iname "*.wav" | sort | awk -F'[/.]' '{print $(NF-1),$0}' > ${datadir}/${dset}/nonoverlap_0s/wav.scp
awk -F'[/.]' '{print $(NF-1),$(NF-2)}' ${datadir}/${dset}/nonoverlap_0s/wav.scp > ${datadir}/${dset}/nonoverlap_0s/utt2spk
echo "Done."
done
@@ -134,37 +154,39 @@
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Stage 3: Generate fbank features"
- home_path=`pwd`
+ home_path=$(pwd)
cd ${kaldi_root}/egs/callhome_diarization/v2 || exit
- . ./cmd.sh
+ export train_cmd="run.pl"
+ export cmd="run.pl"
. ./path.sh
+ cd $home_path || exit
+ ln -s ${kaldi_root}/egs/callhome_diarization/v2/steps ./
for dset in callhome1 callhome2; do
+ utils/fix_data_dir.sh ${datadir}/${dset}
steps/make_fbank.sh --write-utt2num-frames true --fbank-config conf/fbank.conf --nj ${nj} --cmd "$train_cmd" \
${datadir}/${dset} ${expdir}/make_fbank/${dset} ${dumpdir}/${dset}/fbank
- utils/fix_data_dir.sh ${datadir}/${dset}
done
+ rm -f steps
- for dset in callhome1/nonoverlap_0s callhome2/nonoverlap_0s; do
- steps/make_fbank.sh --write-utt2num-frames true --fbank-config conf/fbank.conf --nj ${nj} --cmd "$train_cmd" \
- ${datadir}/${dset} ${expdir}/make_fbank/${dset} ${dumpdir}/${dset}/fbank
- utils/fix_data_dir.sh ${datadir}/${dset}
- done
-
- cd ${home_path} || exit
fi
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
echo "Stage 4: Extract speaker embeddings."
- git lfs install
- git clone https://www.modelscope.cn/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch.git
- mv speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ${expdir}/
-
sv_exp_dir=exp/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch
- sed "s/input_size: null/input_size: 80/g" ${sv_exp_dir}/sv.yaml > ${sv_exp_dir}/sv_fbank.yaml
+
+ if [ ! -e ${sv_exp_dir} ]; then
+ echo "start to download sv models"
+ git lfs install
+ git clone https://www.modelscope.cn/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch.git
+ mv speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ${expdir}/
+ echo "Done."
+ fi
+
for dset in callhome1/nonoverlap_0s callhome2/nonoverlap_0s; do
- key_file=${datadir}/${dset}/feats.scp
+ echo "Start to extract speaker embeddings for ${dset}"
+ key_file=${datadir}/${dset}/wav.scp
num_scp_file="$(<${key_file} wc -l)"
_nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
_logdir=${dumpdir}/${dset}/xvecs
@@ -182,12 +204,15 @@
--njob ${njob} \
--ngpu "${_ngpu}" \
--gpuid_list ${gpuid_list} \
- --data_path_and_name_and_type "${key_file},speech,kaldi_ark" \
+ --data_path_and_name_and_type "${key_file},speech,sound" \
--key_file "${_logdir}"/keys.JOB.scp \
- --sv_train_config ${sv_exp_dir}/sv_fbank.yaml \
+ --sv_train_config ${sv_exp_dir}/sv.yaml \
--sv_model_file ${sv_exp_dir}/sv.pth \
--output_dir "${_logdir}"/output.JOB
cat ${_logdir}/output.*/xvector.scp | sort > ${datadir}/${dset}/utt2xvec
+
+ python script/calc_num_frames.py ${key_file} ${datadir}/${dset}/utt2num_frames
+ echo "Done."
done
fi
@@ -200,7 +225,7 @@
python -Wignore script/calc_real_meeting_frame_labels.py \
${datadir}/${dset} ${dumpdir}/${dset}/labels \
--n_spk 8 --frame_shift 0.01 --nj 16 --sr 8000
- find `pwd`/${dumpdir}/${dset}/labels -iname "*.lbl.mat" | awk -F'[/.]' '{print $(NF-2),$0}' | sort > ${datadir}/${dset}/labels.scp
+ find `pwd`/${dumpdir}/${dset}/labels/ -iname "*.lbl.mat" | awk -F'[/.]' '{print $(NF-2),$0}' | sort > ${datadir}/${dset}/labels.scp
done
fi
@@ -218,7 +243,7 @@
echo "Stage 6: start to dump for callhome1."
python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
- --out ${dumpdir}/callhome1/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode test \
+ --out ${dumpdir}/callhome1/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode train \
--chunk_size 1600 --chunk_shift 400 --add_mid_to_speaker true
mkdir -p ${datadir}/callhome1/dumped_files
@@ -255,6 +280,15 @@
# Finetune model on callhome1, this will take about 1.5 hours.
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
echo "Stage 7: Finetune pretrained model on callhome1."
+
+ if [ ! -e ${expdir}/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch ]; then
+ echo "start to download pretrained models"
+ git lfs install
+ git clone https://www.modelscope.cn/damo/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch.git
+ mv speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch ${expdir}/
+ echo "Done."
+ fi
+
world_size=$gpu_num # run on one machine
mkdir -p ${expdir}/${model_dir}
mkdir -p ${expdir}/${model_dir}/log
@@ -295,7 +329,7 @@
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
--valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
--valid_shape_file ${expdir}/${valid_set}_states/speech_shape \
- --init_param exp/pretrained_models/phase2.pth \
+ --init_param ${expdir}/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch/sond.pth \
--unused_parameters true \
${init_opt} \
${freeze_opt} \
@@ -364,7 +398,7 @@
# Scoring for finetuned model, you may get a DER like:
# oracle_vad | system_vad
-# 7.28 | 8.06
+# 7.32 | 8.14
if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
echo "stage 9: Scoring finetuned models"
if [ ! -e dscore ]; then
@@ -403,11 +437,11 @@
# Then find the wav files to construct wav.scp and put it at data/callhome2/wav.scp.
# After iteratively perform SOAP, you will get DER results like:
# iters : oracle_vad | system_vad
-# iter_0: 9.63 | 10.43
-# iter_1: 9.17 | 10.03
-# iter_2: 9.11 | 9.98
-# iter_3: 9.08 | 9.96
-# iter_4: 9.07 | 9.95
+# iter_0: 9.58 | 10.46
+# iter_1: 9.22 | 10.15
+# iter_2: 9.21 | 10.14
+# iter_3: 9.30 | 10.24
+# iter_4: 9.29 | 10.23
if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
if [ ! -e ${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ]; then
git lfs install
--
Gitblit v1.9.1