smohan-speech
2023-05-07 d76aea23d9f5daac4df7ee1985d07f7428abc719
add speaker-attributed ASR task for alimeeting
16个文件已修改
4个文件已添加
12 文件已重命名
3个文件已删除
1606 ■■■■■ 已修改文件
egs/alimeeting/sa-asr/asr_local.sh 33 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/asr_local_infer.sh 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/apply_map.pl 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/combine_data.sh 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/copy_data_dir.sh 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/data/get_reco2dur.sh 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/data/get_utt2dur.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/data/split_data.sh 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/fix_data_dir.sh 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/format_wav_scp.py 243 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/format_wav_scp.sh 142 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh 116 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/validate_data_dir.sh 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/local/validate_text.pl 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/path.sh 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/utils 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/utils/filter_scp.pl 87 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/utils/parse_options.sh 97 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/utils/split_scp.pl 246 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference.py 28 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_train.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/sa_asr_inference.py 24 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/sa_asr_train.py 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/losses/label_smoothing_loss.py 46 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/decoder/transformer_decoder.py 428 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_sa_asr.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/sa_asr.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs/alimeeting/sa-asr/asr_local.sh
@@ -434,14 +434,14 @@
           log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp"
           for factor in ${speed_perturb_factors}; do
               if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then
                   scripts/utils/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
                   local/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}"
                   _dirs+="data/${train_set}_sp${factor} "
               else
                   # If speed factor is 1, same as the original
                   _dirs+="data/${train_set} "
               fi
           done
           utils/combine_data.sh "data/${train_set}_sp" ${_dirs}
           local/combine_data.sh "data/${train_set}_sp" ${_dirs}
        else
           log "Skip stage 2: Speed perturbation"
        fi
@@ -473,7 +473,7 @@
                        _suf=""
                    fi
                fi
                utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
                local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
                
                cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/"
@@ -488,7 +488,7 @@
                    _opts+="--segments data/${dset}/segments "
                fi
                # shellcheck disable=SC2086
                scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
                local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \
                    --audio-format "${audio_format}" --fs "${fs}" ${_opts} \
                    "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}"
@@ -515,7 +515,7 @@
        for dset in $rm_dset; do
            # Copy data dir
            utils/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
            local/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}"
            cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type"
            # Remove short utterances
@@ -564,7 +564,7 @@
                awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text"
            # fix_data_dir.sh leaves only utts which exist in all files
            utils/fix_data_dir.sh "${data_feats}/${dset}"
            local/fix_data_dir.sh "${data_feats}/${dset}"
            # generate uttid
            cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid"
@@ -1283,6 +1283,7 @@
            ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
                python -m funasr.bin.asr_inference_launch \
                    --batch_size 1 \
                    --mc True   \
                    --nbest 1   \
                    --ngpu "${_ngpu}" \
                    --njob ${njob_infer} \
@@ -1312,10 +1313,10 @@
            _data="${data_feats}/${dset}"
            _dir="${asr_exp}/${inference_tag}/${dset}"
            python local/proce_text.py ${_data}/text ${_data}/text.proc
            python local/proce_text.py ${_dir}/text ${_dir}/text.proc
            python utils/proce_text.py ${_data}/text ${_data}/text.proc
            python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
            python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
            python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
            tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
            cat ${_dir}/text.cer.txt
            
@@ -1390,6 +1391,7 @@
            ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
                python -m funasr.bin.asr_inference_launch \
                    --batch_size 1 \
                    --mc True   \
                    --nbest 1   \
                    --ngpu "${_ngpu}" \
                    --njob ${njob_infer} \
@@ -1421,10 +1423,10 @@
            _data="${data_feats}/${dset}"
            _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}"
            python local/proce_text.py ${_data}/text ${_data}/text.proc
            python local/proce_text.py ${_dir}/text ${_dir}/text.proc
            python utils/proce_text.py ${_data}/text ${_data}/text.proc
            python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
            python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
            python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
            tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
            cat ${_dir}/text.cer.txt
@@ -1506,6 +1508,7 @@
            ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
                python -m funasr.bin.asr_inference_launch \
                    --batch_size 1 \
                    --mc True   \
                    --nbest 1   \
                    --ngpu "${_ngpu}" \
                    --njob ${njob_infer} \
@@ -1536,10 +1539,10 @@
            _data="${data_feats}/${dset}"
            _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}"
            python local/proce_text.py ${_data}/text ${_data}/text.proc
            python local/proce_text.py ${_dir}/text ${_dir}/text.proc
            python utils/proce_text.py ${_data}/text ${_data}/text.proc
            python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
            python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
            python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
            tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
            cat ${_dir}/text.cer.txt
egs/alimeeting/sa-asr/asr_local_infer.sh
@@ -436,7 +436,7 @@
            
                _suf=""
                utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
                local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}"
                
                rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur}
                _opts=
@@ -548,6 +548,7 @@
            ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
                python -m funasr.bin.asr_inference_launch \
                    --batch_size 1 \
                    --mc True   \
                    --nbest 1   \
                    --ngpu "${_ngpu}" \
                    --njob ${njob_infer} \
egs/alimeeting/sa-asr/conf/train_asr_conformer.yaml
@@ -4,7 +4,6 @@
    n_fft: 400
    win_length: 400
    hop_length: 160
    use_channel: 0
    
# encoder related
encoder: conformer
egs/alimeeting/sa-asr/conf/train_sa_asr_conformer.yaml
@@ -4,7 +4,6 @@
    n_fft: 400
    win_length: 400
    hop_length: 160
    use_channel: 0
