jmwang66
2023-08-07 cf8e000a84e888495dcf30c4dbfecea1ee7ab4e2
egs/callhome/diarization/sond/finetune.sh
@@ -30,20 +30,24 @@
  ln -s ${kaldi_root}/egs/callhome_diarization/v2/utils ./utils
fi
# callhome data root like path/to/NIST/LDC2001S97
callhome_root=
if [ -z "${kaldi_root}" ]; then
  echo "We need callhome corpus to prepare data."
  exit;
fi
# machines configuration
gpu_devices="0,1,2,3"  # for V100-16G, need 4 gpus.
gpu_num=4
count=1
# general configuration
stage=1
stop_stage=1
stage=0
stop_stage=10
# number of jobs for data process
nj=16
sr=8000
# dataset related
callhome_root=path/to/NIST/LDC2001S97
# experiment configuration
lang=en
@@ -97,15 +101,18 @@
  # split ref.rttm
  for dset in callhome1 callhome2; do
    rm -rf data/${dset}/ref.rttm
    for name in `awk '{print $1}' data/${dset}/wav.scp`; do
      grep ${name} data/callhome/fullref.rttm >> data/${dset}/ref.rttm;
    rm -rf ${datadir}/${dset}/ref.rttm
    for name in `awk '{print $1}' ${datadir}/${dset}/wav.scp`; do
      grep ${name} ${datadir}/callhome/fullref.rttm >> ${datadir}/${dset}/ref.rttm;
    done
    # filter out records which don't have rttm labels.
    awk '{print $2}' data/${dset}/ref.rttm | sort | uniq > data/${dset}/uttid
    mv data/${dset}/wav.scp data/${dset}/wav.scp.bak
    awk '{if (NR==FNR){a[$1]=1}else{if (a[$1]==1){print $0}}}' data/${dset}/uttid data/${dset}/wav.scp.bak > data/${dset}/wav.scp
    awk '{print $2}' ${datadir}/${dset}/ref.rttm | sort | uniq > ${datadir}/${dset}/uttid
    mv ${datadir}/${dset}/wav.scp ${datadir}/${dset}/wav.scp.bak
    awk '{if (NR==FNR){a[$1]=1}else{if (a[$1]==1){print $0}}}' ${datadir}/${dset}/uttid ${datadir}/${dset}/wav.scp.bak > ${datadir}/${dset}/wav.scp
    mkdir ${datadir}/${dset}/raw
    mv ${datadir}/${dset}/{reco2num_spk,segments,spk2utt,utt2spk,uttid,wav.scp.bak} ${datadir}/${dset}/raw/
    awk '{print $1,$1}' ${datadir}/${dset}/wav.scp > ${datadir}/${dset}/utt2spk
  done
fi
@@ -157,16 +164,9 @@
  ln -s ${kaldi_root}/egs/callhome_diarization/v2/steps ./
  for dset in callhome1 callhome2; do
    mv ${datadir}/${dset}/segments ${datadir}/${dset}/segs
    utils/fix_data_dir.sh ${datadir}/${dset}
    steps/make_fbank.sh --write-utt2num-frames true --fbank-config conf/fbank.conf --nj ${nj} --cmd "$train_cmd" \
        ${datadir}/${dset} ${expdir}/make_fbank/${dset} ${dumpdir}/${dset}/fbank
    utils/fix_data_dir.sh ${datadir}/${dset}
  done
  for dset in callhome1/nonoverlap_0s callhome2/nonoverlap_0s; do
    steps/make_fbank.sh --write-utt2num-frames true --fbank-config conf/fbank.conf --nj ${nj} --cmd "$train_cmd" \
        ${datadir}/${dset} ${expdir}/make_fbank/${dset} ${dumpdir}/${dset}/fbank
    utils/fix_data_dir.sh ${datadir}/${dset}
  done
  rm -f steps
@@ -174,14 +174,19 @@
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
  echo "Stage 4: Extract speaker embeddings."
  git lfs install
  git clone https://www.modelscope.cn/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch.git
  mv speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ${expdir}/
  sv_exp_dir=exp/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch
  sed "s/input_size: null/input_size: 80/g" ${sv_exp_dir}/sv.yaml > ${sv_exp_dir}/sv_fbank.yaml
  if [ ! -e ${sv_exp_dir} ]; then
    echo "start to download sv models"
    git lfs install
    git clone https://www.modelscope.cn/damo/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch.git
    mv speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ${expdir}/
    echo "Done."
  fi
  for dset in callhome1/nonoverlap_0s callhome2/nonoverlap_0s; do
    key_file=${datadir}/${dset}/feats.scp
    echo "Start to extract speaker embeddings for ${dset}"
    key_file=${datadir}/${dset}/wav.scp
    num_scp_file="$(<${key_file} wc -l)"
    _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
    _logdir=${dumpdir}/${dset}/xvecs
