add speaker-attributed ASR task for alimeeting
16个文件已修改
4个文件已添加
12 文件已重命名
3个文件已删除
| | |
| | | log "Stage 2: Speed perturbation: data/${train_set} -> data/${train_set}_sp" |
| | | for factor in ${speed_perturb_factors}; do |
| | | if [[ $(bc <<<"${factor} != 1.0") == 1 ]]; then |
| | | scripts/utils/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}" |
| | | local/perturb_data_dir_speed.sh "${factor}" "data/${train_set}" "data/${train_set}_sp${factor}" |
| | | _dirs+="data/${train_set}_sp${factor} " |
| | | else |
| | | # If speed factor is 1, same as the original |
| | | _dirs+="data/${train_set} " |
| | | fi |
| | | done |
| | | utils/combine_data.sh "data/${train_set}_sp" ${_dirs} |
| | | local/combine_data.sh "data/${train_set}_sp" ${_dirs} |
| | | else |
| | | log "Skip stage 2: Speed perturbation" |
| | | fi |
| | |
| | | _suf="" |
| | | fi |
| | | fi |
| | | utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" |
| | | local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" |
| | | |
| | | cp data/"${dset}"/utt2spk_all_fifo "${data_feats}${_suf}/${dset}/" |
| | | |
| | |
| | | _opts+="--segments data/${dset}/segments " |
| | | fi |
| | | # shellcheck disable=SC2086 |
| | | scripts/audio/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ |
| | | local/format_wav_scp.sh --nj "${nj}" --cmd "${train_cmd}" \ |
| | | --audio-format "${audio_format}" --fs "${fs}" ${_opts} \ |
| | | "data/${dset}/wav.scp" "${data_feats}${_suf}/${dset}" |
| | | |
| | |
| | | for dset in $rm_dset; do |
| | | |
| | | # Copy data dir |
| | | utils/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}" |
| | | local/copy_data_dir.sh --validate_opts --non-print "${data_feats}/org/${dset}" "${data_feats}/${dset}" |
| | | cp "${data_feats}/org/${dset}/feats_type" "${data_feats}/${dset}/feats_type" |
| | | |
| | | # Remove short utterances |
| | |
| | | awk ' { if( NF != 1 ) print $0; } ' >"${data_feats}/${dset}/text" |
| | | |
| | | # fix_data_dir.sh leaves only utts which exist in all files |
| | | utils/fix_data_dir.sh "${data_feats}/${dset}" |
| | | local/fix_data_dir.sh "${data_feats}/${dset}" |
| | | |
| | | # generate uttid |
| | | cut -d ' ' -f 1 "${data_feats}/${dset}/wav.scp" > "${data_feats}/${dset}/uttid" |
| | |
| | | ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ |
| | | python -m funasr.bin.asr_inference_launch \ |
| | | --batch_size 1 \ |
| | | --mc True \ |
| | | --nbest 1 \ |
| | | --ngpu "${_ngpu}" \ |
| | | --njob ${njob_infer} \ |
| | |
| | | _data="${data_feats}/${dset}" |
| | | _dir="${asr_exp}/${inference_tag}/${dset}" |
| | | |
| | | python local/proce_text.py ${_data}/text ${_data}/text.proc |
| | | python local/proce_text.py ${_dir}/text ${_dir}/text.proc |
| | | python utils/proce_text.py ${_data}/text ${_data}/text.proc |
| | | python utils/proce_text.py ${_dir}/text ${_dir}/text.proc |
| | | |
| | | python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer |
| | | python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer |
| | | tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt |
| | | cat ${_dir}/text.cer.txt |
| | | |
| | |
| | | ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ |
| | | python -m funasr.bin.asr_inference_launch \ |
| | | --batch_size 1 \ |
| | | --mc True \ |
| | | --nbest 1 \ |
| | | --ngpu "${_ngpu}" \ |
| | | --njob ${njob_infer} \ |
| | |
| | | _data="${data_feats}/${dset}" |
| | | _dir="${sa_asr_exp}/${sa_asr_inference_tag}.oracle/${dset}" |
| | | |
| | | python local/proce_text.py ${_data}/text ${_data}/text.proc |
| | | python local/proce_text.py ${_dir}/text ${_dir}/text.proc |
| | | python utils/proce_text.py ${_data}/text ${_data}/text.proc |
| | | python utils/proce_text.py ${_dir}/text ${_dir}/text.proc |
| | | |
| | | python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer |
| | | python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer |
| | | tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt |
| | | cat ${_dir}/text.cer.txt |
| | | |
| | |
| | | ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ |
| | | python -m funasr.bin.asr_inference_launch \ |
| | | --batch_size 1 \ |
| | | --mc True \ |
| | | --nbest 1 \ |
| | | --ngpu "${_ngpu}" \ |
| | | --njob ${njob_infer} \ |
| | |
| | | _data="${data_feats}/${dset}" |
| | | _dir="${sa_asr_exp}/${sa_asr_inference_tag}.cluster/${dset}" |
| | | |
| | | python local/proce_text.py ${_data}/text ${_data}/text.proc |
| | | python local/proce_text.py ${_dir}/text ${_dir}/text.proc |
| | | python utils/proce_text.py ${_data}/text ${_data}/text.proc |
| | | python utils/proce_text.py ${_dir}/text ${_dir}/text.proc |
| | | |
| | | python local/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer |
| | | python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer |
| | | tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt |
| | | cat ${_dir}/text.cer.txt |
| | | |
| | |
| | | |
| | | _suf="" |
| | | |
| | | utils/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" |
| | | local/copy_data_dir.sh --validate_opts --non-print data/"${dset}" "${data_feats}${_suf}/${dset}" |
| | | |
| | | rm -f ${data_feats}${_suf}/${dset}/{segments,wav.scp,reco2file_and_channel,reco2dur} |
| | | _opts= |
| | |
| | | ${_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ |
| | | python -m funasr.bin.asr_inference_launch \ |
| | | --batch_size 1 \ |
| | | --mc True \ |
| | | --nbest 1 \ |
| | | --ngpu "${_ngpu}" \ |
| | | --njob ${njob_infer} \ |
| | |
| | | n_fft: 400 |
| | | win_length: 400 |
| | | hop_length: 160 |
| | | use_channel: 0 |
| | | |
| | | # encoder related |
| | | encoder: conformer |
| | |
| | | n_fft: 400 |
| | | win_length: 400 |
| | | hop_length: 160 |
| | | use_channel: 0 |
| | | |
| | | # encoder related |
| | | asr_encoder: conformer |
| | |
| | | utils/filter_scp.pl -f 1 $near_dir/text $near_dir/utt2spk_all | sort -u > $near_dir/utt2spk |
| | | #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $near_dir/utt2spk_old >$near_dir/tmp1 |
| | | #sed -e 's/-[a-z,A-Z,0-9]\+$//' $near_dir/tmp1 | sort -u > $near_dir/utt2spk |
| | | utils/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt |
| | | local/utt2spk_to_spk2utt.pl $near_dir/utt2spk > $near_dir/spk2utt |
| | | utils/filter_scp.pl -f 1 $near_dir/text $near_dir/segments_all | sort -u > $near_dir/segments |
| | | sed -e 's/ $//g' $near_dir/text> $near_dir/tmp1 |
| | | sed -e 's/!//g' $near_dir/tmp1> $near_dir/tmp2 |
| | |
| | | utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk |
| | | #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk |
| | | |
| | | utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt |
| | | local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt |
| | | utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments |
| | | sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1 |
| | | sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2 |
| | |
| | | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then |
| | | log "stage 3: finali data process" |
| | | |
| | | utils/copy_data_dir.sh $near_dir data/${tgt}_Ali_near |
| | | utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far |
| | | local/copy_data_dir.sh $near_dir data/${tgt}_Ali_near |
| | | local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far |
| | | |
| | | sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo |
| | | sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo |
| | |
| | | python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir |
| | | |
| | | cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text |
| | | utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt |
| | | local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt |
| | | |
| | | ./utils/fix_data_dir.sh $far_single_speaker_dir |
| | | utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker |
| | | ./local/fix_data_dir.sh $far_single_speaker_dir |
| | | local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker |
| | | |
| | | # remove space in text |
| | | for x in ${tgt}_Ali_far_single_speaker; do |
| | |
| | | utils/filter_scp.pl -f 1 $far_dir/text $far_dir/utt2spk_all | sort -u > $far_dir/utt2spk |
| | | #sed -e 's/ [a-z,A-Z,_,0-9,-]\+SPK/ SPK/' $far_dir/utt2spk_old >$far_dir/utt2spk |
| | | |
| | | utils/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt |
| | | local/utt2spk_to_spk2utt.pl $far_dir/utt2spk > $far_dir/spk2utt |
| | | utils/filter_scp.pl -f 1 $far_dir/text $far_dir/segments_all | sort -u > $far_dir/segments |
| | | sed -e 's/SRC/$/g' $far_dir/text> $far_dir/tmp1 |
| | | sed -e 's/ $//g' $far_dir/tmp1> $far_dir/tmp2 |
| | |
| | | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then |
| | | log "stage 2: finali data process" |
| | | |
| | | utils/copy_data_dir.sh $far_dir data/${tgt}_Ali_far |
| | | local/copy_data_dir.sh $far_dir data/${tgt}_Ali_far |
| | | |
| | | sort $far_dir/utt2spk_all_fifo > data/${tgt}_Ali_far/utt2spk_all_fifo |
| | | sed -i "s/src/$/g" data/${tgt}_Ali_far/utt2spk_all_fifo |
| | |
| | | python local/process_textgrid_to_single_speaker_wav.py --path $far_single_speaker_dir |
| | | |
| | | cp $far_single_speaker_dir/utt2spk $far_single_speaker_dir/text |
| | | utils/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt |
| | | local/utt2spk_to_spk2utt.pl $far_single_speaker_dir/utt2spk > $far_single_speaker_dir/spk2utt |
| | | |
| | | ./utils/fix_data_dir.sh $far_single_speaker_dir |
| | | utils/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker |
| | | ./local/fix_data_dir.sh $far_single_speaker_dir |
| | | local/copy_data_dir.sh $far_single_speaker_dir data/${tgt}_Ali_far_single_speaker |
| | | |
| | | # remove space in text |
| | | for x in ${tgt}_Ali_far_single_speaker; do |
| File was renamed from egs/alimeeting/sa-asr/utils/combine_data.sh |
| | |
| | | for in_dir in $*; do |
| | | if [ ! -f $in_dir/segments ]; then |
| | | echo "$0 [info]: will generate missing segments for $in_dir" 1>&2 |
| | | utils/data/get_segments_for_data.sh $in_dir |
| | | local/data/get_segments_for_data.sh $in_dir |
| | | else |
| | | cat $in_dir/segments |
| | | fi |
| | |
| | | fi |
| | | done |
| | | |
| | | utils/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt |
| | | local/utt2spk_to_spk2utt.pl <$dest/utt2spk >$dest/spk2utt |
| | | |
| | | if [[ $dir_with_frame_shift ]]; then |
| | | cp $dir_with_frame_shift/frame_shift $dest |
| | | fi |
| | | |
| | | if ! $skip_fix ; then |
| | | utils/fix_data_dir.sh $dest || exit 1; |
| | | local/fix_data_dir.sh $dest || exit 1; |
| | | fi |
| | | |
| | | exit 0 |
| File was renamed from egs/alimeeting/sa-asr/utils/copy_data_dir.sh |
| | |
| | | cat $srcdir/utt2uniq | awk -v p=$utt_prefix -v s=$utt_suffix '{printf("%s%s%s %s\n", p, $1, s, $2);}' > $destdir/utt2uniq |
| | | fi |
| | | |
| | | cat $srcdir/utt2spk | utils/apply_map.pl -f 1 $destdir/utt_map | \ |
| | | utils/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk |
| | | cat $srcdir/utt2spk | local/apply_map.pl -f 1 $destdir/utt_map | \ |
| | | local/apply_map.pl -f 2 $destdir/spk_map >$destdir/utt2spk |
| | | |
| | | utils/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt |
| | | local/utt2spk_to_spk2utt.pl <$destdir/utt2spk >$destdir/spk2utt |
| | | |
| | | if [ -f $srcdir/feats.scp ]; then |
| | | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp |
| | | local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/feats.scp >$destdir/feats.scp |
| | | fi |
| | | |
| | | if [ -f $srcdir/vad.scp ]; then |
| | | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp |
| | | local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/vad.scp >$destdir/vad.scp |
| | | fi |
| | | |
| | | if [ -f $srcdir/segments ]; then |
| | | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments |
| | | local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/segments >$destdir/segments |
| | | cp $srcdir/wav.scp $destdir |
| | | else # no segments->wav indexed by utt. |
| | | if [ -f $srcdir/wav.scp ]; then |
| | | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp |
| | | local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/wav.scp >$destdir/wav.scp |
| | | fi |
| | | fi |
| | | |
| | |
| | | fi |
| | | |
| | | if [ -f $srcdir/text ]; then |
| | | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text |
| | | local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/text >$destdir/text |
| | | fi |
| | | if [ -f $srcdir/utt2dur ]; then |
| | | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur |
| | | local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2dur >$destdir/utt2dur |
| | | fi |
| | | if [ -f $srcdir/utt2num_frames ]; then |
| | | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames |
| | | local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/utt2num_frames >$destdir/utt2num_frames |
| | | fi |
| | | if [ -f $srcdir/reco2dur ]; then |
| | | if [ -f $srcdir/segments ]; then |
| | | cp $srcdir/reco2dur $destdir/reco2dur |
| | | else |
| | | utils/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur |
| | | local/apply_map.pl -f 1 $destdir/utt_map <$srcdir/reco2dur >$destdir/reco2dur |
| | | fi |
| | | fi |
| | | if [ -f $srcdir/spk2gender ]; then |
| | | utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender |
| | | local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/spk2gender >$destdir/spk2gender |
| | | fi |
| | | if [ -f $srcdir/cmvn.scp ]; then |
| | | utils/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp |
| | | local/apply_map.pl -f 1 $destdir/spk_map <$srcdir/cmvn.scp >$destdir/cmvn.scp |
| | | fi |
| | | for f in frame_shift stm glm ctm; do |
| | | if [ -f $srcdir/$f ]; then |
| | |
| | | [ ! -f $srcdir/feats.scp ] && validate_opts="$validate_opts --no-feats" |
| | | [ ! -f $srcdir/text ] && validate_opts="$validate_opts --no-text" |
| | | |
| | | utils/validate_data_dir.sh $validate_opts $destdir |
| | | local/validate_data_dir.sh $validate_opts $destdir |
| File was renamed from egs/alimeeting/sa-asr/utils/data/get_segments_for_data.sh |
| | |
| | | data=$1 |
| | | |
| | | if [ ! -s $data/utt2dur ]; then |
| | | utils/data/get_utt2dur.sh $data 1>&2 || exit 1; |
| | | local/data/get_utt2dur.sh $data 1>&2 || exit 1; |
| | | fi |
| | | |
| | | # <utt-id> <utt-id> 0 <utt-dur> |
| File was renamed from egs/alimeeting/sa-asr/utils/data/get_utt2dur.sh |
| | |
| | | nj=$num_utts |
| | | fi |
| | | |
| | | utils/data/split_data.sh --per-utt $data $nj |
| | | local/data/split_data.sh --per-utt $data $nj |
| | | sdata=$data/split${nj}utt |
| | | |
| | | $cmd JOB=1:$nj $data/log/get_durations.JOB.log \ |
| File was renamed from egs/alimeeting/sa-asr/utils/data/split_data.sh |
| | |
| | | nt=`cat $data/text 2>/dev/null | wc -l` # take it as zero if no such file |
| | | if [ -f $data/feats.scp ] && [ $nu -ne $nf ]; then |
| | | echo "** split_data.sh: warning, #lines is (utt2spk,feats.scp) is ($nu,$nf); you can " |
| | | echo "** use utils/fix_data_dir.sh $data to fix this." |
| | | echo "** use local/fix_data_dir.sh $data to fix this." |
| | | fi |
| | | if [ -f $data/text ] && [ $nu -ne $nt ]; then |
| | | echo "** split_data.sh: warning, #lines is (utt2spk,text) is ($nu,$nt); you can " |
| | | echo "** use utils/fix_data_dir.sh to fix this." |
| | | echo "** use local/fix_data_dir.sh to fix this." |
| | | fi |
| | | |
| | | |
| | |
| | | |
| | | for n in `seq $numsplit`; do |
| | | dsn=$data/split${numsplit}${utt}/$n |
| | | utils/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1; |
| | | local/utt2spk_to_spk2utt.pl $dsn/utt2spk > $dsn/spk2utt || exit 1; |
| | | done |
| | | |
| | | maybe_wav_scp= |
| File was renamed from egs/alimeeting/sa-asr/utils/fix_data_dir.sh |
| | |
| | | |
| | | function filter_speakers { |
| | | # throughout this program, we regard utt2spk as primary and spk2utt as derived, so... |
| | | utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt |
| | | local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt |
| | | |
| | | cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers |
| | | for s in cmvn.scp spk2gender; do |
| | |
| | | done |
| | | |
| | | filter_file $tmpdir/speakers $data/spk2utt |
| | | utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk |
| | | local/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk |
| | | |
| | | for s in cmvn.scp spk2gender $spk_extra_files; do |
| | | f=$data/$s |
| | |
| | | filter_speakers |
| | | filter_recordings |
| | | |
| | | utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt |
| | | local/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt |
| | | |
| | | echo "fix_data_dir.sh: old files are kept in $data/.backup" |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | import argparse |
| | | import logging |
| | | from io import BytesIO |
| | | from pathlib import Path |
| | | from typing import Tuple, Optional |
| | | |
| | | import kaldiio |
| | | import humanfriendly |
| | | import numpy as np |
| | | import resampy |
| | | import soundfile |
| | | from tqdm import tqdm |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.fileio.read_text import read_2column_text |
| | | from funasr.fileio.sound_scp import SoundScpWriter |
| | | |
| | | |
| | | def humanfriendly_or_none(value: str): |
| | | if value in ("none", "None", "NONE"): |
| | | return None |
| | | return humanfriendly.parse_size(value) |
| | | |
| | | |
| | | def str2int_tuple(integers: str) -> Optional[Tuple[int, ...]]: |
| | | """ |
| | | |
| | | >>> str2int_tuple('3,4,5') |
| | | (3, 4, 5) |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | if integers.strip() in ("none", "None", "NONE", "null", "Null", "NULL"): |
| | | return None |
| | | return tuple(map(int, integers.strip().split(","))) |
| | | |
| | | |
| | | def main(): |
| | | logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s" |
| | | logging.basicConfig(level=logging.INFO, format=logfmt) |
| | | logging.info(get_commandline_args()) |
| | | |
| | | parser = argparse.ArgumentParser( |
| | | description='Create waves list from "wav.scp"', |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | parser.add_argument("scp") |
| | | parser.add_argument("outdir") |
| | | parser.add_argument( |
| | | "--name", |
| | | default="wav", |
| | | help="Specify the prefix word of output file name " 'such as "wav.scp"', |
| | | ) |
| | | parser.add_argument("--segments", default=None) |
| | | parser.add_argument( |
| | | "--fs", |
| | | type=humanfriendly_or_none, |
| | | default=None, |
| | | help="If the sampling rate specified, " "Change the sampling rate.", |
| | | ) |
| | | parser.add_argument("--audio-format", default="wav") |
| | | group = parser.add_mutually_exclusive_group() |
| | | group.add_argument("--ref-channels", default=None, type=str2int_tuple) |
| | | group.add_argument("--utt2ref-channels", default=None, type=str) |
| | | args = parser.parse_args() |
| | | |
| | | out_num_samples = Path(args.outdir) / f"utt2num_samples" |
| | | |
| | | if args.ref_channels is not None: |
| | | |
| | | def utt2ref_channels(x) -> Tuple[int, ...]: |
| | | return args.ref_channels |
| | | |
| | | elif args.utt2ref_channels is not None: |
| | | utt2ref_channels_dict = read_2column_text(args.utt2ref_channels) |
| | | |
| | | def utt2ref_channels(x, d=utt2ref_channels_dict) -> Tuple[int, ...]: |
| | | chs_str = d[x] |
| | | return tuple(map(int, chs_str.split())) |
| | | |
| | | else: |
| | | utt2ref_channels = None |
| | | |
| | | Path(args.outdir).mkdir(parents=True, exist_ok=True) |
| | | out_wavscp = Path(args.outdir) / f"{args.name}.scp" |
| | | if args.segments is not None: |
| | | # Note: kaldiio supports only wav-pcm-int16le file. |
| | | loader = kaldiio.load_scp_sequential(args.scp, segments=args.segments) |
| | | if args.audio_format.endswith("ark"): |
| | | fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") |
| | | fscp = out_wavscp.open("w") |
| | | else: |
| | | writer = SoundScpWriter( |
| | | args.outdir, |
| | | out_wavscp, |
| | | format=args.audio_format, |
| | | ) |
| | | |
| | | with out_num_samples.open("w") as fnum_samples: |
| | | for uttid, (rate, wave) in tqdm(loader): |
| | | # wave: (Time,) or (Time, Nmic) |
| | | if wave.ndim == 2 and utt2ref_channels is not None: |
| | | wave = wave[:, utt2ref_channels(uttid)] |
| | | |
| | | if args.fs is not None and args.fs != rate: |
| | | # FIXME(kamo): To use sox? |
| | | wave = resampy.resample( |
| | | wave.astype(np.float64), rate, args.fs, axis=0 |
| | | ) |
| | | wave = wave.astype(np.int16) |
| | | rate = args.fs |
| | | if args.audio_format.endswith("ark"): |
| | | if "flac" in args.audio_format: |
| | | suf = "flac" |
| | | elif "wav" in args.audio_format: |
| | | suf = "wav" |
| | | else: |
| | | raise RuntimeError("wav.ark or flac") |
| | | |
| | | # NOTE(kamo): Using extended ark format style here. |
| | | # This format is incompatible with Kaldi |
| | | kaldiio.save_ark( |
| | | fark, |
| | | {uttid: (wave, rate)}, |
| | | scp=fscp, |
| | | append=True, |
| | | write_function=f"soundfile_{suf}", |
| | | ) |
| | | |
| | | else: |
| | | writer[uttid] = rate, wave |
| | | fnum_samples.write(f"{uttid} {len(wave)}\n") |
| | | else: |
| | | if args.audio_format.endswith("ark"): |
| | | fark = open(Path(args.outdir) / f"data_{args.name}.ark", "wb") |
| | | else: |
| | | wavdir = Path(args.outdir) / f"data_{args.name}" |
| | | wavdir.mkdir(parents=True, exist_ok=True) |
| | | |
| | | with Path(args.scp).open("r") as fscp, out_wavscp.open( |
| | | "w" |
| | | ) as fout, out_num_samples.open("w") as fnum_samples: |
| | | for line in tqdm(fscp): |
| | | uttid, wavpath = line.strip().split(None, 1) |
| | | |
| | | if wavpath.endswith("|"): |
| | | # Streaming input e.g. cat a.wav | |
| | | with kaldiio.open_like_kaldi(wavpath, "rb") as f: |
| | | with BytesIO(f.read()) as g: |
| | | wave, rate = soundfile.read(g, dtype=np.int16) |
| | | if wave.ndim == 2 and utt2ref_channels is not None: |
| | | wave = wave[:, utt2ref_channels(uttid)] |
| | | |
| | | if args.fs is not None and args.fs != rate: |
| | | # FIXME(kamo): To use sox? |
| | | wave = resampy.resample( |
| | | wave.astype(np.float64), rate, args.fs, axis=0 |
| | | ) |
| | | wave = wave.astype(np.int16) |
| | | rate = args.fs |
| | | |
| | | if args.audio_format.endswith("ark"): |
| | | if "flac" in args.audio_format: |
| | | suf = "flac" |
| | | elif "wav" in args.audio_format: |
| | | suf = "wav" |
| | | else: |
| | | raise RuntimeError("wav.ark or flac") |
| | | |
| | | # NOTE(kamo): Using extended ark format style here. |
| | | # This format is incompatible with Kaldi |
| | | kaldiio.save_ark( |
| | | fark, |
| | | {uttid: (wave, rate)}, |
| | | scp=fout, |
| | | append=True, |
| | | write_function=f"soundfile_{suf}", |
| | | ) |
| | | else: |
| | | owavpath = str(wavdir / f"{uttid}.{args.audio_format}") |
| | | soundfile.write(owavpath, wave, rate) |
| | | fout.write(f"{uttid} {owavpath}\n") |
| | | else: |
| | | wave, rate = soundfile.read(wavpath, dtype=np.int16) |
| | | if wave.ndim == 2 and utt2ref_channels is not None: |
| | | wave = wave[:, utt2ref_channels(uttid)] |
| | | save_asis = False |
| | | |
| | | elif args.audio_format.endswith("ark"): |
| | | save_asis = False |
| | | |
| | | elif Path(wavpath).suffix == "." + args.audio_format and ( |
| | | args.fs is None or args.fs == rate |
| | | ): |
| | | save_asis = True |
| | | |
| | | else: |
| | | save_asis = False |
| | | |
| | | if save_asis: |
| | | # Neither --segments nor --fs are specified and |
| | | # the line doesn't end with "|", |
| | | # i.e. not using unix-pipe, |
| | | # only in this case, |
| | | # just using the original file as is. |
| | | fout.write(f"{uttid} {wavpath}\n") |
| | | else: |
| | | if args.fs is not None and args.fs != rate: |
| | | # FIXME(kamo): To use sox? |
| | | wave = resampy.resample( |
| | | wave.astype(np.float64), rate, args.fs, axis=0 |
| | | ) |
| | | wave = wave.astype(np.int16) |
| | | rate = args.fs |
| | | |
| | | if args.audio_format.endswith("ark"): |
| | | if "flac" in args.audio_format: |
| | | suf = "flac" |
| | | elif "wav" in args.audio_format: |
| | | suf = "wav" |
| | | else: |
| | | raise RuntimeError("wav.ark or flac") |
| | | |
| | | # NOTE(kamo): Using extended ark format style here. |
| | | # This format is not supported in Kaldi. |
| | | kaldiio.save_ark( |
| | | fark, |
| | | {uttid: (wave, rate)}, |
| | | scp=fout, |
| | | append=True, |
| | | write_function=f"soundfile_{suf}", |
| | | ) |
| | | else: |
| | | owavpath = str(wavdir / f"{uttid}.{args.audio_format}") |
| | | soundfile.write(owavpath, wave, rate) |
| | | fout.write(f"{uttid} {owavpath}\n") |
| | | fnum_samples.write(f"{uttid} {len(wave)}\n") |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | set -euo pipefail |
| | | SECONDS=0 |
| | | log() { |
| | | local fname=${BASH_SOURCE[1]##*/} |
| | | echo -e "$(date '+%Y-%m-%dT%H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" |
| | | } |
| | | help_message=$(cat << EOF |
| | | Usage: $0 <in-wav.scp> <out-datadir> [<logdir> [<outdir>]] |
| | | e.g. |
| | | $0 data/test/wav.scp data/test_format/ |
| | | |
| | | Format 'wav.scp': In short words, |
| | | changing "kaldi-datadir" to "modified-kaldi-datadir" |
| | | |
| | | The 'wav.scp' format in kaldi is very flexible, |
| | | e.g. It can use unix-pipe as describing that wav file, |
| | | but it sometime looks confusing and make scripts more complex. |
| | | This tools creates actual wav files from 'wav.scp' |
| | | and also segments wav files using 'segments'. |
| | | |
| | | Options |
| | | --fs <fs> |
| | | --segments <segments> |
| | | --nj <nj> |
| | | --cmd <cmd> |
| | | EOF |
| | | ) |
| | | |
| | | out_filename=wav.scp |
| | | cmd=utils/run.pl |
| | | nj=30 |
| | | fs=none |
| | | segments= |
| | | |
| | | ref_channels= |
| | | utt2ref_channels= |
| | | |
| | | audio_format=wav |
| | | write_utt2num_samples=true |
| | | |
| | | log "$0 $*" |
| | | . utils/parse_options.sh |
| | | |
| | | if [ $# -ne 2 ] && [ $# -ne 3 ] && [ $# -ne 4 ]; then |
| | | log "${help_message}" |
| | | log "Error: invalid command line arguments" |
| | | exit 1 |
| | | fi |
| | | |
| | | . ./path.sh # Setup the environment |
| | | |
| | | scp=$1 |
| | | if [ ! -f "${scp}" ]; then |
| | | log "${help_message}" |
| | | echo "$0: Error: No such file: ${scp}" |
| | | exit 1 |
| | | fi |
| | | dir=$2 |
| | | |
| | | |
| | | if [ $# -eq 2 ]; then |
| | | logdir=${dir}/logs |
| | | outdir=${dir}/data |
| | | |
| | | elif [ $# -eq 3 ]; then |
| | | logdir=$3 |
| | | outdir=${dir}/data |
| | | |
| | | elif [ $# -eq 4 ]; then |
| | | logdir=$3 |
| | | outdir=$4 |
| | | fi |
| | | |
| | | |
| | | mkdir -p ${logdir} |
| | | |
| | | rm -f "${dir}/${out_filename}" |
| | | |
| | | |
| | | opts= |
| | | if [ -n "${utt2ref_channels}" ]; then |
| | | opts="--utt2ref-channels ${utt2ref_channels} " |
| | | elif [ -n "${ref_channels}" ]; then |
| | | opts="--ref-channels ${ref_channels} " |
| | | fi |
| | | |
| | | |
| | | if [ -n "${segments}" ]; then |
| | | log "[info]: using ${segments}" |
| | | nutt=$(<${segments} wc -l) |
| | | nj=$((nj<nutt?nj:nutt)) |
| | | |
| | | split_segments="" |
| | | for n in $(seq ${nj}); do |
| | | split_segments="${split_segments} ${logdir}/segments.${n}" |
| | | done |
| | | |
| | | utils/split_scp.pl "${segments}" ${split_segments} |
| | | |
| | | ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \ |
| | | local/format_wav_scp.py \ |
| | | ${opts} \ |
| | | --fs ${fs} \ |
| | | --audio-format "${audio_format}" \ |
| | | "--segment=${logdir}/segments.JOB" \ |
| | | "${scp}" "${outdir}/format.JOB" |
| | | |
| | | else |
| | | log "[info]: without segments" |
| | | nutt=$(<${scp} wc -l) |
| | | nj=$((nj<nutt?nj:nutt)) |
| | | |
| | | split_scps="" |
| | | for n in $(seq ${nj}); do |
| | | split_scps="${split_scps} ${logdir}/wav.${n}.scp" |
| | | done |
| | | |
| | | utils/split_scp.pl "${scp}" ${split_scps} |
| | | ${cmd} "JOB=1:${nj}" "${logdir}/format_wav_scp.JOB.log" \ |
| | | local/format_wav_scp.py \ |
| | | ${opts} \ |
| | | --fs "${fs}" \ |
| | | --audio-format "${audio_format}" \ |
| | | "${logdir}/wav.JOB.scp" ${outdir}/format.JOB"" |
| | | fi |
| | | |
| | | # Workaround for the NFS problem |
| | | ls ${outdir}/format.* > /dev/null |
| | | |
| | | # concatenate the .scp files together. |
| | | for n in $(seq ${nj}); do |
| | | cat "${outdir}/format.${n}/wav.scp" || exit 1; |
| | | done > "${dir}/${out_filename}" || exit 1 |
| | | |
| | | if "${write_utt2num_samples}"; then |
| | | for n in $(seq ${nj}); do |
| | | cat "${outdir}/format.${n}/utt2num_samples" || exit 1; |
| | | done > "${dir}/utt2num_samples" || exit 1 |
| | | fi |
| | | |
| | | log "Successfully finished. [elapsed=${SECONDS}s]" |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | # 2020 @kamo-naoyuki |
| | | # This file was copied from Kaldi and |
| | | # I deleted parts related to wav duration |
| | | # because we shouldn't use kaldi's command here |
| | | # and we don't need the files actually. |
| | | |
| | | # Copyright 2013 Johns Hopkins University (author: Daniel Povey) |
| | | # 2014 Tom Ko |
| | | # 2018 Emotech LTD (author: Pawel Swietojanski) |
| | | # Apache 2.0 |
| | | |
| | | # This script operates on a directory, such as in data/train/, |
| | | # that contains some subset of the following files: |
| | | # wav.scp |
| | | # spk2utt |
| | | # utt2spk |
| | | # text |
| | | # |
| | | # It generates the files which are used for perturbing the speed of the original data. |
| | | |
| | | export LC_ALL=C |
| | | set -euo pipefail |
| | | |
| | | if [[ $# != 3 ]]; then |
| | | echo "Usage: perturb_data_dir_speed.sh <warping-factor> <srcdir> <destdir>" |
| | | echo "e.g.:" |
| | | echo " $0 0.9 data/train_si284 data/train_si284p" |
| | | exit 1 |
| | | fi |
| | | |
| | | factor=$1 |
| | | srcdir=$2 |
| | | destdir=$3 |
| | | label="sp" |
| | | spk_prefix="${label}${factor}-" |
| | | utt_prefix="${label}${factor}-" |
| | | |
| | | #check is sox on the path |
| | | |
| | | ! command -v sox &>/dev/null && echo "sox: command not found" && exit 1; |
| | | |
| | | if [[ ! -f ${srcdir}/utt2spk ]]; then |
| | | echo "$0: no such file ${srcdir}/utt2spk" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [[ ${destdir} == "${srcdir}" ]]; then |
| | | echo "$0: this script requires <srcdir> and <destdir> to be different." |
| | | exit 1 |
| | | fi |
| | | |
| | | mkdir -p "${destdir}" |
| | | |
| | | <"${srcdir}"/utt2spk awk -v p="${utt_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/utt_map" |
| | | <"${srcdir}"/spk2utt awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/spk_map" |
| | | <"${srcdir}"/wav.scp awk -v p="${spk_prefix}" '{printf("%s %s%s\n", $1, p, $1);}' > "${destdir}/reco_map" |
| | | if [[ ! -f ${srcdir}/utt2uniq ]]; then |
| | | <"${srcdir}/utt2spk" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $1);}' > "${destdir}/utt2uniq" |
| | | else |
| | | <"${srcdir}/utt2uniq" awk -v p="${utt_prefix}" '{printf("%s%s %s\n", p, $1, $2);}' > "${destdir}/utt2uniq" |
| | | fi |
| | | |
| | | |
| | | <"${srcdir}"/utt2spk utils/apply_map.pl -f 1 "${destdir}"/utt_map | \ |
| | | utils/apply_map.pl -f 2 "${destdir}"/spk_map >"${destdir}"/utt2spk |
| | | |
| | | utils/utt2spk_to_spk2utt.pl <"${destdir}"/utt2spk >"${destdir}"/spk2utt |
| | | |
| | | if [[ -f ${srcdir}/segments ]]; then |
| | | |
| | | utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/segments | \ |
| | | utils/apply_map.pl -f 2 "${destdir}"/reco_map | \ |
| | | awk -v factor="${factor}" \ |
| | | '{s=$3/factor; e=$4/factor; if (e > s + 0.01) { printf("%s %s %.2f %.2f\n", $1, $2, $3/factor, $4/factor);} }' \ |
| | | >"${destdir}"/segments |
| | | |
| | | utils/apply_map.pl -f 1 "${destdir}"/reco_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ |
| | | # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" |
| | | awk -v factor="${factor}" \ |
| | | '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} |
| | | else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } |
| | | else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ |
| | | > "${destdir}"/wav.scp |
| | | if [[ -f ${srcdir}/reco2file_and_channel ]]; then |
| | | utils/apply_map.pl -f 1 "${destdir}"/reco_map \ |
| | | <"${srcdir}"/reco2file_and_channel >"${destdir}"/reco2file_and_channel |
| | | fi |
| | | |
| | | else # no segments->wav indexed by utterance. |
| | | if [[ -f ${srcdir}/wav.scp ]]; then |
| | | utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/wav.scp | sed 's/| *$/ |/' | \ |
| | | # Handle three cases of rxfilenames appropriately; "input piped command", "file offset" and "filename" |
| | | awk -v factor="${factor}" \ |
| | | '{wid=$1; $1=""; if ($NF=="|") {print wid $_ " sox -t wav - -t wav - speed " factor " |"} |
| | | else if (match($0, /:[0-9]+$/)) {print wid " wav-copy" $_ " - | sox -t wav - -t wav - speed " factor " |" } |
| | | else {print wid " sox" $_ " -t wav - speed " factor " |"}}' \ |
| | | > "${destdir}"/wav.scp |
| | | fi |
| | | fi |
| | | |
| | | if [[ -f ${srcdir}/text ]]; then |
| | | utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/text >"${destdir}"/text |
| | | fi |
| | | if [[ -f ${srcdir}/spk2gender ]]; then |
| | | utils/apply_map.pl -f 1 "${destdir}"/spk_map <"${srcdir}"/spk2gender >"${destdir}"/spk2gender |
| | | fi |
| | | if [[ -f ${srcdir}/utt2lang ]]; then |
| | | utils/apply_map.pl -f 1 "${destdir}"/utt_map <"${srcdir}"/utt2lang >"${destdir}"/utt2lang |
| | | fi |
| | | |
| | | rm "${destdir}"/spk_map "${destdir}"/utt_map "${destdir}"/reco_map 2>/dev/null |
| | | echo "$0: generated speed-perturbed version of data in ${srcdir}, in ${destdir}" |
| | | |
| | | utils/validate_data_dir.sh --no-feats --no-text "${destdir}" |
| File was renamed from egs/alimeeting/sa-asr/utils/validate_data_dir.sh |
| | |
| | | check_sorted_and_uniq $data/spk2utt |
| | | |
| | | ! cmp -s <(cat $data/utt2spk | awk '{print $1, $2;}') \ |
| | | <(utils/spk2utt_to_utt2spk.pl $data/spk2utt) && \ |
| | | <(local/spk2utt_to_utt2spk.pl $data/spk2utt) && \ |
| | | echo "$0: spk2utt and utt2spk do not seem to match" && exit 1; |
| | | |
| | | cat $data/utt2spk | awk '{print $1;}' > $tmpdir/utts |
| | |
| | | echo "$0: text contains $n_non_print lines with non-printable characters" &&\ |
| | | exit 1; |
| | | fi |
| | | utils/validate_text.pl $data/text || exit 1; |
| | | local/validate_text.pl $data/text || exit 1; |
| | | check_sorted_and_uniq $data/text |
| | | text_len=`cat $data/text | wc -l` |
| | | illegal_sym_list="<s> </s> #0" |
| | |
| | | |
| | | # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C |
| | | export PYTHONIOENCODING=UTF-8 |
| | | export PATH=$FUNASR_DIR/funasr/bin:$PATH |
| | | export PATH=$PWD/utils/:$PATH |
| | | export PATH=$FUNASR_DIR/funasr/bin:$PATH |
| New file |
| | |
| | | ../../aishell/transformer/utils |
| | |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils import asr_utils, wav_utils, postprocess_utils |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.tasks.asr import frontend_choices |
| | | |
| | | |
| | | header_colors = '\033[95m' |
| | |
| | | asr_train_config, asr_model_file, cmvn_file, device |
| | | ) |
| | | frontend = None |
| | | if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: |
| | | if asr_train_args.frontend=='wav_frontend': |
| | | frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval() |
| | | else: |
| | | frontend_class=frontend_choices.get_class(asr_train_args.frontend) |
| | | frontend = frontend_class(**asr_train_args.frontend_conf).eval() |
| | | |
| | | logging.info("asr_model: {}".format(asr_model)) |
| | | logging.info("asr_train_args: {}".format(asr_train_args)) |
| | |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | |
| | | # Input as audio signal |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | |
| | | batch = {"speech": speech, "speech_lengths": speech_lengths} |
| | | if self.frontend is not None: |
| | | feats, feats_len = self.frontend.forward(speech, speech_lengths) |
| | | feats = to_device(feats, device=self.device) |
| | | feats_len = feats_len.int() |
| | | self.asr_model.frontend = None |
| | | else: |
| | | feats = speech |
| | | feats_len = speech_lengths |
| | | lfr_factor = max(1, (feats.size()[-1] // 80) - 1) |
| | | batch = {"speech": feats, "speech_lengths": feats_len} |
| | | |
| | | # a. To device |
| | | batch = to_device(batch, device=self.device) |
| | |
| | | ngram_weight: float = 0.9, |
| | | nbest: int = 1, |
| | | num_workers: int = 1, |
| | | mc: bool = False, |
| | | **kwargs, |
| | | ): |
| | | inference_pipeline = inference_modelscope( |
| | |
| | | ngram_weight=ngram_weight, |
| | | nbest=nbest, |
| | | num_workers=num_workers, |
| | | mc=mc, |
| | | **kwargs, |
| | | ) |
| | | return inference_pipeline(data_path_and_name_and_type, raw_inputs) |
| | |
| | | ngram_weight: float = 0.9, |
| | | nbest: int = 1, |
| | | num_workers: int = 1, |
| | | mc: bool = False, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | fs=fs, |
| | | mc=True, |
| | | mc=mc, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | |
| | | finish_count = 0 |
| | | file_count = 1 |
| | | # 7 .Start for-loop |
| | |
| | | ) |
| | | group.add_argument("--key_file", type=str_or_none) |
| | | group.add_argument("--allow_variable_data_keys", type=str2bool, default=False) |
| | | |
| | | group.add_argument( |
| | | "--mc", |
| | | type=bool, |
| | | default=False, |
| | | help="MultiChannel input", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("The model configuration related") |
| | | group.add_argument( |
| | | "--vad_infer_config", |
| | |
| | | |
| | | import os |
| | | |
| | | import logging |
| | | |
| | | logging.basicConfig( |
| | | level='INFO', |
| | | format=f"[{os.uname()[1].split('.')[0]}]" |
| | | f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | from funasr.tasks.asr import ASRTask |
| | | |
| | | |
| | |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils import asr_utils, wav_utils, postprocess_utils |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.tasks.asr import frontend_choices |
| | | |
| | | |
| | | header_colors = '\033[95m' |
| | |
| | | asr_train_config, asr_model_file, cmvn_file, device |
| | | ) |
| | | frontend = None |
| | | if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: |
| | | if asr_train_args.frontend=='wav_frontend': |
| | | frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf).eval() |
| | | else: |
| | | frontend_class=frontend_choices.get_class(asr_train_args.frontend) |
| | | frontend = frontend_class(**asr_train_args.frontend_conf).eval() |
| | | |
| | | logging.info("asr_model: {}".format(asr_model)) |
| | | logging.info("asr_train_args: {}".format(asr_train_args)) |
| | |
| | | if isinstance(profile, np.ndarray): |
| | | profile = torch.tensor(profile) |
| | | |
| | | batch = {"speech": speech, "speech_lengths": speech_lengths} |
| | | if self.frontend is not None: |
| | | feats, feats_len = self.frontend.forward(speech, speech_lengths) |
| | | feats = to_device(feats, device=self.device) |
| | | feats_len = feats_len.int() |
| | | self.asr_model.frontend = None |
| | | else: |
| | | feats = speech |
| | | feats_len = speech_lengths |
| | | lfr_factor = max(1, (feats.size()[-1] // 80) - 1) |
| | | batch = {"speech": feats, "speech_lengths": feats_len} |
| | | |
| | | # a. To device |
| | | batch = to_device(batch, device=self.device) |
| | |
| | | ngram_weight: float = 0.9, |
| | | nbest: int = 1, |
| | | num_workers: int = 1, |
| | | mc: bool = False, |
| | | **kwargs, |
| | | ): |
| | | inference_pipeline = inference_modelscope( |
| | |
| | | ngram_weight=ngram_weight, |
| | | nbest=nbest, |
| | | num_workers=num_workers, |
| | | mc=mc, |
| | | **kwargs, |
| | | ) |
| | | return inference_pipeline(data_path_and_name_and_type, raw_inputs) |
| | |
| | | ngram_weight: float = 0.9, |
| | | nbest: int = 1, |
| | | num_workers: int = 1, |
| | | mc: bool = False, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | fs=fs, |
| | | mc=True, |
| | | mc=mc, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | |
| | | |
| | | import os |
| | | |
| | | import logging |
| | | |
| | | logging.basicConfig( |
| | | level='INFO', |
| | | format=f"[{os.uname()[1].split('.')[0]}]" |
| | | f" %(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | from funasr.tasks.sa_asr import ASRTask |
| | | |
| | | |
| | |
| | | loss = self.criterion(pred, label) |
| | | denom = (~pad_mask).sum() if self.normalize_length else pred.shape[0] |
| | | return loss.masked_fill(pad_mask, 0).sum() / denom |
| | | |
| | | |
| | | class NllLoss(nn.Module): |
| | | """Nll loss. |
| | | |
| | | :param int size: the number of class |
| | | :param int padding_idx: ignored class id |
| | | :param bool normalize_length: normalize loss by sequence length if True |
| | | :param torch.nn.Module criterion: loss function |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | padding_idx, |
| | | normalize_length=False, |
| | | criterion=nn.NLLLoss(reduction='none'), |
| | | ): |
| | | """Construct an LabelSmoothingLoss object.""" |
| | | super(NllLoss, self).__init__() |
| | | self.criterion = criterion |
| | | self.padding_idx = padding_idx |
| | | self.size = size |
| | | self.true_dist = None |
| | | self.normalize_length = normalize_length |
| | | |
| | | def forward(self, x, target): |
| | | """Compute loss between x and target. |
| | | |
| | | :param torch.Tensor x: prediction (batch, seqlen, class) |
| | | :param torch.Tensor target: |
| | | target signal masked with self.padding_id (batch, seqlen) |
| | | :return: scalar float value |
| | | :rtype torch.Tensor |
| | | """ |
| | | assert x.size(2) == self.size |
| | | batch_size = x.size(0) |
| | | x = x.view(-1, self.size) |
| | | target = target.view(-1) |
| | | with torch.no_grad(): |
| | | ignore = target == self.padding_idx # (B,) |
| | | total = len(target) - ignore.sum().item() |
| | | target = target.masked_fill(ignore, 0) # avoid -1 index |
| | | kl = self.criterion(x , target) |
| | | denom = total if self.normalize_length else batch_size |
| | | return kl.masked_fill(ignore, 0).sum() / denom |
| | |
| | | |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.modules.attention import MultiHeadedAttention |
| | | from funasr.modules.attention import CosineDistanceAttention |
| | | from funasr.modules.dynamic_conv import DynamicConvolution |
| | | from funasr.modules.dynamic_conv2d import DynamicConvolution2D |
| | | from funasr.modules.embedding import PositionalEncoding |
| | |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | ) |
| | | |
| | | class BaseSAAsrTransformerDecoder(AbsDecoder, BatchScorerInterface): |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | spker_embedding_dim: int = 256, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | input_layer: str = "embed", |
| | | use_asr_output_layer: bool = True, |
| | | use_spk_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | attention_dim = encoder_output_size |
| | | |
| | | if input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(vocab_size, attention_dim), |
| | | pos_enc_class(attention_dim, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(vocab_size, attention_dim), |
| | | torch.nn.LayerNorm(attention_dim), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | pos_enc_class(attention_dim, positional_dropout_rate), |
| | | ) |
| | | else: |
| | | raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") |
| | | |
| | | self.normalize_before = normalize_before |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(attention_dim) |
| | | if use_asr_output_layer: |
| | | self.asr_output_layer = torch.nn.Linear(attention_dim, vocab_size) |
| | | else: |
| | | self.asr_output_layer = None |
| | | |
| | | if use_spk_output_layer: |
| | | self.spk_output_layer = torch.nn.Linear(attention_dim, spker_embedding_dim) |
| | | else: |
| | | self.spk_output_layer = None |
| | | |
| | | self.cos_distance_att = CosineDistanceAttention() |
| | | |
| | | self.decoder1 = None |
| | | self.decoder2 = None |
| | | self.decoder3 = None |
| | | self.decoder4 = None |
| | | |
| | | def forward( |
| | | self, |
| | | asr_hs_pad: torch.Tensor, |
| | | spk_hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | profile: torch.Tensor, |
| | | profile_lens: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | |
| | | tgt = ys_in_pad |
| | | # tgt_mask: (B, 1, L) |
| | | tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) |
| | | # m: (1, L, L) |
| | | m = subsequent_mask(tgt_mask.size(-1), device=tgt_mask.device).unsqueeze(0) |
| | | # tgt_mask: (B, L, L) |
| | | tgt_mask = tgt_mask & m |
| | | |
| | | asr_memory = asr_hs_pad |
| | | spk_memory = spk_hs_pad |
| | | memory_mask = (~make_pad_mask(hlens))[:, None, :].to(asr_memory.device) |
| | | # Spk decoder |
| | | x = self.embed(tgt) |
| | | |
| | | x, tgt_mask, asr_memory, spk_memory, memory_mask, z = self.decoder1( |
| | | x, tgt_mask, asr_memory, spk_memory, memory_mask |
| | | ) |
| | | x, tgt_mask, spk_memory, memory_mask = self.decoder2( |
| | | x, tgt_mask, spk_memory, memory_mask |
| | | ) |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | if self.spk_output_layer is not None: |
| | | x = self.spk_output_layer(x) |
| | | dn, weights = self.cos_distance_att(x, profile, profile_lens) |
| | | # Asr decoder |
| | | x, tgt_mask, asr_memory, memory_mask = self.decoder3( |
| | | z, tgt_mask, asr_memory, memory_mask, dn |
| | | ) |
| | | x, tgt_mask, asr_memory, memory_mask = self.decoder4( |
| | | x, tgt_mask, asr_memory, memory_mask |
| | | ) |
| | | |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | if self.asr_output_layer is not None: |
| | | x = self.asr_output_layer(x) |
| | | |
| | | olens = tgt_mask.sum(1) |
| | | return x, weights, olens |
| | | |
| | | |
| | | def forward_one_step( |
| | | self, |
| | | tgt: torch.Tensor, |
| | | tgt_mask: torch.Tensor, |
| | | asr_memory: torch.Tensor, |
| | | spk_memory: torch.Tensor, |
| | | profile: torch.Tensor, |
| | | cache: List[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
| | | |
| | | x = self.embed(tgt) |
| | | |
| | | if cache is None: |
| | | cache = [None] * (2 + len(self.decoder2) + len(self.decoder4)) |
| | | new_cache = [] |
| | | x, tgt_mask, asr_memory, spk_memory, _, z = self.decoder1( |
| | | x, tgt_mask, asr_memory, spk_memory, None, cache=cache[0] |
| | | ) |
| | | new_cache.append(x) |
| | | for c, decoder in zip(cache[1: len(self.decoder2) + 1], self.decoder2): |
| | | x, tgt_mask, spk_memory, _ = decoder( |
| | | x, tgt_mask, spk_memory, None, cache=c |
| | | ) |
| | | new_cache.append(x) |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | else: |
| | | x = x |
| | | if self.spk_output_layer is not None: |
| | | x = self.spk_output_layer(x) |
| | | dn, weights = self.cos_distance_att(x, profile, None) |
| | | |
| | | x, tgt_mask, asr_memory, _ = self.decoder3( |
| | | z, tgt_mask, asr_memory, None, dn, cache=cache[len(self.decoder2) + 1] |
| | | ) |
| | | new_cache.append(x) |
| | | |
| | | for c, decoder in zip(cache[len(self.decoder2) + 2: ], self.decoder4): |
| | | x, tgt_mask, asr_memory, _ = decoder( |
| | | x, tgt_mask, asr_memory, None, cache=c |
| | | ) |
| | | new_cache.append(x) |
| | | |
| | | if self.normalize_before: |
| | | y = self.after_norm(x[:, -1]) |
| | | else: |
| | | y = x[:, -1] |
| | | if self.asr_output_layer is not None: |
| | | y = torch.log_softmax(self.asr_output_layer(y), dim=-1) |
| | | |
| | | return y, weights, new_cache |
| | | |
| | | def score(self, ys, state, asr_enc, spk_enc, profile): |
| | | """Score.""" |
| | | ys_mask = subsequent_mask(len(ys), device=ys.device).unsqueeze(0) |
| | | logp, weights, state = self.forward_one_step( |
| | | ys.unsqueeze(0), ys_mask, asr_enc.unsqueeze(0), spk_enc.unsqueeze(0), profile.unsqueeze(0), cache=state |
| | | ) |
| | | return logp.squeeze(0), weights.squeeze(), state |
| | | |
| | | class SAAsrTransformerDecoder(BaseSAAsrTransformerDecoder): |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | spker_embedding_dim: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | asr_num_blocks: int = 6, |
| | | spk_num_blocks: int = 3, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_asr_output_layer: bool = True, |
| | | use_spk_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | spker_embedding_dim=spker_embedding_dim, |
| | | dropout_rate=dropout_rate, |
| | | positional_dropout_rate=positional_dropout_rate, |
| | | input_layer=input_layer, |
| | | use_asr_output_layer=use_asr_output_layer, |
| | | use_spk_output_layer=use_spk_output_layer, |
| | | pos_enc_class=pos_enc_class, |
| | | normalize_before=normalize_before, |
| | | ) |
| | | |
| | | attention_dim = encoder_output_size |
| | | |
| | | self.decoder1 = SpeakerAttributeSpkDecoderFirstLayer( |
| | | attention_dim, |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, self_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ) |
| | | self.decoder2 = repeat( |
| | | spk_num_blocks - 1, |
| | | lambda lnum: DecoderLayer( |
| | | attention_dim, |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, self_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | |
| | | self.decoder3 = SpeakerAttributeAsrDecoderFirstLayer( |
| | | attention_dim, |
| | | spker_embedding_dim, |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ) |
| | | self.decoder4 = repeat( |
| | | asr_num_blocks - 1, |
| | | lambda lnum: DecoderLayer( |
| | | attention_dim, |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, self_attention_dropout_rate |
| | | ), |
| | | MultiHeadedAttention( |
| | | attention_heads, attention_dim, src_attention_dropout_rate |
| | | ), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | class SpeakerAttributeSpkDecoderFirstLayer(nn.Module): |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | self_attn, |
| | | src_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | ): |
| | | """Construct an DecoderLayer object.""" |
| | | super(SpeakerAttributeSpkDecoderFirstLayer, self).__init__() |
| | | self.size = size |
| | | self.self_attn = self_attn |
| | | self.src_attn = src_attn |
| | | self.feed_forward = feed_forward |
| | | self.norm1 = LayerNorm(size) |
| | | self.norm2 = LayerNorm(size) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.normalize_before = normalize_before |
| | | self.concat_after = concat_after |
| | | if self.concat_after: |
| | | self.concat_linear1 = nn.Linear(size + size, size) |
| | | self.concat_linear2 = nn.Linear(size + size, size) |
| | | |
| | | def forward(self, tgt, tgt_mask, asr_memory, spk_memory, memory_mask, cache=None): |
| | | |
| | | residual = tgt |
| | | if self.normalize_before: |
| | | tgt = self.norm1(tgt) |
| | | |
| | | if cache is None: |
| | | tgt_q = tgt |
| | | tgt_q_mask = tgt_mask |
| | | else: |
| | | # compute only the last frame query keeping dim: max_time_out -> 1 |
| | | assert cache.shape == ( |
| | | tgt.shape[0], |
| | | tgt.shape[1] - 1, |
| | | self.size, |
| | | ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" |
| | | tgt_q = tgt[:, -1:, :] |
| | | residual = residual[:, -1:, :] |
| | | tgt_q_mask = None |
| | | if tgt_mask is not None: |
| | | tgt_q_mask = tgt_mask[:, -1:, :] |
| | | |
| | | if self.concat_after: |
| | | tgt_concat = torch.cat( |
| | | (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 |
| | | ) |
| | | x = residual + self.concat_linear1(tgt_concat) |
| | | else: |
| | | x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | z = x |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | skip = self.src_attn(x, asr_memory, spk_memory, memory_mask) |
| | | |
| | | if self.concat_after: |
| | | x_concat = torch.cat( |
| | | (x, skip), dim=-1 |
| | | ) |
| | | x = residual + self.concat_linear2(x_concat) |
| | | else: |
| | | x = residual + self.dropout(skip) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | x = residual + self.dropout(self.feed_forward(x)) |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | |
| | | return x, tgt_mask, asr_memory, spk_memory, memory_mask, z |
| | | |
| | | class SpeakerAttributeAsrDecoderFirstLayer(nn.Module): |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | d_size, |
| | | src_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | ): |
| | | """Construct an DecoderLayer object.""" |
| | | super(SpeakerAttributeAsrDecoderFirstLayer, self).__init__() |
| | | self.size = size |
| | | self.src_attn = src_attn |
| | | self.feed_forward = feed_forward |
| | | self.norm1 = LayerNorm(size) |
| | | self.norm2 = LayerNorm(size) |
| | | self.norm3 = LayerNorm(size) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.normalize_before = normalize_before |
| | | self.concat_after = concat_after |
| | | self.spk_linear = nn.Linear(d_size, size, bias=False) |
| | | if self.concat_after: |
| | | self.concat_linear1 = nn.Linear(size + size, size) |
| | | self.concat_linear2 = nn.Linear(size + size, size) |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask, dn, cache=None): |
| | | |
| | | residual = tgt |
| | | if self.normalize_before: |
| | | tgt = self.norm1(tgt) |
| | | |
| | | if cache is None: |
| | | tgt_q = tgt |
| | | tgt_q_mask = tgt_mask |
| | | else: |
| | | |
| | | tgt_q = tgt[:, -1:, :] |
| | | residual = residual[:, -1:, :] |
| | | tgt_q_mask = None |
| | | if tgt_mask is not None: |
| | | tgt_q_mask = tgt_mask[:, -1:, :] |
| | | |
| | | x = tgt_q |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | if self.concat_after: |
| | | x_concat = torch.cat( |
| | | (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 |
| | | ) |
| | | x = residual + self.concat_linear2(x_concat) |
| | | else: |
| | | x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask)) |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | residual = x |
| | | |
| | | if dn!=None: |
| | | x = x + self.spk_linear(dn) |
| | | if self.normalize_before: |
| | | x = self.norm3(x) |
| | | |
| | | x = residual + self.dropout(self.feed_forward(x)) |
| | | if not self.normalize_before: |
| | | x = self.norm3(x) |
| | | |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | |
| | | return x, tgt_mask, memory, memory_mask |
| | |
| | | |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.losses.label_smoothing_loss import ( |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | LabelSmoothingLoss, NllLoss # noqa: H301 |
| | | ) |
| | | from funasr.losses.nll_loss import NllLoss |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | |
| | | from funasr.models.decoder.transformer_decoder import ( |
| | | DynamicConvolution2DTransformerDecoder, # noqa: H301 |
| | | ) |
| | | from funasr.models.decoder.transformer_decoder_sa_asr import SAAsrTransformerDecoder |
| | | from funasr.models.decoder.transformer_decoder import SAAsrTransformerDecoder |
| | | from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder |
| | | from funasr.models.decoder.transformer_decoder import ( |
| | | LightweightConvolution2DTransformerDecoder, # noqa: H301 |