From bee8346c4b0fd9eb4acb8910620be6173f31cf92 Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期三, 02 八月 2023 10:59:31 +0800
Subject: [PATCH] TOLD/SOND: update finetune and train recipe

---
 egs/callhome/diarization/sond/finetune.sh |   25 ++++++++----
 egs/callhome/diarization/sond/run.sh      |   78 +++++++--------------------------------
 2 files changed, 31 insertions(+), 72 deletions(-)

diff --git a/egs/callhome/diarization/sond/finetune.sh b/egs/callhome/diarization/sond/finetune.sh
index 8e161f9..84ec103 100644
--- a/egs/callhome/diarization/sond/finetune.sh
+++ b/egs/callhome/diarization/sond/finetune.sh
@@ -8,13 +8,18 @@
 # [2] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, EMNLP 2022
 # We recommend you run this script stage by stage.
 
+# This recipe includes:
+# 1. downloading a pretrained model on the simulated data from switchboard and NIST,
+# 2. finetuning the pretrained model on Callhome1.
+# Finally, you will get a slightly better DER result 9.95% on Callhome2 than that in the paper 10.14%.
+
 # environment configuration
 if [ ! -e utils ]; then
   ln -s ../../../aishell/transformer/utils ./utils
 fi
 
 # machines configuration
-gpu_devices="0,1,2,3"
+gpu_devices="0,1,2,3"  # for V100-16G, need 4 gpus.
 gpu_num=4
 count=1
 
@@ -76,10 +81,14 @@
 # Download required resources
 if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
   echo "Stage 0: Download required resources."
-  wget told_finetune_resources.zip
+  if [ ! -e told_finetune_resources.tar.gz ]; then
+    # MD5SUM: abc7424e4e86ce6f040e9cba4178123b
+    wget --no-check-certificate https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/Speaker_Diar/told_finetune_resources.tar.gz
+    tar zxf told_finetune_resources.tar.gz
+  fi
 fi
 
-# Finetune model on callhome1
+# Finetune model on callhome1, this will take about 1.5 hours.
 if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
   echo "Stage 1: Finetune pretrained model on callhome1."
   world_size=$gpu_num  # run on one machine
@@ -230,11 +239,11 @@
 # Then find the wav files to construct wav.scp and put it at data/callhome2/wav.scp.
 # After iteratively perform SOAP, you will get DER results like:
 # iters : oracle_vad  |  system_vad
-# iter_0:   9.68      |     10.51
-# iter_1:   9.26      |     10.14  (reported in the paper)
-# iter_2:   9.18      |     10.08
-# iter_3:   9.24      |     10.15
-# iter_4:   9.27      |     10.17
+# iter_0:   9.63      |     10.43
+# iter_1:   9.17      |     10.03
+# iter_2:   9.11      |     9.98
+# iter_3:   9.08      |     9.96
+# iter_4:   9.07      |     9.95
 if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
   if [ ! -e ${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ]; then
     git lfs install
diff --git a/egs/callhome/diarization/sond/run.sh b/egs/callhome/diarization/sond/run.sh
index 3758f0c..c0ecd35 100644
--- a/egs/callhome/diarization/sond/run.sh
+++ b/egs/callhome/diarization/sond/run.sh
@@ -8,6 +8,15 @@
 # [2] Speaker Overlap-aware Neural Diarization for Multi-party Meeting Analysis, EMNLP 2022
 # We recommend you run this script stage by stage.
 
+# [developing] This recipe includes:
+# 1. simulating data with switchboard and NIST.
+# 2. training the model from scratch for 3 stages:
+#   2-1. pre-train on simu_swbd_sre
+#   2-2. train on simu_swbd_sre
+#   2-3. finetune on callhome1
+# 3. evaluating model with the results from the first stage EEND-OLA,
+# Finally, you will get a similar DER result claimed in the paper.
+
 # environment configuration
 kaldi_root=
 
@@ -26,8 +35,8 @@
 fi
 
 # machines configuration
-gpu_devices="6,7"
-gpu_num=2
+gpu_devices="4,5,6,7"  # for V100-16G, use 4 GPUs
+gpu_num=4
 count=1
 
 # general configuration
@@ -417,7 +426,7 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
-            diar_train.py \
+            python -m funasr.bin.diar_train \
                 --gpu_id $gpu_id \
                 --use_preprocessor false \
                 --token_type char \
@@ -565,7 +574,7 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
-            diar_train.py \
+            python -m funasr.bin.diar_train \
                 --gpu_id $gpu_id \
                 --use_preprocessor false \
                 --token_type char \
@@ -710,7 +719,7 @@
             rank=$i
             local_rank=$i
             gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
-            diar_train.py \
+            python -m funasr.bin.diar_train \
                 --gpu_id $gpu_id \
                 --use_preprocessor false \
                 --token_type char \
@@ -942,62 +951,3 @@
     echo "Done."
   done
 fi
-
-
-if [ ${stage} -le 30 ] && [ ${stop_stage} -ge 30 ]; then
-    echo "stage 30: training phase 1, pretraining on simulated data"
-    world_size=$gpu_num  # run on one machine
-    mkdir -p ${expdir}/${model_dir}
-    mkdir -p ${expdir}/${model_dir}/log
-    mkdir -p /tmp/${model_dir}
-    INIT_FILE=/tmp/${model_dir}/ddp_init
-    if [ -f $INIT_FILE ];then
-        rm -f $INIT_FILE
-    fi
-    init_opt=""
-    if [ ! -z "${init_param}" ]; then
-        init_opt="--init_param ${init_param}"
-        echo ${init_opt}
-    fi
-
-    freeze_opt=""
-    if [ ! -z "${freeze_param}" ]; then
-        freeze_opt="--freeze_param ${freeze_param}"
-        echo ${freeze_opt}
-    fi
-
-    init_method=file://$(readlink -f $INIT_FILE)
-    echo "$0: init method is $init_method"
-    for ((i = 0; i < $gpu_num; ++i)); do
-        {
-            rank=$i
-            local_rank=$i
-            gpu_id=$(echo $gpu_devices | cut -d',' -f$[$i+1])
-            diar_train.py \
-                --gpu_id $gpu_id \
-                --use_preprocessor false \
-                --token_type char \
-                --token_list $token_list \
-                --dataset_type large \
-                --train_data_file ${datadir}/${train_set}/dumped_files/data_file.list \
-                --valid_data_file ${datadir}/${valid_set}/dumped_files/data_file.list \
-                --init_param ${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch/sv.pth:encoder:encoder \
-                --freeze_param encoder \
-                ${init_opt} \
-                ${freeze_opt} \
-                --ignore_init_mismatch true \
-                --resume true \
-                --output_dir ${expdir}/${model_dir} \
-                --config $train_config \
-                --ngpu $gpu_num \
-                --num_worker_count $count \
-                --multiprocessing_distributed true \
-                --dist_init_method $init_method \
-                --dist_world_size $world_size \
-                --dist_rank $rank \
-                --local_rank $local_rank 1> ${expdir}/${model_dir}/log/train.log.$i 2>&1
-        } &
-        done
-        echo "Training log can be found at ${expdir}/${model_dir}/log/train.log.*"
-        wait
-fi
\ No newline at end of file

--
Gitblit v1.9.1