@@ -199,12 +204,15 @@
        --njob ${njob} \
        --ngpu "${_ngpu}" \
        --gpuid_list ${gpuid_list} \
        --data_path_and_name_and_type "${key_file},speech,kaldi_ark" \
        --data_path_and_name_and_type "${key_file},speech,sound" \
        --key_file "${_logdir}"/keys.JOB.scp \
        --sv_train_config ${sv_exp_dir}/sv_fbank.yaml \
        --sv_train_config ${sv_exp_dir}/sv.yaml \
        --sv_model_file ${sv_exp_dir}/sv.pth \
        --output_dir "${_logdir}"/output.JOB
    cat ${_logdir}/output.*/xvector.scp | sort > ${datadir}/${dset}/utt2xvec
    python script/calc_num_frames.py ${key_file} ${datadir}/${dset}/utt2num_frames
    echo "Done."
  done
fi
@@ -217,7 +225,7 @@
    python -Wignore script/calc_real_meeting_frame_labels.py \
          ${datadir}/${dset} ${dumpdir}/${dset}/labels \
          --n_spk 8 --frame_shift 0.01 --nj 16 --sr 8000
    find `pwd`/${dumpdir}/${dset}/labels -iname "*.lbl.mat" | awk -F'[/.]' '{print $(NF-2),$0}' | sort > ${datadir}/${dset}/labels.scp
    find `pwd`/${dumpdir}/${dset}/labels/ -iname "*.lbl.mat" | awk -F'[/.]' '{print $(NF-2),$0}' | sort > ${datadir}/${dset}/labels.scp
  done
fi
@@ -235,7 +243,7 @@
  echo "Stage 6: start to dump for callhome1."
  python -Wignore script/dump_meeting_chunks.py --dir ${data_dir} \
    --out ${dumpdir}/callhome1/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode test \
    --out ${dumpdir}/callhome1/dumped_files/data --n_spk 16 --no_pbar --sr 8000 --mode train \
    --chunk_size 1600 --chunk_shift 400 --add_mid_to_speaker true
  mkdir -p ${datadir}/callhome1/dumped_files
@@ -272,6 +280,15 @@
# Finetune model on callhome1, this will take about 1.5 hours.
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
  echo "Stage 7: Finetune pretrained model on callhome1."
  if [ ! -e ${expdir}/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch ]; then
    echo "start to download pretrained models"
    git lfs install
    git clone https://www.modelscope.cn/damo/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch.git
    mv speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch ${expdir}/
    echo "Done."
  fi
  world_size=$gpu_num  # run on one machine
  mkdir -p ${expdir}/${model_dir}
  mkdir -p ${expdir}/${model_dir}/log
@@ -312,7 +329,7 @@
              --valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/profile.scp,profile,kaldi_ark \
              --valid_data_path_and_name_and_type ${datadir}/${valid_set}/dumped_files/label.scp,binary_labels,kaldi_ark \
              --valid_shape_file ${expdir}/${valid_set}_states/speech_shape \
              --init_param exp/pretrained_models/phase2.pth \
              --init_param ${expdir}/speech_diarization_sond-en-us-swbd_sre-8k-n16k4-pytorch/sond.pth \
              --unused_parameters true \
              ${init_opt} \
              ${freeze_opt} \
@@ -381,7 +398,7 @@
# Scoring for finetuned model, you may get a DER like:
# oracle_vad  |  system_vad
#   7.28      |     8.06
#   7.32      |     8.14
if [ ${stage} -le 9 ] && [ ${stop_stage} -ge 9 ]; then
  echo "stage 9: Scoring finetuned models"
  if [ ! -e dscore ]; then
@@ -420,11 +437,11 @@
# Then find the wav files to construct wav.scp and put it at data/callhome2/wav.scp.
# After iteratively perform SOAP, you will get DER results like:
# iters : oracle_vad  |  system_vad
# iter_0:   9.63      |     10.43
# iter_1:   9.17      |     10.03
# iter_2:   9.11      |     9.98
# iter_3:   9.08      |     9.96
# iter_4:   9.07      |     9.95
# iter_0:   9.58      |     10.46
# iter_1:   9.22      |     10.15
# iter_2:   9.21      |     10.14
# iter_3:   9.30      |     10.24
# iter_4:   9.29      |     10.23
if [ ${stage} -le 10 ] && [ ${stop_stage} -ge 10 ]; then
  if [ ! -e ${expdir}/speech_xvector_sv-en-us-callhome-8k-spk6135-pytorch ]; then
    git lfs install