# encoder related
asr_encoder: conformer
egs/alimeeting/sa-asr/local/alimeeting_data_prep.sh
@@ -78,7 +78,7 @@
    utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk
    #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/'  $near_dir/utt2spk_old >$near_dir/tmp1
    #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk
    utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
    local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt
    utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments
    sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1
    sed -e 's/!//g' $near_dir/tmp1> $near_dir/tmp2
@@ -109,7 +109,7 @@
    utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
    #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/'  $far_dir/utt2spk_old >$far_dir/utt2spk
    
    utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
    local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
    utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
    sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
    sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
@@ -121,8 +121,8 @@
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    log "stage 3: finali data process"
    utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
    utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
    local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near
    local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
    sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
    sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
@@ -146,10 +146,10 @@
    python local/process_textgrid_to_single_speaker_wav.py  --path $far_single_speaker_dir
    
    cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text    
    utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
    local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
    ./utils/fix_data_dir.sh $far_single_speaker_dir
    utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
    ./local/fix_data_dir.sh $far_single_speaker_dir
    local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
    # remove space in text
    for x in ${tgt}_Ali_far_single_speaker; do
egs/alimeeting/sa-asr/local/alimeeting_data_prep_test_2023.sh
@@ -77,7 +77,7 @@
    utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk
    #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/'  $far_dir/utt2spk_old >$far_dir/utt2spk
    
    utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
    local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt
    utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments
    sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1
    sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2
@@ -89,7 +89,7 @@
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    log "stage 2: finali data process"
    utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
    local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far
    sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo
    sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo
@@ -113,10 +113,10 @@
    python local/process_textgrid_to_single_speaker_wav.py  --path $far_single_speaker_dir
    
    cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text    
    utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
    local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt
    ./utils/fix_data_dir.sh $far_single_speaker_dir
    utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
    ./local/fix_data_dir.sh $far_single_speaker_dir
    local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker
    # remove space in text
    for x in ${tgt}_Ali_far_single_speaker; do
egs/alimeeting/sa-asr/local/apply_map.pl
egs/alimeeting/sa-asr/local/combine_data.sh
File was renamed from egs/alimeeting/sa-asr/utils/combine_data.sh
@@ -98,7 +98,7 @@
  for in_dir in $*; do
    if [ ! -f $in_dir/segments ]; then
      echo "$0 [info]: will generate missing segments for $in_dir" 1>&2
      utils/data/get_segments_for_data.sh $in_dir
      local/data/get_segments_for_data.sh $in_dir
    else
      cat $in_dir/segments
    fi
@@ -133,14 +133,14 @@
  fi
done
utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
local/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt
if [[ $dir_with_frame_shift ]]; then
  cp $dir_with_frame_shift/frame_shift $dest
fi
if ! $skip_fix ; then
  utils/fix_data_dir.sh $dest || exit 1;
  local/fix_data_dir.sh $dest || exit 1;
fi
exit 0
egs/alimeeting/sa-asr/local/copy_data_dir.sh
File was renamed from egs/alimeeting/sa-asr/utils/copy_data_dir.sh
@@ -71,25 +71,25 @@
  cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq
fi
cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map  | \
  utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
cat $srcdir/utt2spk | local/apply_map.pl -f 1 $destdir/utt_map  | \
  local/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk
utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
local/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt
if [ -f $srcdir/feats.scp ]; then
  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp
fi
if [ -f $srcdir/vad.scp ]; then
  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp
fi
if [ -f $srcdir/segments ]; then
  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments
  cp $srcdir/wav.scp $destdir
else # no segments->wav indexed by utt.
  if [ -f $srcdir/wav.scp ]; then
    utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
    local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp
  fi
fi
@@ -98,26 +98,26 @@
fi
if [ -f $srcdir/text ]; then
  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text
fi
if [ -f $srcdir/utt2dur ]; then
  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur
fi
if [ -f $srcdir/utt2num_frames ]; then
  utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
  local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames
fi
if [ -f $srcdir/reco2dur ]; then
  if [ -f $srcdir/segments ]; then
    cp $srcdir/reco2dur $destdir/reco2dur
  else
    utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
    local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur
  fi
fi
if [ -f $srcdir/spk2gender ]; then
  utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
  local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender
fi
if [ -f $srcdir/cmvn.scp ]; then
  utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
  local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp
fi
for f in frame_shift stm glm ctm; do
  if [ -f $srcdir/$f ]; then
@@ -142,4 +142,4 @@
[ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats"
[ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text"
utils/validate_data_dir.sh $validate_opts $destdir
local/validate_data_dir.sh $validate_opts $destdir
egs/alimeeting/sa-asr/local/data/get_reco2dur.sh
egs/alimeeting/sa-asr/local/data/get_segments_for_data.sh
File was renamed from egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh
@@ -20,7 +20,7 @@
data=$1
if [ ! -s $data/utt2dur ]; then
  utils/data/get_utt2dur.sh $data 1>&2 || exit 1;
  local/data/get_utt2dur.sh $data 1>&2 || exit 1;
fi
# <utt-id> <utt-id> 0 <utt-dur>
egs/alimeeting/sa-asr/local/data/get_utt2dur.sh
File was renamed from egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh
@@ -94,7 +94,7 @@
      nj=$num_utts
    fi
    utils/data/split_data.sh --per-utt $data $nj
    local/data/split_data.sh --per-utt $data $nj
    sdata=$data/split${nj}utt
    $cmd JOB=1:$nj $data/log/get_durations.JOB.log \
egs/alimeeting/sa-asr/local/data/split_data.sh
File was renamed from egs/alimeeting/sa-asr/utils/data/split_data.sh
@@ -60,11 +60,11 @@
nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file
if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then
  echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can "
  echo "**  use utils/fix_data_dir.sh $data to fix this."
  echo "**  use local/fix_data_dir.sh $data to fix this."
fi
if [ -f $data/text ] && [ $nu -ne $nt ]; then
  echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can "
  echo "** use utils/fix_data_dir.sh to fix this."
  echo "** use local/fix_data_dir.sh to fix this."
fi
@@ -112,7 +112,7 @@
for n in `seq $numsplit`; do
  dsn=$data/split${numsplit}${utt}/$n
  utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
  local/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1;
done
maybe_wav_scp=
egs/alimeeting/sa-asr/local/fix_data_dir.sh
File was renamed from egs/alimeeting/sa-asr/utils/fix_data_dir.sh
@@ -112,7 +112,7 @@
function filter_speakers {
  # throughout this program, we regard utt2spk as primary and spk2utt as derived, so...
  utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
  local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
  cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers
  for s in cmvn.scp spk2gender; do
@@ -123,7 +123,7 @@
  done
  filter_file $tmpdir/speakers $data/spk2utt
  utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
  local/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk
  for s in cmvn.scp spk2gender $spk_extra_files; do
    f=$data/$s
@@ -210,6 +210,6 @@
filter_speakers
filter_recordings
utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt
echo "fix_data_dir.sh: old files are kept in $data/.backup"
egs/alimeeting/sa-asr/local/format_wav_scp.py
New file
@@ -0,0 +1,243 @@
#!/usr/bin/env python3
import argparse
import logging
from io import BytesIO
from pathlib import Path
from typing import Tuple, Optional
import kaldiio
import humanfriendly
import numpy as np
import resampy
import soundfile
from tqdm import tqdm
from typeguard import check_argument_types
from funasr.utils.cli_utils import get_commandline_args
from funasr.fileio.read_text import read_2column_text
from funasr.fileio.sound_scp import SoundScpWriter
def humanfriendly_or_none(value: str):
    if value in ("none", "None", "NONE"):
        return None
    return humanfriendly.parse_size(value)
def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]:
    """
    >>> str2int_tuple('3,4,5')
    (3, 4, 5)
    """
    assert check_argument_types()
    if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"):
        return None
    return tuple(map(int, integers.strip().split(",")))
def main():
    logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
    logging.basicConfig(level=logging.INFO, format=logfmt)
    logging.info(get_commandline_args())
    parser = argparse.ArgumentParser(
        description='Create waves list from "wav.scp"',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("scp")
    parser.add_argument("outdir")
    parser.add_argument(
        "--name",
        default="wav",
        help="Specify the prefix word of output file name " 'such as "wav.scp"',
    )
    parser.add_argument("--segments", default=None)
    parser.add_argument(
        "--fs",
        type=humanfriendly_or_none,
        default=None,
        help="If the sampling rate specified, " "Change the sampling rate.",
    )
    parser.add_argument("--audio-format", default="wav")
    group = parser.add_mutually_exclusive_group()
    group.add_argument("--ref-channels", default=None, type=str2int_tuple)
    group.add_argument("--utt2ref-channels", default=None, type=str)
    args = parser.parse_args()
    out_num_samples = Path(args.outdir) / f"utt2num_samples"
    if args.ref_channels is not None:
        def utt2ref_channels(x) -> Tuple[int, ...]:
            return args.ref_channels
    elif args.utt2ref_channels is not None:
        utt2ref_channels_dict = read_2column_text(args.utt2ref_channels)
        def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]:
            chs_str = d[x]
            return tuple(map(int, chs_str.split()))
    else:
        utt2ref_channels = None
    Path(args.outdir).mkdir(parents=True, exist_ok=True)
    out_wavscp = Path(args.outdir) / f"{args.name}.scp"
    if args.segments is not None:
        # Note: kaldiio supports only wav-pcm-int16le file.
        loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments)
        if args.audio_format.endswith("ark"):
            fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
            fscp = out_wavscp.open("w")
        else:
            writer = SoundScpWriter(
                args.outdir,
                out_wavscp,
                format=args.audio_format,
            )
        with out_num_samples.open("w") as fnum_samples:
            for uttid, (rate, wave) in tqdm(loader):
                # wave: (Time,) or (Time, Nmic)
                if wave.ndim == 2 and utt2ref_channels is not None:
                    wave = wave[:, utt2ref_channels(uttid)]
                if args.fs is not None and args.fs != rate:
                    # FIXME(kamo): To use sox?
                    wave = resampy.resample(
                        wave.astype(np.float64), rate, args.fs, axis=0
                    )
                    wave = wave.astype(np.int16)
                    rate = args.fs
                if args.audio_format.endswith("ark"):
                    if "flac" in args.audio_format:
                        suf = "flac"
                    elif "wav" in args.audio_format:
                        suf = "wav"
                    else:
                        raise RuntimeError("wav.ark or flac")
                    # NOTE(kamo): Using extended ark format style here.
                    # This format is incompatible with Kaldi
                    kaldiio.save_ark(
                        fark,
                        {uttid: (wave, rate)},
                        scp=fscp,
                        append=True,
                        write_function=f"soundfile_{suf}",
                    )
                else:
                    writer[uttid] = rate, wave
                fnum_samples.write(f"{uttid} {len(wave)}\n")
    else:
        if args.audio_format.endswith("ark"):
            fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb")
        else:
            wavdir = Path(args.outdir) / f"data_{args.name}"
            wavdir.mkdir(parents=True, exist_ok=True)
        with Path(args.scp).open("r") as fscp, out_wavscp.open(
            "w"
        ) as fout, out_num_samples.open("w") as fnum_samples:
            for line in tqdm(fscp):
                uttid, wavpath = line.strip().split(None, 1)
                if wavpath.endswith("|"):
                    # Streaming input e.g. cat a.wav |
                    with kaldiio.open_like_kaldi(wavpath, "rb") as f:
                        with BytesIO(f.read()) as g:
                            wave, rate = soundfile.read(g, dtype=np.int16)
                            if wave.ndim == 2 and utt2ref_channels is not None:
                                wave = wave[:, utt2ref_channels(uttid)]
                        if args.fs is not None and args.fs != rate:
                            # FIXME(kamo): To use sox?
                            wave = resampy.resample(
                                wave.astype(np.float64), rate, args.fs, axis=0
                            )
                            wave = wave.astype(np.int16)
                            rate = args.fs
                        if args.audio_format.endswith("ark"):
                            if "flac" in args.audio_format:
                                suf = "flac"
                            elif "wav" in args.audio_format:
                                suf = "wav"
                            else:
                                raise RuntimeError("wav.ark or flac")
                            # NOTE(kamo): Using extended ark format style here.
                            # This format is incompatible with Kaldi
                            kaldiio.save_ark(
                                fark,
                                {uttid: (wave, rate)},
                                scp=fout,
                                append=True,
                                write_function=f"soundfile_{suf}",
                            )
                        else:
                            owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
                            soundfile.write(owavpath, wave, rate)
                            fout.write(f"{uttid} {owavpath}\n")
                else:
                    wave, rate = soundfile.read(wavpath, dtype=np.int16)
                    if wave.ndim == 2 and utt2ref_channels is not None:
                        wave = wave[:, utt2ref_channels(uttid)]
                        save_asis = False
                    elif args.audio_format.endswith("ark"):
                        save_asis = False
                    elif Path(wavpath).suffix == "." + args.audio_format and (
                        args.fs is None or args.fs == rate
                    ):
                        save_asis = True
                    else:
                        save_asis = False
                    if save_asis:
                        # Neither --segments nor --fs are specified and
                        # the line doesn't end with "|",
                        # i.e. not using unix-pipe,
                        # only in this case,
                        # just using the original file as is.
                        fout.write(f"{uttid} {wavpath}\n")
                    else:
                        if args.fs is not None and args.fs != rate:
                            # FIXME(kamo): To use sox?
                            wave = resampy.resample(
                                wave.astype(np.float64), rate, args.fs, axis=0
                            )
                            wave = wave.astype(np.int16)
                            rate = args.fs
                        if args.audio_format.endswith("ark"):
                            if "flac" in args.audio_format:
                                suf = "flac"
                            elif "wav" in args.audio_format:
                                suf = "wav"
                            else:
                                raise RuntimeError("wav.ark or flac")
                            # NOTE(kamo): Using extended ark format style here.
                            # This format is not supported in Kaldi.
                            kaldiio.save_ark(
                                fark,
                                {uttid: (wave, rate)},
                                scp=fout,
                                append=True,
                                write_function=f"soundfile_{suf}",
                            )
                        else:
                            owavpath = str(wavdir / f"{uttid}.{args.audio_format}")
                            soundfile.write(owavpath, wave, rate)
                            fout.write(f"{uttid} {owavpath}\n")
                fnum_samples.write(f"{uttid} {len(wave)}\n")
if __name__ == "__main__":
    main()
egs/alimeeting/sa-asr/local/format_wav_scp.sh
New file
@@ -0,0 +1,142 @@
#!/usr/bin/env bash
set -euo pipefail
SECONDS=0
log() {
    local fname=${BASH_SOURCE[1]##*/}
    echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
help_message=$(cat << EOF
Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]]
e.g.
$0 data/test/wav.scp data/test_format/
Format 'wav.scp': In short words,
changing "kaldi-datadir" to "modified-kaldi-datadir"
The 'wav.scp' format in kaldi is very flexible,
e.g. It can use unix-pipe as describing that wav file,
but it sometime looks confusing and make scripts more complex.
This tools creates actual wav files from 'wav.scp'
and also segments wav files using 'segments'.
Options
  --fs <fs>
  --segments <segments>
  --nj <nj>
  --cmd <cmd>
EOF
)
out_filename=wav.scp
cmd=utils/run.pl
nj=30
fs=none
segments=
ref_channels=
utt2ref_channels=
audio_format=wav
write_utt2num_samples=true
log "$0 $*"
. utils/parse_options.sh
if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then
    log "${help_message}"
    log "Error: invalid command line arguments"
    exit 1
fi
. ./path.sh  # Setup the environment
scp=$1
if [ ! -f "${scp}" ]; then
    log "${help_message}"
    echo "$0: Error: No such file: ${scp}"
    exit 1
fi
dir=$2
if [ $# -eq 2 ]; then
    logdir=${dir}/logs
    outdir=${dir}/data
elif [ $# -eq 3 ]; then
    logdir=$3
    outdir=${dir}/data
elif [ $# -eq 4 ]; then
    logdir=$3
    outdir=$4
fi
mkdir -p ${logdir}
rm -f "${dir}/${out_filename}"
opts=
if [ -n "${utt2ref_channels}" ]; then
    opts="--utt2ref-channels ${utt2ref_channels} "
elif [ -n "${ref_channels}" ]; then
    opts="--ref-channels ${ref_channels} "
fi
if [ -n "${segments}" ]; then
    log "[info]: using ${segments}"
    nutt=$(<${segments} wc -l)
    nj=$((nj<nutt?nj:nutt))
    split_segments=""
    for n in $(seq ${nj}); do
        split_segments="${split_segments} ${logdir}/segments.${n}"
    done
    utils/split_scp.pl "${segments}" ${split_segments}
    ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
        local/format_wav_scp.py \
            ${opts} \
            --fs ${fs} \
            --audio-format "${audio_format}" \
            "--segment=${logdir}/segments.JOB" \
            "${scp}" "${outdir}/format.JOB"
else
    log "[info]: without segments"
    nutt=$(<${scp} wc -l)
    nj=$((nj<nutt?nj:nutt))
    split_scps=""
    for n in $(seq ${nj}); do
        split_scps="${split_scps} ${logdir}/wav.${n}.scp"
    done
    utils/split_scp.pl "${scp}" ${split_scps}
    ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \
        local/format_wav_scp.py \
        ${opts} \
        --fs "${fs}" \
        --audio-format "${audio_format}" \
        "${logdir}/wav.JOB.scp" ${outdir}/format.JOB""
fi
# Workaround for the NFS problem
ls ${outdir}/format.* > /dev/null
# concatenate the .scp files together.
for n in $(seq ${nj}); do
    cat "${outdir}/format.${n}/wav.scp" || exit 1;
done > "${dir}/${out_filename}" || exit 1
if "${write_utt2num_samples}"; then
    for n in $(seq ${nj}); do
        cat "${outdir}/format.${n}/utt2num_samples" || exit 1;
    done > "${dir}/utt2num_samples"  || exit 1
fi
log "Successfully finished. [elapsed=${SECONDS}s]"
egs/alimeeting/sa-asr/local/perturb_data_dir_speed.sh
New file
@@ -0,0 +1,116 @@
#!/usr/bin/env bash
# 2020 @kamo-naoyuki
# This file was copied from Kaldi and
# I deleted parts related to wav duration
# because we shouldn't use kaldi's command here
# and we don't need the files actually.
# Copyright 2013  Johns Hopkins University (author: Daniel Povey)
#           2014  Tom Ko
#           2018  Emotech LTD (author: Pawel Swietojanski)
# Apache 2.0
# This script operates on a directory, such as in data/train/,
# that contains some subset of the following files:
#  wav.scp
#  spk2utt
#  utt2spk
#  text
#
# It generates the files which are used for perturbing the speed of the original data.
export LC_ALL=C
set -euo pipefail
if [[ $# != 3 ]]; then
    echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>"
    echo "e.g.:"
    echo " $0 0.9 data/train_si284 data/train_si284p"
    exit 1
fi
factor=$1
srcdir=$2
destdir=$3
label="sp"
spk_prefix="${label}${factor}-"
utt_prefix="${label}${factor}-"
#check is sox on the path
! command -v sox &>/dev/null && echo "sox: command not found" && exit 1;
if [[ ! -f ${srcdir}/utt2spk ]]; then
  echo "$0: no such file ${srcdir}/utt2spk"
  exit 1;
fi
if [[ ${destdir} == "${srcdir}" ]]; then
  echo "$0: this script requires <srcdir> and <destdir> to be different."
  exit 1
fi
mkdir -p "${destdir}"
<"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map"
<"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map"
<"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map"
if [[ ! -f ${srcdir}/utt2uniq ]]; then
    <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq"
else
    <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq"
fi
<"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \
  utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk
utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt
if [[ -f ${srcdir}/segments ]]; then
  utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \
      utils/apply_map.pl -f 2 "${destdir}"/reco_map | \
          awk -v factor="${factor}" \
            '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \
            >"${destdir}"/segments
  utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
      # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
      awk -v factor="${factor}" \
          '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
            else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
            else  {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
             > "${destdir}"/wav.scp
  if [[ -f ${srcdir}/reco2file_and_channel ]]; then
      utils/apply_map.pl -f 1 "${destdir}"/reco_map \
       <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel
  fi
else # no segments->wav indexed by utterance.
    if [[ -f ${srcdir}/wav.scp ]]; then
        utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \
         # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename"
         awk -v factor="${factor}" \
           '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"}
             else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" }
             else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \
                 > "${destdir}"/wav.scp
    fi
fi
if [[ -f ${srcdir}/text ]]; then
    utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text
fi
if [[ -f ${srcdir}/spk2gender ]]; then
    utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender
fi
if [[ -f ${srcdir}/utt2lang ]]; then
    utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang
fi
rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null
echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}"
utils/validate_data_dir.sh --no-feats --no-text "${destdir}"
egs/alimeeting/sa-asr/local/spk2utt_to_utt2spk.pl
egs/alimeeting/sa-asr/local/utt2spk_to_spk2utt.pl
egs/alimeeting/sa-asr/local/validate_data_dir.sh
File was renamed from egs/alimeeting/sa-asr/utils/validate_data_dir.sh
@@ -113,7 +113,7 @@
check_sorted_and_uniq $data/spk2utt
! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \
     <(utils/spk2utt_to_utt2spk.pl $data/spk2utt)  && \
     <(local/spk2utt_to_utt2spk.pl $data/spk2utt)  && \
   echo "$0: spk2utt and utt2spk do not seem to match" && exit 1;
cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts
@@ -135,7 +135,7 @@
    echo "$0: text contains $n_non_print lines with non-printable characters" &&\
    exit 1;
  fi
  utils/validate_text.pl $data/text || exit 1;
  local/validate_text.pl $data/text || exit 1;
  check_sorted_and_uniq $data/text
  text_len=`cat $data/text | wc -l`
  illegal_sym_list="<s> </s> #0"
egs/alimeeting/sa-asr/local/validate_text.pl
egs/alimeeting/sa-asr/path.sh
@@ -2,5 +2,4 @@
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH
export PATH=$PWD/utils/:$PATH
export PATH=$FUNASR_DIR/funasr/bin:$PATH
egs/alimeeting/sa-asr/utils
New file
@@ -0,0 +1 @@
../../aishell/transformer/utils
egs/alimeeting/sa-asr/utils/filter_scp.pl
File was deleted
egs/alimeeting/sa-asr/utils/parse_options.sh
File was deleted
egs/alimeeting/sa-asr/utils/split_scp.pl
File was deleted
funasr/bin/asr_inference.py
@@ -40,6 +40,8 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
@@ -90,6 +92,12 @@
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            if asr_train_args.frontend=='wav_frontend':
                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
            else:
                frontend_class=frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
@@ -197,12 +205,21 @@
        """
        assert check_argument_types()
        # Input as audio signal
        if isinstance(speech, np.ndarray):
            speech = torch.tensor(speech)
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            self.asr_model.frontend = None
        else:
            feats = speech
            feats_len = speech_lengths
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
        batch = to_device(batch, device=self.device)
@@ -275,6 +292,7 @@
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        mc: bool = False,
        **kwargs,
):
    inference_pipeline = inference_modelscope(
@@ -305,6 +323,7 @@
        ngram_weight=ngram_weight,
        nbest=nbest,
        num_workers=num_workers,
        mc=mc,
        **kwargs,
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -337,6 +356,7 @@
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    mc: bool = False,
    param_dict: dict = None,
    **kwargs,
):
@@ -406,7 +426,7 @@
            data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=True,
            mc=mc,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
@@ -415,7 +435,7 @@
            allow_variable_data_keys=allow_variable_data_keys,
            inference=True,
        )
        finish_count = 0
        file_count = 1
        # 7 .Start for-loop
funasr/bin/asr_inference_launch.py
@@ -71,7 +71,13 @@
    )
    group.add_argument("--key_file", type=str_or_none)
    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
    group.add_argument(
            "--mc",
            type=bool,
            default=False,
            help="MultiChannel input",
        )
    group = parser.add_argument_group("The model configuration related")
    group.add_argument(
        "--vad_infer_config",
funasr/bin/asr_train.py
@@ -2,14 +2,6 @@
import os
import logging
logging.basicConfig(
    level='INFO',
    format=f"[{os.uname()[1].split('.')[0]}]"
           f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
from funasr.tasks.asr import ASRTask
funasr/bin/sa_asr_inference.py
@@ -35,6 +35,8 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.asr import frontend_choices
header_colors = '\033[95m'
@@ -85,6 +87,12 @@
            asr_train_config, asr_model_file, cmvn_file, device
        )
        frontend = None
        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
            if asr_train_args.frontend=='wav_frontend':
                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval()
            else:
                frontend_class=frontend_choices.get_class(asr_train_args.frontend)
                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
        logging.info("asr_model: {}".format(asr_model))
        logging.info("asr_train_args: {}".format(asr_train_args))
@@ -201,7 +209,16 @@
        if isinstance(profile, np.ndarray):
            profile = torch.tensor(profile)
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, speech_lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
            self.asr_model.frontend = None
        else:
            feats = speech
            feats_len = speech_lengths
        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
        batch = to_device(batch, device=self.device)
@@ -308,6 +325,7 @@
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        mc: bool = False,
        **kwargs,
):
    inference_pipeline = inference_modelscope(
@@ -338,6 +356,7 @@
        ngram_weight=ngram_weight,
        nbest=nbest,
        num_workers=num_workers,
        mc=mc,
        **kwargs,
    )
    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -370,6 +389,7 @@
    ngram_weight: float = 0.9,
    nbest: int = 1,
    num_workers: int = 1,
    mc: bool = False,
    param_dict: dict = None,
    **kwargs,
):
@@ -437,7 +457,7 @@
            data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            mc=True,
            mc=mc,
            batch_size=batch_size,
            key_file=key_file,
            num_workers=num_workers,
funasr/bin/sa_asr_train.py
@@ -2,14 +2,6 @@
import os
import logging
logging.basicConfig(
    level='INFO',
    format=f"[{os.uname()[1].split('.')[0]}]"
           f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
from funasr.tasks.sa_asr import ASRTask
funasr/losses/label_smoothing_loss.py
@@ -79,3 +79,49 @@
        loss = self.criterion(pred, label)
        denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0]
        return loss.masked_fill(pad_mask, 0).sum() / denom
class NllLoss(nn.Module):
    """Nll loss.
    :param int size: the number of class
    :param int padding_idx: ignored class id
    :param bool normalize_length: normalize loss by sequence length if True
    :param torch.nn.Module criterion: loss function
    """
    def __init__(
        self,
        size,
        padding_idx,
        normalize_length=False,
        criterion=nn.NLLLoss(reduction='none'),
    ):
        """Construct an LabelSmoothingLoss object."""
        super(NllLoss, self).__init__()
        self.criterion = criterion
        self.padding_idx = padding_idx
        self.size = size
        self.true_dist = None
        self.normalize_length = normalize_length
    def forward(self, x, target):
        """Compute loss between x and target.
        :param torch.Tensor x: prediction (batch, seqlen, class)
        :param torch.Tensor target:
            target signal masked with self.padding_id (batch, seqlen)
        :return: scalar float value
        :rtype torch.Tensor
        """
        assert x.size(2) == self.size
        batch_size = x.size(0)
        x = x.view(-1, self.size)
        target = target.view(-1)
        with torch.no_grad():
            ignore = target == self.padding_idx  # (B,)
            total = len(target) - ignore.sum().item()
            target = target.masked_fill(ignore, 0)  # avoid -1 index
        kl = self.criterion(x , target)
        denom = total if self.normalize_length else batch_size
        return kl.masked_fill(ignore, 0).sum() / denom
funasr/models/decoder/transformer_decoder.py
@@ -13,6 +13,7 @@
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.attention import CosineDistanceAttention
from funasr.modules.dynamic_conv import DynamicConvolution
from funasr.modules.dynamic_conv2d import DynamicConvolution2D
from funasr.modules.embedding import PositionalEncoding
@@ -763,4 +764,429 @@
                normalize_before,
                concat_after,
            ),
        )
        )
class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface):
    def __init__(
        self,
        vocab_size: int,
        encoder_output_size: int,
        spker_embedding_dim: int = 256,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        input_layer: str = "embed",
        use_asr_output_layer: bool = True,
        use_spk_output_layer: bool = True,
        pos_enc_class=PositionalEncoding,
        normalize_before: bool = True,
    ):
        assert check_argument_types()
        super().__init__()
        attention_dim = encoder_output_size
        if input_layer == "embed":
            self.embed = torch.nn.Sequential(
                torch.nn.Embedding(vocab_size, attention_dim),
                pos_enc_class(attention_dim, positional_dropout_rate),
            )
        elif input_layer == "linear":
            self.embed = torch.nn.Sequential(
                torch.nn.Linear(vocab_size, attention_dim),
                torch.nn.LayerNorm(attention_dim),
                torch.nn.Dropout(dropout_rate),
                torch.nn.ReLU(),
                pos_enc_class(attention_dim, positional_dropout_rate),
            )
        else:
            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
        self.normalize_before = normalize_before
        if self.normalize_before:
            self.after_norm = LayerNorm(attention_dim)
        if use_asr_output_layer:
            self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size)
        else:
            self.asr_output_layer = None
        if use_spk_output_layer:
            self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim)
        else:
            self.spk_output_layer = None
        self.cos_distance_att = CosineDistanceAttention()
        self.decoder1 = None
        self.decoder2 = None
        self.decoder3 = None
        self.decoder4 = None
    def forward(
        self,
        asr_hs_pad: torch.Tensor,
        spk_hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        profile: torch.Tensor,
        profile_lens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        tgt = ys_in_pad
        # tgt_mask: (B, 1, L)
        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
        # m: (1, L, L)
        m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0)
        # tgt_mask: (B, L, L)
        tgt_mask = tgt_mask & m
        asr_memory = asr_hs_pad
        spk_memory = spk_hs_pad
        memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device)
        # Spk decoder
        x = self.embed(tgt)
        x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1(
            x, tgt_mask, asr_memory, spk_memory, memory_mask
        )
        x, tgt_mask, spk_memory, memory_mask = self.decoder2(
            x, tgt_mask, spk_memory, memory_mask
        )
        if self.normalize_before:
            x = self.after_norm(x)
        if self.spk_output_layer is not None:
            x = self.spk_output_layer(x)
        dn, weights = self.cos_distance_att(x, profile, profile_lens)
        # Asr decoder
        x, tgt_mask, asr_memory, memory_mask = self.decoder3(
            z, tgt_mask, asr_memory, memory_mask, dn
        )
        x, tgt_mask, asr_memory, memory_mask = self.decoder4(
            x, tgt_mask, asr_memory, memory_mask
        )
        if self.normalize_before:
            x = self.after_norm(x)
        if self.asr_output_layer is not None:
            x = self.asr_output_layer(x)
        olens = tgt_mask.sum(1)
        return x, weights, olens
    def forward_one_step(
        self,
        tgt: torch.Tensor,
        tgt_mask: torch.Tensor,
        asr_memory: torch.Tensor,
        spk_memory: torch.Tensor,
        profile: torch.Tensor,
        cache: List[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        x = self.embed(tgt)
        if cache is None:
            cache = [None] * (2 + len(self.decoder2) + len(self.decoder4))
        new_cache = []
        x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1(
                x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0]
        )
        new_cache.append(x)
        for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2):
            x, tgt_mask, spk_memory, _ = decoder(
                x, tgt_mask, spk_memory, None, cache=c
            )
            new_cache.append(x)
        if self.normalize_before:
            x = self.after_norm(x)
        else:
            x = x
        if self.spk_output_layer is not None:
            x = self.spk_output_layer(x)
        dn, weights = self.cos_distance_att(x, profile, None)
        x, tgt_mask, asr_memory, _ = self.decoder3(
            z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1]
        )
        new_cache.append(x)
        for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4):
            x, tgt_mask, asr_memory, _ = decoder(
                x, tgt_mask, asr_memory, None, cache=c
            )
            new_cache.append(x)
        if self.normalize_before:
            y = self.after_norm(x[:, -1])
        else:
            y = x[:, -1]
        if self.asr_output_layer is not None:
            y = torch.log_softmax(self.asr_output_layer(y), dim=-1)
        return y, weights, new_cache
    def score(self, ys, state, asr_enc, spk_enc, profile):
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0)
        logp, weights, state = self.forward_one_step(
            ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state
        )
        return logp.squeeze(0), weights.squeeze(), state
class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder):
    def __init__(
        self,
        vocab_size: int,
        encoder_output_size: int,
        spker_embedding_dim: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        asr_num_blocks: int = 6,
        spk_num_blocks: int = 3,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        self_attention_dropout_rate: float = 0.0,
        src_attention_dropout_rate: float = 0.0,
        input_layer: str = "embed",
        use_asr_output_layer: bool = True,
        use_spk_output_layer: bool = True,
        pos_enc_class=PositionalEncoding,
        normalize_before: bool = True,
        concat_after: bool = False,
    ):
        assert check_argument_types()
        super().__init__(
            vocab_size=vocab_size,
            encoder_output_size=encoder_output_size,
            spker_embedding_dim=spker_embedding_dim,
            dropout_rate=dropout_rate,
            positional_dropout_rate=positional_dropout_rate,
            input_layer=input_layer,
            use_asr_output_layer=use_asr_output_layer,
            use_spk_output_layer=use_spk_output_layer,
            pos_enc_class=pos_enc_class,
            normalize_before=normalize_before,
        )
        attention_dim = encoder_output_size
        self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer(
            attention_dim,
            MultiHeadedAttention(
                attention_heads, attention_dim, self_attention_dropout_rate
            ),
            MultiHeadedAttention(
                attention_heads, attention_dim, src_attention_dropout_rate
            ),
            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
            dropout_rate,
            normalize_before,
            concat_after,
        )
        self.decoder2 = repeat(
            spk_num_blocks - 1,
            lambda lnum: DecoderLayer(
                attention_dim,
                MultiHeadedAttention(
                    attention_heads, attention_dim, self_attention_dropout_rate
                ),
                MultiHeadedAttention(
                    attention_heads, attention_dim, src_attention_dropout_rate
                ),
                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                dropout_rate,
                normalize_before,
                concat_after,
            ),
        )
        self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer(
            attention_dim,
            spker_embedding_dim,
            MultiHeadedAttention(
                attention_heads, attention_dim, src_attention_dropout_rate
            ),
            PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
            dropout_rate,
            normalize_before,
            concat_after,
        )
        self.decoder4 = repeat(
            asr_num_blocks - 1,
            lambda lnum: DecoderLayer(
                attention_dim,
                MultiHeadedAttention(
                    attention_heads, attention_dim, self_attention_dropout_rate
                ),
                MultiHeadedAttention(
                    attention_heads, attention_dim, src_attention_dropout_rate
                ),
                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                dropout_rate,
                normalize_before,
                concat_after,
            ),
        )
class SpeakerAttributeSpkDecoderFirstLayer(nn.Module):
    def __init__(
        self,
        size,
        self_attn,
        src_attn,
        feed_forward,
        dropout_rate,
        normalize_before=True,
        concat_after=False,
    ):
        """Construct an DecoderLayer object."""
        super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.norm1 = LayerNorm(size)
        self.norm2 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        if self.concat_after:
            self.concat_linear1 = nn.Linear(size + size, size)
            self.concat_linear2 = nn.Linear(size + size, size)
    def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None):
        residual = tgt
        if self.normalize_before:
            tgt = self.norm1(tgt)
        if cache is None:
            tgt_q = tgt
            tgt_q_mask = tgt_mask
        else:
            # compute only the last frame query keeping dim: max_time_out -> 1
            assert cache.shape == (
                tgt.shape[0],
                tgt.shape[1] - 1,
                self.size,
            ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
            tgt_q = tgt[:, -1:, :]
            residual = residual[:, -1:, :]
            tgt_q_mask = None
            if tgt_mask is not None:
                tgt_q_mask = tgt_mask[:, -1:, :]
        if self.concat_after:
            tgt_concat = torch.cat(
                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
            )
            x = residual + self.concat_linear1(tgt_concat)
        else:
            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
        if not self.normalize_before:
            x = self.norm1(x)
        z = x
        residual = x
        if self.normalize_before:
            x = self.norm1(x)
        skip = self.src_attn(x, asr_memory, spk_memory, memory_mask)
        if self.concat_after:
            x_concat = torch.cat(
                (x, skip), dim=-1
            )
            x = residual + self.concat_linear2(x_concat)
        else:
            x = residual + self.dropout(skip)
        if not self.normalize_before:
            x = self.norm1(x)
        residual = x
        if self.normalize_before:
            x = self.norm2(x)
        x = residual + self.dropout(self.feed_forward(x))
        if not self.normalize_before:
            x = self.norm2(x)
        if cache is not None:
            x = torch.cat([cache, x], dim=1)
        return x, tgt_mask, asr_memory, spk_memory, memory_mask, z
class SpeakerAttributeAsrDecoderFirstLayer(nn.Module):
    def __init__(
        self,
        size,
        d_size,
        src_attn,
        feed_forward,
        dropout_rate,
        normalize_before=True,
        concat_after=False,
    ):
        """Construct an DecoderLayer object."""
        super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__()
        self.size = size
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.norm1 = LayerNorm(size)
        self.norm2 = LayerNorm(size)
        self.norm3 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        self.spk_linear = nn.Linear(d_size, size, bias=False)
        if self.concat_after:
            self.concat_linear1 = nn.Linear(size + size, size)
            self.concat_linear2 = nn.Linear(size + size, size)
    def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None):
        residual = tgt
        if self.normalize_before:
            tgt = self.norm1(tgt)
        if cache is None:
            tgt_q = tgt
            tgt_q_mask = tgt_mask
        else:
            tgt_q = tgt[:, -1:, :]
            residual = residual[:, -1:, :]
            tgt_q_mask = None
            if tgt_mask is not None:
                tgt_q_mask = tgt_mask[:, -1:, :]
        x = tgt_q
        if self.normalize_before:
            x = self.norm2(x)
        if self.concat_after:
            x_concat = torch.cat(
                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
            )
            x = residual + self.concat_linear2(x_concat)
        else:
            x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
        if not self.normalize_before:
            x = self.norm2(x)
        residual = x
        if dn!=None:
            x = x + self.spk_linear(dn)
        if self.normalize_before:
            x = self.norm3(x)
        x = residual + self.dropout(self.feed_forward(x))
        if not self.normalize_before:
            x = self.norm3(x)
        if cache is not None:
            x = torch.cat([cache, x], dim=1)
        return x, tgt_mask, memory, memory_mask
funasr/models/e2e_sa_asr.py
@@ -16,9 +16,8 @@
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
    LabelSmoothingLoss, NllLoss  # noqa: H301
)
from funasr.losses.nll_loss import NllLoss
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
funasr/tasks/sa_asr.py
@@ -28,7 +28,7 @@
from funasr.models.decoder.transformer_decoder import (
    DynamicConvolution2DTransformerDecoder,  # noqa: H301
)
from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder
from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder
from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
from funasr.models.decoder.transformer_decoder import (
    LightweightConvolution2DTransformerDecoder,  # noqa: H301