Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
| | |
| | | - main |
| | | push: |
| | | branches: |
| | | - dev_wjm |
| | | - dev_jy |
| | | - dev_wjm_infer |
| | | |
| | | jobs: |
| | | build: |
| | |
| | | - main |
| | | push: |
| | | branches: |
| | | - dev_wjm |
| | | - main |
| | | - dev_lyh |
| | | |
| New file |
| | |
| | | beam_size: 10 |
| | | penalty: 0.0 |
| | | maxlenratio: 0.0 |
| | | minlenratio: 0.0 |
| | | ctc_weight: 0.4 |
| | | lm_weight: 0.0 |
| New file |
| | |
| | | # network architecture |
| | | # encoder related |
| | | encoder: branchformer |
| | | encoder_conf: |
| | | output_size: 256 |
| | | use_attn: true |
| | | attention_heads: 4 |
| | | attention_layer_type: rel_selfattn |
| | | pos_enc_layer_type: rel_pos |
| | | rel_pos_type: latest |
| | | use_cgmlp: true |
| | | cgmlp_linear_units: 2048 |
| | | cgmlp_conv_kernel: 31 |
| | | use_linear_after_conv: false |
| | | gate_activation: identity |
| | | merge_method: concat |
| | | cgmlp_weight: 0.5 # used only if merge_method is "fixed_ave" |
| | | attn_branch_drop_rate: 0.0 # used only if merge_method is "learned_ave" |
| | | num_blocks: 24 |
| | | dropout_rate: 0.1 |
| | | positional_dropout_rate: 0.1 |
| | | attention_dropout_rate: 0.1 |
| | | input_layer: conv2d |
| | | stochastic_depth_rate: 0.0 |
| | | |
| | | # decoder related |
| | | decoder: transformer |
| | | decoder_conf: |
| | | attention_heads: 4 |
| | | linear_units: 2048 |
| | | num_blocks: 6 |
| | | dropout_rate: 0.1 |
| | | positional_dropout_rate: 0.1 |
| | | self_attention_dropout_rate: 0. |
| | | src_attention_dropout_rate: 0. |
| | | |
| | | # frontend related |
| | | frontend: wav_frontend |
| | | frontend_conf: |
| | | fs: 16000 |
| | | window: hamming |
| | | n_mels: 80 |
| | | frame_length: 25 |
| | | frame_shift: 10 |
| | | lfr_m: 1 |
| | | lfr_n: 1 |
| | | |
| | | # hybrid CTC/attention |
| | | model_conf: |
| | | ctc_weight: 0.3 |
| | | lsm_weight: 0.1 # label smoothing option |
| | | length_normalized_loss: false |
| | | |
| | | # optimization related |
| | | accum_grad: 1 |
| | | grad_clip: 5 |
| | | max_epoch: 180 |
| | | val_scheduler_criterion: |
| | | - valid |
| | | - acc |
| | | best_model_criterion: |
| | | - - valid |
| | | - acc |
| | | - max |
| | | keep_nbest_models: 10 |
| | | |
| | | optim: adam |
| | | optim_conf: |
| | | lr: 0.001 |
| | | weight_decay: 0.000001 |
| | | scheduler: warmuplr |
| | | scheduler_conf: |
| | | warmup_steps: 35000 |
| | | |
| | | specaug: specaug |
| | | specaug_conf: |
| | | apply_time_warp: true |
| | | time_warp_window: 5 |
| | | time_warp_mode: bicubic |
| | | apply_freq_mask: true |
| | | freq_mask_width_range: |
| | | - 0 |
| | | - 27 |
| | | num_freq_mask: 2 |
| | | apply_time_mask: true |
| | | time_mask_width_ratio_range: |
| | | - 0. |
| | | - 0.05 |
| | | num_time_mask: 10 |
| | | |
| | | dataset_conf: |
| | | data_names: speech,text |
| | | data_types: sound,text |
| | | shuffle: True |
| | | shuffle_conf: |
| | | shuffle_size: 2048 |
| | | sort_size: 500 |
| | | batch_conf: |
| | | batch_type: token |
| | | batch_size: 10000 |
| | | num_workers: 8 |
| | | |
| | | log_interval: 50 |
| | | normalize: None |
| New file |
| | |
| | | #!/bin/bash |
| | | |
| | | # Copyright 2017 Xingyu Na |
| | | # Apache 2.0 |
| | | |
| | | #. ./path.sh || exit 1; |
| | | |
| | | if [ $# != 3 ]; then |
| | | echo "Usage: $0 <audio-path> <text-path> <output-path>" |
| | | echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data" |
| | | exit 1; |
| | | fi |
| | | |
| | | aishell_audio_dir=$1 |
| | | aishell_text=$2/aishell_transcript_v0.8.txt |
| | | output_dir=$3 |
| | | |
| | | train_dir=$output_dir/data/local/train |
| | | dev_dir=$output_dir/data/local/dev |
| | | test_dir=$output_dir/data/local/test |
| | | tmp_dir=$output_dir/data/local/tmp |
| | | |
| | | mkdir -p $train_dir |
| | | mkdir -p $dev_dir |
| | | mkdir -p $test_dir |
| | | mkdir -p $tmp_dir |
| | | |
| | | # data directory check |
| | | if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then |
| | | echo "Error: $0 requires two directory arguments" |
| | | exit 1; |
| | | fi |
| | | |
| | | # find wav audio file for train, dev and test resp. |
| | | find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist |
| | | n=`cat $tmp_dir/wav.flist | wc -l` |
| | | [ $n -ne 141925 ] && \ |
| | | echo Warning: expected 141925 data data files, found $n |
| | | |
| | | grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; |
| | | grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; |
| | | grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; |
| | | |
| | | rm -r $tmp_dir |
| | | |
| | | # Transcriptions preparation |
| | | for dir in $train_dir $dev_dir $test_dir; do |
| | | echo Preparing $dir transcriptions |
| | | sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list |
| | | paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all |
| | | utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt |
| | | awk '{print $1}' $dir/transcripts.txt > $dir/utt.list |
| | | utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp |
| | | sort -u $dir/transcripts.txt > $dir/text |
| | | done |
| | | |
| | | mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test |
| | | |
| | | for f in wav.scp text; do |
| | | cp $train_dir/$f $output_dir/data/train/$f || exit 1; |
| | | cp $dev_dir/$f $output_dir/data/dev/$f || exit 1; |
| | | cp $test_dir/$f $output_dir/data/test/$f || exit 1; |
| | | done |
| | | |
| | | echo "$0: AISHELL data preparation succeeded" |
| | | exit 0; |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | # Copyright 2014 Johns Hopkins University (author: Daniel Povey) |
| | | # 2017 Xingyu Na |
| | | # Apache 2.0 |
| | | |
| | | remove_archive=false |
| | | |
| | | if [ "$1" == --remove-archive ]; then |
| | | remove_archive=true |
| | | shift |
| | | fi |
| | | |
| | | if [ $# -ne 3 ]; then |
| | | echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>" |
| | | echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell" |
| | | echo "With --remove-archive it will remove the archive after successfully un-tarring it." |
| | | echo "<corpus-part> can be one of: data_aishell, resource_aishell." |
| | | fi |
| | | |
| | | data=$1 |
| | | url=$2 |
| | | part=$3 |
| | | |
| | | if [ ! -d "$data" ]; then |
| | | echo "$0: no such directory $data" |
| | | exit 1; |
| | | fi |
| | | |
| | | part_ok=false |
| | | list="data_aishell resource_aishell" |
| | | for x in $list; do |
| | | if [ "$part" == $x ]; then part_ok=true; fi |
| | | done |
| | | if ! $part_ok; then |
| | | echo "$0: expected <corpus-part> to be one of $list, but got '$part'" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ -z "$url" ]; then |
| | | echo "$0: empty URL base." |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ -f $data/$part/.complete ]; then |
| | | echo "$0: data part $part was already successfully extracted, nothing to do." |
| | | exit 0; |
| | | fi |
| | | |
| | | # sizes of the archive files in bytes. |
| | | sizes="15582913665 1246920" |
| | | |
| | | if [ -f $data/$part.tgz ]; then |
| | | size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}') |
| | | size_ok=false |
| | | for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done |
| | | if ! $size_ok; then |
| | | echo "$0: removing existing file $data/$part.tgz because its size in bytes $size" |
| | | echo "does not equal the size of one of the archives." |
| | | rm $data/$part.tgz |
| | | else |
| | | echo "$data/$part.tgz exists and appears to be complete." |
| | | fi |
| | | fi |
| | | |
| | | if [ ! -f $data/$part.tgz ]; then |
| | | if ! command -v wget >/dev/null; then |
| | | echo "$0: wget is not installed." |
| | | exit 1; |
| | | fi |
| | | full_url=$url/$part.tgz |
| | | echo "$0: downloading data from $full_url. This may take some time, please be patient." |
| | | |
| | | cd $data || exit 1 |
| | | if ! wget --no-check-certificate $full_url; then |
| | | echo "$0: error executing wget $full_url" |
| | | exit 1; |
| | | fi |
| | | fi |
| | | |
| | | cd $data || exit 1 |
| | | |
| | | if ! tar -xvzf $part.tgz; then |
| | | echo "$0: error un-tarring archive $data/$part.tgz" |
| | | exit 1; |
| | | fi |
| | | |
| | | touch $data/$part/.complete |
| | | |
| | | if [ $part == "data_aishell" ]; then |
| | | cd $data/$part/wav || exit 1 |
| | | for wav in ./*.tar.gz; do |
| | | echo "Extracting wav from $wav" |
| | | tar -zxf $wav && rm $wav |
| | | done |
| | | fi |
| | | |
| | | echo "$0: Successfully downloaded and un-tarred $data/$part.tgz" |
| | | |
| | | if $remove_archive; then |
| | | echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied." |
| | | rm $data/$part.tgz |
| | | fi |
| | | |
| | | exit 0; |
| New file |
| | |
| | | export FUNASR_DIR=$PWD/../../.. |
| | | |
| | | # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C |
| | | export PYTHONIOENCODING=UTF-8 |
| | | export PATH=$FUNASR_DIR/funasr/bin:$PATH |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | . ./path.sh || exit 1; |
| | | |
| | | # machines configuration |
| | | CUDA_VISIBLE_DEVICES="0,1,2,3" |
| | | gpu_num=4 |
| | | count=1 |
| | | gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding |
| | | # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob |
| | | njob=5 |
| | | train_cmd=utils/run.pl |
| | | infer_cmd=utils/run.pl |
| | | |
| | | # general configuration |
| | | feats_dir="../DATA" #feature output dictionary |
| | | exp_dir="." |
| | | lang=zh |
| | | token_type=char |
| | | type=sound |
| | | scp=wav.scp |
| | | speed_perturb="0.9 1.0 1.1" |
| | | stage=0 |
| | | stop_stage=5 |
| | | |
| | | # feature configuration |
| | | feats_dim=80 |
| | | nj=64 |
| | | |
| | | # data |
| | | raw_data=../raw_data |
| | | data_url=www.openslr.org/resources/33 |
| | | |
| | | # exp tag |
| | | tag="exp1" |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # Set bash to 'debug' mode, it will exit on : |
| | | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', |
| | | set -e |
| | | set -u |
| | | set -o pipefail |
| | | |
| | | train_set=train |
| | | valid_set=dev |
| | | test_sets="dev test" |
| | | |
| | | asr_config=conf/train_asr_branchformer.yaml |
| | | model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}" |
| | | |
| | | inference_config=conf/decode_asr_transformer.yaml |
| | | inference_asr_model=valid.acc.ave_10best.pb |
| | | |
| | | # you can set gpu num for decoding here |
| | | gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default |
| | | ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') |
| | | |
| | | if ${gpu_inference}; then |
| | | inference_nj=$[${ngpu}*${njob}] |
| | | _ngpu=1 |
| | | else |
| | | inference_nj=$njob |
| | | _ngpu=0 |
| | | fi |
| | | |
| | | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then |
| | | echo "stage -1: Data Download" |
| | | local/download_and_untar.sh ${raw_data} ${data_url} data_aishell |
| | | local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell |
| | | fi |
| | | |
| | | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then |
| | | echo "stage 0: Data preparation" |
| | | # Data preparation |
| | | local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir} |
| | | for x in train dev test; do |
| | | cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org |
| | | paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \ |
| | | > ${feats_dir}/data/${x}/text |
| | | utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org |
| | | mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text |
| | | done |
| | | fi |
| | | |
| | | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then |
| | | echo "stage 1: Feature and CMVN Generation" |
| | | utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0 |
| | | fi |
| | | |
| | | token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt |
| | | echo "dictionary: ${token_list}" |
| | | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then |
| | | echo "stage 2: Dictionary Preparation" |
| | | mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/ |
| | | |
| | | echo "make a dictionary" |
| | | echo "<blank>" > ${token_list} |
| | | echo "<s>" >> ${token_list} |
| | | echo "</s>" >> ${token_list} |
| | | utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \ |
| | | | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} |
| | | echo "<unk>" >> ${token_list} |
| | | fi |
| | | |
| | | # LM Training Stage |
| | | world_size=$gpu_num # run on one machine |
| | | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then |
| | | echo "stage 3: LM Training" |
| | | fi |
| | | |
| | | # ASR Training Stage |
| | | world_size=$gpu_num # run on one machine |
| | | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then |
| | | echo "stage 4: ASR Training" |
| | | mkdir -p ${exp_dir}/exp/${model_dir} |
| | | mkdir -p ${exp_dir}/exp/${model_dir}/log |
| | | INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init |
| | | if [ -f $INIT_FILE ];then |
| | | rm -f $INIT_FILE |
| | | fi |
| | | init_method=file://$(readlink -f $INIT_FILE) |
| | | echo "$0: init method is $init_method" |
| | | for ((i = 0; i < $gpu_num; ++i)); do |
| | | { |
| | | rank=$i |
| | | local_rank=$i |
| | | gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) |
| | | train.py \ |
| | | --task_name asr \ |
| | | --gpu_id $gpu_id \ |
| | | --use_preprocessor true \ |
| | | --token_type $token_type \ |
| | | --token_list $token_list \ |
| | | --data_dir ${feats_dir}/data \ |
| | | --train_set ${train_set} \ |
| | | --valid_set ${valid_set} \ |
| | | --data_file_names "wav.scp,text" \ |
| | | --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \ |
| | | --speed_perturb ${speed_perturb} \ |
| | | --resume true \ |
| | | --output_dir ${exp_dir}/exp/${model_dir} \ |
| | | --config $asr_config \ |
| | | --ngpu $gpu_num \ |
| | | --num_worker_count $count \ |
| | | --dist_init_method $init_method \ |
| | | --dist_world_size $world_size \ |
| | | --dist_rank $rank \ |
| | | --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1 |
| | | } & |
| | | done |
| | | wait |
| | | fi |
| | | |
| | | # Testing Stage |
| | | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then |
| | | echo "stage 5: Inference" |
| | | for dset in ${test_sets}; do |
| | | asr_exp=${exp_dir}/exp/${model_dir} |
| | | inference_tag="$(basename "${inference_config}" .yaml)" |
| | | _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}" |
| | | _logdir="${_dir}/logdir" |
| | | if [ -d ${_dir} ]; then |
| | | echo "${_dir} is already exists. if you want to decode again, please delete this dir first." |
| | | exit 0 |
| | | fi |
| | | mkdir -p "${_logdir}" |
| | | _data="${feats_dir}/data/${dset}" |
| | | key_file=${_data}/${scp} |
| | | num_scp_file="$(<${key_file} wc -l)" |
| | | _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") |
| | | split_scps= |
| | | for n in $(seq "${_nj}"); do |
| | | split_scps+=" ${_logdir}/keys.${n}.scp" |
| | | done |
| | | # shellcheck disable=SC2086 |
| | | utils/split_scp.pl "${key_file}" ${split_scps} |
| | | _opts= |
| | | if [ -n "${inference_config}" ]; then |
| | | _opts+="--config ${inference_config} " |
| | | fi |
| | | ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ |
| | | python -m funasr.bin.asr_inference_launch \ |
| | | --batch_size 1 \ |
| | | --ngpu "${_ngpu}" \ |
| | | --njob ${njob} \ |
| | | --gpuid_list ${gpuid_list} \ |
| | | --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ |
| | | --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \ |
| | | --key_file "${_logdir}"/keys.JOB.scp \ |
| | | --asr_train_config "${asr_exp}"/config.yaml \ |
| | | --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ |
| | | --output_dir "${_logdir}"/output.JOB \ |
| | | --mode asr \ |
| | | ${_opts} |
| | | |
| | | for f in token token_int score text; do |
| | | if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then |
| | | for i in $(seq "${_nj}"); do |
| | | cat "${_logdir}/output.${i}/1best_recog/${f}" |
| | | done | sort -k1 >"${_dir}/${f}" |
| | | fi |
| | | done |
| | | python utils/proce_text.py ${_dir}/text ${_dir}/text.proc |
| | | python utils/proce_text.py ${_data}/text ${_data}/text.proc |
| | | python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer |
| | | tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt |
| | | cat ${_dir}/text.cer.txt |
| | | done |
| | | fi |
| | | |
| | | # Prepare files for ModelScope fine-tuning and inference |
| | | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then |
| | | echo "stage 6: ModelScope Preparation" |
| | | cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn |
| | | vocab_size=$(cat ${token_list} | wc -l) |
| | | python utils/gen_modelscope_configuration.py \ |
| | | --am_model_name $inference_asr_model \ |
| | | --mode asr \ |
| | | --model_name conformer \ |
| | | --dataset aishell \ |
| | | --output_dir $exp_dir/exp/$model_dir \ |
| | | --vocab_size $vocab_size \ |
| | | --tag $tag |
| | | fi |
| New file |
| | |
| | | ../transformer/utils |
| New file |
| | |
| | | beam_size: 10 |
| | | penalty: 0.0 |
| | | maxlenratio: 0.0 |
| | | minlenratio: 0.0 |
| | | ctc_weight: 0.4 |
| | | lm_weight: 0.0 |
| New file |
| | |
| | | # network architecture |
| | | # encoder related |
| | | encoder: e_branchformer |
| | | encoder_conf: |
| | | output_size: 256 |
| | | attention_heads: 4 |
| | | attention_layer_type: rel_selfattn |
| | | pos_enc_layer_type: rel_pos |
| | | rel_pos_type: latest |
| | | cgmlp_linear_units: 1024 |
| | | cgmlp_conv_kernel: 31 |
| | | use_linear_after_conv: false |
| | | gate_activation: identity |
| | | num_blocks: 12 |
| | | dropout_rate: 0.1 |
| | | positional_dropout_rate: 0.1 |
| | | attention_dropout_rate: 0.1 |
| | | input_layer: conv2d |
| | | layer_drop_rate: 0.0 |
| | | linear_units: 1024 |
| | | positionwise_layer_type: linear |
| | | use_ffn: true |
| | | macaron_ffn: true |
| | | merge_conv_kernel: 31 |
| | | |
| | | # decoder related |
| | | decoder: transformer |
| | | decoder_conf: |
| | | attention_heads: 4 |
| | | linear_units: 2048 |
| | | num_blocks: 6 |
| | | dropout_rate: 0.1 |
| | | positional_dropout_rate: 0.1 |
| | | self_attention_dropout_rate: 0. |
| | | src_attention_dropout_rate: 0. |
| | | |
| | | # frontend related |
| | | frontend: wav_frontend |
| | | frontend_conf: |
| | | fs: 16000 |
| | | window: hamming |
| | | n_mels: 80 |
| | | frame_length: 25 |
| | | frame_shift: 10 |
| | | lfr_m: 1 |
| | | lfr_n: 1 |
| | | |
| | | # hybrid CTC/attention |
| | | model_conf: |
| | | ctc_weight: 0.3 |
| | | lsm_weight: 0.1 # label smoothing option |
| | | length_normalized_loss: false |
| | | |
| | | # optimization related |
| | | accum_grad: 1 |
| | | grad_clip: 5 |
| | | max_epoch: 180 |
| | | best_model_criterion: |
| | | - - valid |
| | | - acc |
| | | - max |
| | | keep_nbest_models: 10 |
| | | |
| | | optim: adam |
| | | optim_conf: |
| | | lr: 0.001 |
| | | weight_decay: 0.000001 |
| | | scheduler: warmuplr |
| | | scheduler_conf: |
| | | warmup_steps: 35000 |
| | | |
| | | specaug: specaug |
| | | specaug_conf: |
| | | apply_time_warp: true |
| | | time_warp_window: 5 |
| | | time_warp_mode: bicubic |
| | | apply_freq_mask: true |
| | | freq_mask_width_range: |
| | | - 0 |
| | | - 27 |
| | | num_freq_mask: 2 |
| | | apply_time_mask: true |
| | | time_mask_width_ratio_range: |
| | | - 0. |
| | | - 0.05 |
| | | num_time_mask: 10 |
| | | |
| | | dataset_conf: |
| | | data_names: speech,text |
| | | data_types: sound,text |
| | | shuffle: True |
| | | shuffle_conf: |
| | | shuffle_size: 2048 |
| | | sort_size: 500 |
| | | batch_conf: |
| | | batch_type: token |
| | | batch_size: 10000 |
| | | num_workers: 8 |
| | | |
| | | log_interval: 50 |
| | | normalize: None |
| New file |
| | |
| | | #!/bin/bash |
| | | |
| | | # Copyright 2017 Xingyu Na |
| | | # Apache 2.0 |
| | | |
| | | #. ./path.sh || exit 1; |
| | | |
| | | if [ $# != 3 ]; then |
| | | echo "Usage: $0 <audio-path> <text-path> <output-path>" |
| | | echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data" |
| | | exit 1; |
| | | fi |
| | | |
| | | aishell_audio_dir=$1 |
| | | aishell_text=$2/aishell_transcript_v0.8.txt |
| | | output_dir=$3 |
| | | |
| | | train_dir=$output_dir/data/local/train |
| | | dev_dir=$output_dir/data/local/dev |
| | | test_dir=$output_dir/data/local/test |
| | | tmp_dir=$output_dir/data/local/tmp |
| | | |
| | | mkdir -p $train_dir |
| | | mkdir -p $dev_dir |
| | | mkdir -p $test_dir |
| | | mkdir -p $tmp_dir |
| | | |
| | | # data directory check |
| | | if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then |
| | | echo "Error: $0 requires two directory arguments" |
| | | exit 1; |
| | | fi |
| | | |
| | | # find wav audio file for train, dev and test resp. |
| | | find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist |
| | | n=`cat $tmp_dir/wav.flist | wc -l` |
| | | [ $n -ne 141925 ] && \ |
| | | echo Warning: expected 141925 data data files, found $n |
| | | |
| | | grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; |
| | | grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; |
| | | grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; |
| | | |
| | | rm -r $tmp_dir |
| | | |
| | | # Transcriptions preparation |
| | | for dir in $train_dir $dev_dir $test_dir; do |
| | | echo Preparing $dir transcriptions |
| | | sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list |
| | | paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all |
| | | utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt |
| | | awk '{print $1}' $dir/transcripts.txt > $dir/utt.list |
| | | utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp |
| | | sort -u $dir/transcripts.txt > $dir/text |
| | | done |
| | | |
| | | mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test |
| | | |
| | | for f in wav.scp text; do |
| | | cp $train_dir/$f $output_dir/data/train/$f || exit 1; |
| | | cp $dev_dir/$f $output_dir/data/dev/$f || exit 1; |
| | | cp $test_dir/$f $output_dir/data/test/$f || exit 1; |
| | | done |
| | | |
| | | echo "$0: AISHELL data preparation succeeded" |
| | | exit 0; |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | # Copyright 2014 Johns Hopkins University (author: Daniel Povey) |
| | | # 2017 Xingyu Na |
| | | # Apache 2.0 |
| | | |
| | | remove_archive=false |
| | | |
| | | if [ "$1" == --remove-archive ]; then |
| | | remove_archive=true |
| | | shift |
| | | fi |
| | | |
| | | if [ $# -ne 3 ]; then |
| | | echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>" |
| | | echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell" |
| | | echo "With --remove-archive it will remove the archive after successfully un-tarring it." |
| | | echo "<corpus-part> can be one of: data_aishell, resource_aishell." |
| | | fi |
| | | |
| | | data=$1 |
| | | url=$2 |
| | | part=$3 |
| | | |
| | | if [ ! -d "$data" ]; then |
| | | echo "$0: no such directory $data" |
| | | exit 1; |
| | | fi |
| | | |
| | | part_ok=false |
| | | list="data_aishell resource_aishell" |
| | | for x in $list; do |
| | | if [ "$part" == $x ]; then part_ok=true; fi |
| | | done |
| | | if ! $part_ok; then |
| | | echo "$0: expected <corpus-part> to be one of $list, but got '$part'" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ -z "$url" ]; then |
| | | echo "$0: empty URL base." |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ -f $data/$part/.complete ]; then |
| | | echo "$0: data part $part was already successfully extracted, nothing to do." |
| | | exit 0; |
| | | fi |
| | | |
| | | # sizes of the archive files in bytes. |
| | | sizes="15582913665 1246920" |
| | | |
| | | if [ -f $data/$part.tgz ]; then |
| | | size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}') |
| | | size_ok=false |
| | | for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done |
| | | if ! $size_ok; then |
| | | echo "$0: removing existing file $data/$part.tgz because its size in bytes $size" |
| | | echo "does not equal the size of one of the archives." |
| | | rm $data/$part.tgz |
| | | else |
| | | echo "$data/$part.tgz exists and appears to be complete." |
| | | fi |
| | | fi |
| | | |
| | | if [ ! -f $data/$part.tgz ]; then |
| | | if ! command -v wget >/dev/null; then |
| | | echo "$0: wget is not installed." |
| | | exit 1; |
| | | fi |
| | | full_url=$url/$part.tgz |
| | | echo "$0: downloading data from $full_url. This may take some time, please be patient." |
| | | |
| | | cd $data || exit 1 |
| | | if ! wget --no-check-certificate $full_url; then |
| | | echo "$0: error executing wget $full_url" |
| | | exit 1; |
| | | fi |
| | | fi |
| | | |
| | | cd $data || exit 1 |
| | | |
| | | if ! tar -xvzf $part.tgz; then |
| | | echo "$0: error un-tarring archive $data/$part.tgz" |
| | | exit 1; |
| | | fi |
| | | |
| | | touch $data/$part/.complete |
| | | |
| | | if [ $part == "data_aishell" ]; then |
| | | cd $data/$part/wav || exit 1 |
| | | for wav in ./*.tar.gz; do |
| | | echo "Extracting wav from $wav" |
| | | tar -zxf $wav && rm $wav |
| | | done |
| | | fi |
| | | |
| | | echo "$0: Successfully downloaded and un-tarred $data/$part.tgz" |
| | | |
| | | if $remove_archive; then |
| | | echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied." |
| | | rm $data/$part.tgz |
| | | fi |
| | | |
| | | exit 0; |
| New file |
| | |
| | | export FUNASR_DIR=$PWD/../../.. |
| | | |
| | | # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C |
| | | export PYTHONIOENCODING=UTF-8 |
| | | export PATH=$FUNASR_DIR/funasr/bin:$PATH |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | . ./path.sh || exit 1; |
| | | |
| | | # machines configuration |
| | | CUDA_VISIBLE_DEVICES="0,1,2,3" |
| | | gpu_num=4 |
| | | count=1 |
| | | gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding |
| | | # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob |
| | | njob=5 |
| | | train_cmd=utils/run.pl |
| | | infer_cmd=utils/run.pl |
| | | |
| | | # general configuration |
| | | feats_dir="../DATA" #feature output dictionary |
| | | exp_dir="." |
| | | lang=zh |
| | | token_type=char |
| | | type=sound |
| | | scp=wav.scp |
| | | speed_perturb="0.9 1.0 1.1" |
| | | stage=0 |
| | | stop_stage=5 |
| | | |
| | | # feature configuration |
| | | feats_dim=80 |
| | | nj=64 |
| | | |
| | | # data |
| | | raw_data=../raw_data |
| | | data_url=www.openslr.org/resources/33 |
| | | |
| | | # exp tag |
| | | tag="exp1" |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # Set bash to 'debug' mode, it will exit on : |
| | | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', |
| | | set -e |
| | | set -u |
| | | set -o pipefail |
| | | |
| | | train_set=train |
| | | valid_set=dev |
| | | test_sets="dev test" |
| | | |
| | | asr_config=conf/train_asr_e_branchformer.yaml |
| | | model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}" |
| | | |
| | | inference_config=conf/decode_asr_transformer.yaml |
| | | inference_asr_model=valid.acc.ave_10best.pb |
| | | |
| | | # you can set gpu num for decoding here |
| | | gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default |
| | | ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') |
| | | |
| | | if ${gpu_inference}; then |
| | | inference_nj=$[${ngpu}*${njob}] |
| | | _ngpu=1 |
| | | else |
| | | inference_nj=$njob |
| | | _ngpu=0 |
| | | fi |
| | | |
| | | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then |
| | | echo "stage -1: Data Download" |
| | | local/download_and_untar.sh ${raw_data} ${data_url} data_aishell |
| | | local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell |
| | | fi |
| | | |
| | | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then |
| | | echo "stage 0: Data preparation" |
| | | # Data preparation |
| | | local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir} |
| | | for x in train dev test; do |
| | | cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org |
| | | paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \ |
| | | > ${feats_dir}/data/${x}/text |
| | | utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org |
| | | mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text |
| | | done |
| | | fi |
| | | |
| | | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then |
| | | echo "stage 1: Feature and CMVN Generation" |
| | | utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0 |
| | | fi |
| | | |
| | | token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt |
| | | echo "dictionary: ${token_list}" |
| | | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then |
| | | echo "stage 2: Dictionary Preparation" |
| | | mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/ |
| | | |
| | | echo "make a dictionary" |
| | | echo "<blank>" > ${token_list} |
| | | echo "<s>" >> ${token_list} |
| | | echo "</s>" >> ${token_list} |
| | | utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \ |
| | | | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} |
| | | echo "<unk>" >> ${token_list} |
| | | fi |
| | | |
| | | # LM Training Stage |
| | | world_size=$gpu_num # run on one machine |
| | | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then |
| | | echo "stage 3: LM Training" |
| | | fi |
| | | |
| | | # ASR Training Stage |
| | | world_size=$gpu_num # run on one machine |
| | | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then |
| | | echo "stage 4: ASR Training" |
| | | mkdir -p ${exp_dir}/exp/${model_dir} |
| | | mkdir -p ${exp_dir}/exp/${model_dir}/log |
| | | INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init |
| | | if [ -f $INIT_FILE ];then |
| | | rm -f $INIT_FILE |
| | | fi |
| | | init_method=file://$(readlink -f $INIT_FILE) |
| | | echo "$0: init method is $init_method" |
| | | for ((i = 0; i < $gpu_num; ++i)); do |
| | | { |
| | | rank=$i |
| | | local_rank=$i |
| | | gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) |
| | | train.py \ |
| | | --task_name asr \ |
| | | --gpu_id $gpu_id \ |
| | | --use_preprocessor true \ |
| | | --token_type $token_type \ |
| | | --token_list $token_list \ |
| | | --data_dir ${feats_dir}/data \ |
| | | --train_set ${train_set} \ |
| | | --valid_set ${valid_set} \ |
| | | --data_file_names "wav.scp,text" \ |
| | | --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \ |
| | | --speed_perturb ${speed_perturb} \ |
| | | --resume true \ |
| | | --output_dir ${exp_dir}/exp/${model_dir} \ |
| | | --config $asr_config \ |
| | | --ngpu $gpu_num \ |
| | | --num_worker_count $count \ |
| | | --dist_init_method $init_method \ |
| | | --dist_world_size $world_size \ |
| | | --dist_rank $rank \ |
| | | --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1 |
| | | } & |
| | | done |
| | | wait |
| | | fi |
| | | |
| | | # Testing Stage |
| | | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then |
| | | echo "stage 5: Inference" |
| | | for dset in ${test_sets}; do |
| | | asr_exp=${exp_dir}/exp/${model_dir} |
| | | inference_tag="$(basename "${inference_config}" .yaml)" |
| | | _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}" |
| | | _logdir="${_dir}/logdir" |
| | | if [ -d ${_dir} ]; then |
| | | echo "${_dir} is already exists. if you want to decode again, please delete this dir first." |
| | | exit 0 |
| | | fi |
| | | mkdir -p "${_logdir}" |
| | | _data="${feats_dir}/data/${dset}" |
| | | key_file=${_data}/${scp} |
| | | num_scp_file="$(<${key_file} wc -l)" |
| | | _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") |
| | | split_scps= |
| | | for n in $(seq "${_nj}"); do |
| | | split_scps+=" ${_logdir}/keys.${n}.scp" |
| | | done |
| | | # shellcheck disable=SC2086 |
| | | utils/split_scp.pl "${key_file}" ${split_scps} |
| | | _opts= |
| | | if [ -n "${inference_config}" ]; then |
| | | _opts+="--config ${inference_config} " |
| | | fi |
| | | ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ |
| | | python -m funasr.bin.asr_inference_launch \ |
| | | --batch_size 1 \ |
| | | --ngpu "${_ngpu}" \ |
| | | --njob ${njob} \ |
| | | --gpuid_list ${gpuid_list} \ |
| | | --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ |
| | | --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \ |
| | | --key_file "${_logdir}"/keys.JOB.scp \ |
| | | --asr_train_config "${asr_exp}"/config.yaml \ |
| | | --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ |
| | | --output_dir "${_logdir}"/output.JOB \ |
| | | --mode asr \ |
| | | ${_opts} |
| | | |
| | | for f in token token_int score text; do |
| | | if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then |
| | | for i in $(seq "${_nj}"); do |
| | | cat "${_logdir}/output.${i}/1best_recog/${f}" |
| | | done | sort -k1 >"${_dir}/${f}" |
| | | fi |
| | | done |
| | | python utils/proce_text.py ${_dir}/text ${_dir}/text.proc |
| | | python utils/proce_text.py ${_data}/text ${_data}/text.proc |
| | | python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer |
| | | tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt |
| | | cat ${_dir}/text.cer.txt |
| | | done |
| | | fi |
| | | |
| | | # Prepare files for ModelScope fine-tuning and inference |
| | | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then |
| | | echo "stage 6: ModelScope Preparation" |
| | | cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn |
| | | vocab_size=$(cat ${token_list} | wc -l) |
| | | python utils/gen_modelscope_configuration.py \ |
| | | --am_model_name $inference_asr_model \ |
| | | --mode asr \ |
| | | --model_name conformer \ |
| | | --dataset aishell \ |
| | | --output_dir $exp_dir/exp/$model_dir \ |
| | | --vocab_size $vocab_size \ |
| | | --tag $tag |
| | | fi |
| New file |
| | |
| | | ../transformer/utils |
| | |
| | | right_context: Number of frames in right context AFTER subsampling. |
| | | display_partial_hypotheses: Whether to display partial hypotheses. |
| | | """ |
| | | # assert check_argument_types() |
| | | |
| | | if batch_size > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | |
| | | from funasr.models.encoder.resnet34_encoder import ResNet34Diar |
| | | from funasr.models.encoder.rnn_encoder import RNNEncoder |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt |
| | | from funasr.models.encoder.branchformer_encoder import BranchformerEncoder |
| | | from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder |
| | | from funasr.models.encoder.transformer_encoder import TransformerEncoder |
| | | from funasr.models.frontend.default import DefaultFrontend |
| | | from funasr.models.frontend.default import MultiChannelFrontend |
| | |
| | | sanm=SANMEncoder, |
| | | sanm_chunk_opt=SANMEncoderChunkOpt, |
| | | data2vec_encoder=Data2VecEncoder, |
| | | branchformer=BranchformerEncoder, |
| | | e_branchformer=EBranchformerEncoder, |
| | | mfcca_enc=MFCCAEncoder, |
| | | chunk_conformer=ConformerChunkEncoder, |
| | | ), |
| | |
| | | mode=mode, |
| | | ) |
| | | |
| | | filter_conf = conf.get('filter_conf', {}) |
| | | filter_fn = partial(filter, **filter_conf) |
| | | dataset = FilterIterDataPipe(dataset, fn=filter_fn) |
| | | |
| | | if "text" in data_names: |
| | | vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict, 'bpe_tokenizer': bpe_tokenizer, 'hw_config': hw_config} |
| | | tokenize_fn = partial(tokenize, **vocab) |
| | | dataset = MapperIterDataPipe(dataset, fn=tokenize_fn) |
| | | |
| | | filter_conf = conf.get('filter_conf', {}) |
| | | filter_fn = partial(filter, **filter_conf) |
| | | dataset = FilterIterDataPipe(dataset, fn=filter_fn) |
| | | |
| | | if shuffle: |
| | | buffer_conf = conf.get('shuffle_conf', {}) |
| | | buffer_size = buffer_conf['shuffle_size'] |
| New file |
| | |
| | | # Copyright 2022 Yifan Peng (Carnegie Mellon University) |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | """Branchformer encoder definition. |
| | | |
| | | Reference: |
| | | Yifan Peng, Siddharth Dalmia, Ian Lane, and Shinji Watanabe, |
| | | “Branchformer: Parallel MLP-Attention Architectures to Capture |
| | | Local and Global Context for Speech Recognition and Understanding,” |
| | | in Proceedings of ICML, 2022. |
| | | |
| | | """ |
| | | |
| | | import logging |
| | | from typing import List, Optional, Tuple, Union |
| | | |
| | | import numpy |
| | | import torch |
| | | |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.cgmlp import ConvolutionalGatingMLP |
| | | from funasr.modules.fastformer import FastSelfAttention |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.modules.attention import ( # noqa: H301 |
| | | LegacyRelPositionMultiHeadedAttention, |
| | | MultiHeadedAttention, |
| | | RelPositionMultiHeadedAttention, |
| | | ) |
| | | from funasr.modules.embedding import ( # noqa: H301 |
| | | LegacyRelPositionalEncoding, |
| | | PositionalEncoding, |
| | | RelPositionalEncoding, |
| | | ScaledPositionalEncoding, |
| | | ) |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.repeat import repeat |
| | | from funasr.modules.subsampling import ( |
| | | Conv2dSubsampling, |
| | | Conv2dSubsampling2, |
| | | Conv2dSubsampling6, |
| | | Conv2dSubsampling8, |
| | | TooShortUttError, |
| | | check_short_utt, |
| | | ) |
| | | |
| | | |
| | | class BranchformerEncoderLayer(torch.nn.Module): |
| | | """Branchformer encoder layer module. |
| | | |
| | | Args: |
| | | size (int): model dimension |
| | | attn: standard self-attention or efficient attention, optional |
| | | cgmlp: ConvolutionalGatingMLP, optional |
| | | dropout_rate (float): dropout probability |
| | | merge_method (str): concat, learned_ave, fixed_ave |
| | | cgmlp_weight (float): weight of the cgmlp branch, between 0 and 1, |
| | | used if merge_method is fixed_ave |
| | | attn_branch_drop_rate (float): probability of dropping the attn branch, |
| | | used if merge_method is learned_ave |
| | | stochastic_depth_rate (float): stochastic depth probability |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | attn: Optional[torch.nn.Module], |
| | | cgmlp: Optional[torch.nn.Module], |
| | | dropout_rate: float, |
| | | merge_method: str, |
| | | cgmlp_weight: float = 0.5, |
| | | attn_branch_drop_rate: float = 0.0, |
| | | stochastic_depth_rate: float = 0.0, |
| | | ): |
| | | super().__init__() |
| | | assert (attn is not None) or ( |
| | | cgmlp is not None |
| | | ), "At least one branch should be valid" |
| | | |
| | | self.size = size |
| | | self.attn = attn |
| | | self.cgmlp = cgmlp |
| | | self.merge_method = merge_method |
| | | self.cgmlp_weight = cgmlp_weight |
| | | self.attn_branch_drop_rate = attn_branch_drop_rate |
| | | self.stochastic_depth_rate = stochastic_depth_rate |
| | | self.use_two_branches = (attn is not None) and (cgmlp is not None) |
| | | |
| | | if attn is not None: |
| | | self.norm_mha = LayerNorm(size) # for the MHA module |
| | | if cgmlp is not None: |
| | | self.norm_mlp = LayerNorm(size) # for the MLP module |
| | | self.norm_final = LayerNorm(size) # for the final output of the block |
| | | |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | |
| | | if self.use_two_branches: |
| | | if merge_method == "concat": |
| | | self.merge_proj = torch.nn.Linear(size + size, size) |
| | | |
| | | elif merge_method == "learned_ave": |
| | | # attention-based pooling for two branches |
| | | self.pooling_proj1 = torch.nn.Linear(size, 1) |
| | | self.pooling_proj2 = torch.nn.Linear(size, 1) |
| | | |
| | | # linear projections for calculating merging weights |
| | | self.weight_proj1 = torch.nn.Linear(size, 1) |
| | | self.weight_proj2 = torch.nn.Linear(size, 1) |
| | | |
| | | # linear projection after weighted average |
| | | self.merge_proj = torch.nn.Linear(size, size) |
| | | |
| | | elif merge_method == "fixed_ave": |
| | | assert ( |
| | | 0.0 <= cgmlp_weight <= 1.0 |
| | | ), "cgmlp weight should be between 0.0 and 1.0" |
| | | |
| | | # remove the other branch if only one branch is used |
| | | if cgmlp_weight == 0.0: |
| | | self.use_two_branches = False |
| | | self.cgmlp = None |
| | | self.norm_mlp = None |
| | | elif cgmlp_weight == 1.0: |
| | | self.use_two_branches = False |
| | | self.attn = None |
| | | self.norm_mha = None |
| | | |
| | | # linear projection after weighted average |
| | | self.merge_proj = torch.nn.Linear(size, size) |
| | | |
| | | else: |
| | | raise ValueError(f"unknown merge method: {merge_method}") |
| | | |
| | | else: |
| | | self.merge_proj = torch.nn.Identity() |
| | | |
| | | def forward(self, x_input, mask, cache=None): |
| | | """Compute encoded features. |
| | | |
| | | Args: |
| | | x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. |
| | | - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. |
| | | - w/o pos emb: Tensor (#batch, time, size). |
| | | mask (torch.Tensor): Mask tensor for the input (#batch, 1, time). |
| | | cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, time, size). |
| | | torch.Tensor: Mask tensor (#batch, time). |
| | | """ |
| | | |
| | | if cache is not None: |
| | | raise NotImplementedError("cache is not None, which is not tested") |
| | | |
| | | if isinstance(x_input, tuple): |
| | | x, pos_emb = x_input[0], x_input[1] |
| | | else: |
| | | x, pos_emb = x_input, None |
| | | |
| | | skip_layer = False |
| | | # with stochastic depth, residual connection `x + f(x)` becomes |
| | | # `x <- x + 1 / (1 - p) * f(x)` at training time. |
| | | stoch_layer_coeff = 1.0 |
| | | if self.training and self.stochastic_depth_rate > 0: |
| | | skip_layer = torch.rand(1).item() < self.stochastic_depth_rate |
| | | stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) |
| | | |
| | | if skip_layer: |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | if pos_emb is not None: |
| | | return (x, pos_emb), mask |
| | | return x, mask |
| | | |
| | | # Two branches |
| | | x1 = x |
| | | x2 = x |
| | | |
| | | # Branch 1: multi-headed attention module |
| | | if self.attn is not None: |
| | | x1 = self.norm_mha(x1) |
| | | |
| | | if isinstance(self.attn, FastSelfAttention): |
| | | x_att = self.attn(x1, mask) |
| | | else: |
| | | if pos_emb is not None: |
| | | x_att = self.attn(x1, x1, x1, pos_emb, mask) |
| | | else: |
| | | x_att = self.attn(x1, x1, x1, mask) |
| | | |
| | | x1 = self.dropout(x_att) |
| | | |
| | | # Branch 2: convolutional gating mlp |
| | | if self.cgmlp is not None: |
| | | x2 = self.norm_mlp(x2) |
| | | |
| | | if pos_emb is not None: |
| | | x2 = (x2, pos_emb) |
| | | x2 = self.cgmlp(x2, mask) |
| | | if isinstance(x2, tuple): |
| | | x2 = x2[0] |
| | | |
| | | x2 = self.dropout(x2) |
| | | |
| | | # Merge two branches |
| | | if self.use_two_branches: |
| | | if self.merge_method == "concat": |
| | | x = x + stoch_layer_coeff * self.dropout( |
| | | self.merge_proj(torch.cat([x1, x2], dim=-1)) |
| | | ) |
| | | elif self.merge_method == "learned_ave": |
| | | if ( |
| | | self.training |
| | | and self.attn_branch_drop_rate > 0 |
| | | and torch.rand(1).item() < self.attn_branch_drop_rate |
| | | ): |
| | | # Drop the attn branch |
| | | w1, w2 = 0.0, 1.0 |
| | | else: |
| | | # branch1 |
| | | score1 = ( |
| | | self.pooling_proj1(x1).transpose(1, 2) / self.size**0.5 |
| | | ) # (batch, 1, time) |
| | | if mask is not None: |
| | | min_value = float( |
| | | numpy.finfo( |
| | | torch.tensor(0, dtype=score1.dtype).numpy().dtype |
| | | ).min |
| | | ) |
| | | score1 = score1.masked_fill(mask.eq(0), min_value) |
| | | score1 = torch.softmax(score1, dim=-1).masked_fill( |
| | | mask.eq(0), 0.0 |
| | | ) |
| | | else: |
| | | score1 = torch.softmax(score1, dim=-1) |
| | | pooled1 = torch.matmul(score1, x1).squeeze(1) # (batch, size) |
| | | weight1 = self.weight_proj1(pooled1) # (batch, 1) |
| | | |
| | | # branch2 |
| | | score2 = ( |
| | | self.pooling_proj2(x2).transpose(1, 2) / self.size**0.5 |
| | | ) # (batch, 1, time) |
| | | if mask is not None: |
| | | min_value = float( |
| | | numpy.finfo( |
| | | torch.tensor(0, dtype=score2.dtype).numpy().dtype |
| | | ).min |
| | | ) |
| | | score2 = score2.masked_fill(mask.eq(0), min_value) |
| | | score2 = torch.softmax(score2, dim=-1).masked_fill( |
| | | mask.eq(0), 0.0 |
| | | ) |
| | | else: |
| | | score2 = torch.softmax(score2, dim=-1) |
| | | pooled2 = torch.matmul(score2, x2).squeeze(1) # (batch, size) |
| | | weight2 = self.weight_proj2(pooled2) # (batch, 1) |
| | | |
| | | # normalize weights of two branches |
| | | merge_weights = torch.softmax( |
| | | torch.cat([weight1, weight2], dim=-1), dim=-1 |
| | | ) # (batch, 2) |
| | | merge_weights = merge_weights.unsqueeze(-1).unsqueeze( |
| | | -1 |
| | | ) # (batch, 2, 1, 1) |
| | | w1, w2 = merge_weights[:, 0], merge_weights[:, 1] # (batch, 1, 1) |
| | | |
| | | x = x + stoch_layer_coeff * self.dropout( |
| | | self.merge_proj(w1 * x1 + w2 * x2) |
| | | ) |
| | | elif self.merge_method == "fixed_ave": |
| | | x = x + stoch_layer_coeff * self.dropout( |
| | | self.merge_proj( |
| | | (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2 |
| | | ) |
| | | ) |
| | | else: |
| | | raise RuntimeError(f"unknown merge method: {self.merge_method}") |
| | | else: |
| | | if self.attn is None: |
| | | x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x2)) |
| | | elif self.cgmlp is None: |
| | | x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x1)) |
| | | else: |
| | | # This should not happen |
| | | raise RuntimeError("Both branches are not None, which is unexpected.") |
| | | |
| | | x = self.norm_final(x) |
| | | |
| | | if pos_emb is not None: |
| | | return (x, pos_emb), mask |
| | | |
| | | return x, mask |
| | | |
| | | |
| | | class BranchformerEncoder(AbsEncoder): |
| | | """Branchformer encoder module.""" |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | use_attn: bool = True, |
| | | attention_heads: int = 4, |
| | | attention_layer_type: str = "rel_selfattn", |
| | | pos_enc_layer_type: str = "rel_pos", |
| | | rel_pos_type: str = "latest", |
| | | use_cgmlp: bool = True, |
| | | cgmlp_linear_units: int = 2048, |
| | | cgmlp_conv_kernel: int = 31, |
| | | use_linear_after_conv: bool = False, |
| | | gate_activation: str = "identity", |
| | | merge_method: str = "concat", |
| | | cgmlp_weight: Union[float, List[float]] = 0.5, |
| | | attn_branch_drop_rate: Union[float, List[float]] = 0.0, |
| | | num_blocks: int = 12, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: Optional[str] = "conv2d", |
| | | zero_triu: bool = False, |
| | | padding_idx: int = -1, |
| | | stochastic_depth_rate: Union[float, List[float]] = 0.0, |
| | | ): |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | | if rel_pos_type == "legacy": |
| | | if pos_enc_layer_type == "rel_pos": |
| | | pos_enc_layer_type = "legacy_rel_pos" |
| | | if attention_layer_type == "rel_selfattn": |
| | | attention_layer_type = "legacy_rel_selfattn" |
| | | elif rel_pos_type == "latest": |
| | | assert attention_layer_type != "legacy_rel_selfattn" |
| | | assert pos_enc_layer_type != "legacy_rel_pos" |
| | | else: |
| | | raise ValueError("unknown rel_pos_type: " + rel_pos_type) |
| | | |
| | | if pos_enc_layer_type == "abs_pos": |
| | | pos_enc_class = PositionalEncoding |
| | | elif pos_enc_layer_type == "scaled_abs_pos": |
| | | pos_enc_class = ScaledPositionalEncoding |
| | | elif pos_enc_layer_type == "rel_pos": |
| | | assert attention_layer_type == "rel_selfattn" |
| | | pos_enc_class = RelPositionalEncoding |
| | | elif pos_enc_layer_type == "legacy_rel_pos": |
| | | assert attention_layer_type == "legacy_rel_selfattn" |
| | | pos_enc_class = LegacyRelPositionalEncoding |
| | | logging.warning( |
| | | "Using legacy_rel_pos and it will be deprecated in the future." |
| | | ) |
| | | else: |
| | | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) |
| | | |
| | | if input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(input_size, output_size), |
| | | torch.nn.LayerNorm(output_size), |
| | | torch.nn.Dropout(dropout_rate), |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "conv2d": |
| | | self.embed = Conv2dSubsampling( |
| | | input_size, |
| | | output_size, |
| | | dropout_rate, |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "conv2d2": |
| | | self.embed = Conv2dSubsampling2( |
| | | input_size, |
| | | output_size, |
| | | dropout_rate, |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "conv2d6": |
| | | self.embed = Conv2dSubsampling6( |
| | | input_size, |
| | | output_size, |
| | | dropout_rate, |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "conv2d8": |
| | | self.embed = Conv2dSubsampling8( |
| | | input_size, |
| | | output_size, |
| | | dropout_rate, |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif isinstance(input_layer, torch.nn.Module): |
| | | self.embed = torch.nn.Sequential( |
| | | input_layer, |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer is None: |
| | | if input_size == output_size: |
| | | self.embed = None |
| | | else: |
| | | self.embed = torch.nn.Linear(input_size, output_size) |
| | | else: |
| | | raise ValueError("unknown input_layer: " + input_layer) |
| | | |
| | | if attention_layer_type == "selfattn": |
| | | encoder_selfattn_layer = MultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | elif attention_layer_type == "legacy_rel_selfattn": |
| | | assert pos_enc_layer_type == "legacy_rel_pos" |
| | | encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | logging.warning( |
| | | "Using legacy_rel_selfattn and it will be deprecated in the future." |
| | | ) |
| | | elif attention_layer_type == "rel_selfattn": |
| | | assert pos_enc_layer_type == "rel_pos" |
| | | encoder_selfattn_layer = RelPositionMultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | zero_triu, |
| | | ) |
| | | elif attention_layer_type == "fast_selfattn": |
| | | assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"] |
| | | encoder_selfattn_layer = FastSelfAttention |
| | | encoder_selfattn_layer_args = ( |
| | | output_size, |
| | | attention_heads, |
| | | attention_dropout_rate, |
| | | ) |
| | | else: |
| | | raise ValueError("unknown encoder_attn_layer: " + attention_layer_type) |
| | | |
| | | cgmlp_layer = ConvolutionalGatingMLP |
| | | cgmlp_layer_args = ( |
| | | output_size, |
| | | cgmlp_linear_units, |
| | | cgmlp_conv_kernel, |
| | | dropout_rate, |
| | | use_linear_after_conv, |
| | | gate_activation, |
| | | ) |
| | | |
| | | if isinstance(stochastic_depth_rate, float): |
| | | stochastic_depth_rate = [stochastic_depth_rate] * num_blocks |
| | | if len(stochastic_depth_rate) != num_blocks: |
| | | raise ValueError( |
| | | f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) " |
| | | f"should be equal to num_blocks ({num_blocks})" |
| | | ) |
| | | |
| | | if isinstance(cgmlp_weight, float): |
| | | cgmlp_weight = [cgmlp_weight] * num_blocks |
| | | if len(cgmlp_weight) != num_blocks: |
| | | raise ValueError( |
| | | f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to " |
| | | f"num_blocks ({num_blocks})" |
| | | ) |
| | | |
| | | if isinstance(attn_branch_drop_rate, float): |
| | | attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks |
| | | if len(attn_branch_drop_rate) != num_blocks: |
| | | raise ValueError( |
| | | f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) " |
| | | f"should be equal to num_blocks ({num_blocks})" |
| | | ) |
| | | |
| | | self.encoders = repeat( |
| | | num_blocks, |
| | | lambda lnum: BranchformerEncoderLayer( |
| | | output_size, |
| | | encoder_selfattn_layer(*encoder_selfattn_layer_args) |
| | | if use_attn |
| | | else None, |
| | | cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None, |
| | | dropout_rate, |
| | | merge_method, |
| | | cgmlp_weight[lnum], |
| | | attn_branch_drop_rate[lnum], |
| | | stochastic_depth_rate[lnum], |
| | | ), |
| | | ) |
| | | self.after_norm = LayerNorm(output_size) |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Calculate forward propagation. |
| | | |
| | | Args: |
| | | xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). |
| | | ilens (torch.Tensor): Input length (#batch). |
| | | prev_states (torch.Tensor): Not to be used now. |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, L, output_size). |
| | | torch.Tensor: Output length (#batch). |
| | | torch.Tensor: Not to be used now. |
| | | |
| | | """ |
| | | |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | |
| | | if ( |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling2) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | ): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {xs_pad.size(1)} frames and is too short for subsampling " |
| | | + f"(it needs more than {limit_size} frames), return empty results", |
| | | xs_pad.size(1), |
| | | limit_size, |
| | | ) |
| | | xs_pad, masks = self.embed(xs_pad, masks) |
| | | elif self.embed is not None: |
| | | xs_pad = self.embed(xs_pad) |
| | | |
| | | xs_pad, masks = self.encoders(xs_pad, masks) |
| | | |
| | | if isinstance(xs_pad, tuple): |
| | | xs_pad = xs_pad[0] |
| | | |
| | | xs_pad = self.after_norm(xs_pad) |
| | | olens = masks.squeeze(1).sum(1) |
| | | return xs_pad, olens, None |
| New file |
| | |
| | | # Copyright 2022 Kwangyoun Kim (ASAPP inc.) |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | """E-Branchformer encoder definition. |
| | | Reference: |
| | | Kwangyoun Kim, Felix Wu, Yifan Peng, Jing Pan, |
| | | Prashant Sridhar, Kyu J. Han, Shinji Watanabe, |
| | | "E-Branchformer: Branchformer with Enhanced merging |
| | | for speech recognition," in SLT 2022. |
| | | """ |
| | | |
| | | import logging |
| | | from typing import List, Optional, Tuple |
| | | |
| | | import torch |
| | | |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.cgmlp import ConvolutionalGatingMLP |
| | | from funasr.modules.fastformer import FastSelfAttention |
| | | from funasr.modules.nets_utils import get_activation, make_pad_mask |
| | | from funasr.modules.attention import ( # noqa: H301 |
| | | LegacyRelPositionMultiHeadedAttention, |
| | | MultiHeadedAttention, |
| | | RelPositionMultiHeadedAttention, |
| | | ) |
| | | from funasr.modules.embedding import ( # noqa: H301 |
| | | LegacyRelPositionalEncoding, |
| | | PositionalEncoding, |
| | | RelPositionalEncoding, |
| | | ScaledPositionalEncoding, |
| | | ) |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.positionwise_feed_forward import ( |
| | | PositionwiseFeedForward, |
| | | ) |
| | | from funasr.modules.repeat import repeat |
| | | from funasr.modules.subsampling import ( |
| | | Conv2dSubsampling, |
| | | Conv2dSubsampling2, |
| | | Conv2dSubsampling6, |
| | | Conv2dSubsampling8, |
| | | TooShortUttError, |
| | | check_short_utt, |
| | | ) |
| | | |
| | | |
| | | class EBranchformerEncoderLayer(torch.nn.Module): |
| | | """E-Branchformer encoder layer module. |
| | | |
| | | Args: |
| | | size (int): model dimension |
| | | attn: standard self-attention or efficient attention |
| | | cgmlp: ConvolutionalGatingMLP |
| | | feed_forward: feed-forward module, optional |
| | | feed_forward: macaron-style feed-forward module, optional |
| | | dropout_rate (float): dropout probability |
| | | merge_conv_kernel (int): kernel size of the depth-wise conv in merge module |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | attn: torch.nn.Module, |
| | | cgmlp: torch.nn.Module, |
| | | feed_forward: Optional[torch.nn.Module], |
| | | feed_forward_macaron: Optional[torch.nn.Module], |
| | | dropout_rate: float, |
| | | merge_conv_kernel: int = 3, |
| | | ): |
| | | super().__init__() |
| | | |
| | | self.size = size |
| | | self.attn = attn |
| | | self.cgmlp = cgmlp |
| | | |
| | | self.feed_forward = feed_forward |
| | | self.feed_forward_macaron = feed_forward_macaron |
| | | self.ff_scale = 1.0 |
| | | if self.feed_forward is not None: |
| | | self.norm_ff = LayerNorm(size) |
| | | if self.feed_forward_macaron is not None: |
| | | self.ff_scale = 0.5 |
| | | self.norm_ff_macaron = LayerNorm(size) |
| | | |
| | | self.norm_mha = LayerNorm(size) # for the MHA module |
| | | self.norm_mlp = LayerNorm(size) # for the MLP module |
| | | self.norm_final = LayerNorm(size) # for the final output of the block |
| | | |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | |
| | | self.depthwise_conv_fusion = torch.nn.Conv1d( |
| | | size + size, |
| | | size + size, |
| | | kernel_size=merge_conv_kernel, |
| | | stride=1, |
| | | padding=(merge_conv_kernel - 1) // 2, |
| | | groups=size + size, |
| | | bias=True, |
| | | ) |
| | | self.merge_proj = torch.nn.Linear(size + size, size) |
| | | |
| | | def forward(self, x_input, mask, cache=None): |
| | | """Compute encoded features. |
| | | |
| | | Args: |
| | | x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. |
| | | - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. |
| | | - w/o pos emb: Tensor (#batch, time, size). |
| | | mask (torch.Tensor): Mask tensor for the input (#batch, 1, time). |
| | | cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, time, size). |
| | | torch.Tensor: Mask tensor (#batch, time). |
| | | """ |
| | | |
| | | if cache is not None: |
| | | raise NotImplementedError("cache is not None, which is not tested") |
| | | |
| | | if isinstance(x_input, tuple): |
| | | x, pos_emb = x_input[0], x_input[1] |
| | | else: |
| | | x, pos_emb = x_input, None |
| | | |
| | | if self.feed_forward_macaron is not None: |
| | | residual = x |
| | | x = self.norm_ff_macaron(x) |
| | | x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) |
| | | |
| | | # Two branches |
| | | x1 = x |
| | | x2 = x |
| | | |
| | | # Branch 1: multi-headed attention module |
| | | x1 = self.norm_mha(x1) |
| | | |
| | | if isinstance(self.attn, FastSelfAttention): |
| | | x_att = self.attn(x1, mask) |
| | | else: |
| | | if pos_emb is not None: |
| | | x_att = self.attn(x1, x1, x1, pos_emb, mask) |
| | | else: |
| | | x_att = self.attn(x1, x1, x1, mask) |
| | | |
| | | x1 = self.dropout(x_att) |
| | | |
| | | # Branch 2: convolutional gating mlp |
| | | x2 = self.norm_mlp(x2) |
| | | |
| | | if pos_emb is not None: |
| | | x2 = (x2, pos_emb) |
| | | x2 = self.cgmlp(x2, mask) |
| | | if isinstance(x2, tuple): |
| | | x2 = x2[0] |
| | | |
| | | x2 = self.dropout(x2) |
| | | |
| | | # Merge two branches |
| | | x_concat = torch.cat([x1, x2], dim=-1) |
| | | x_tmp = x_concat.transpose(1, 2) |
| | | x_tmp = self.depthwise_conv_fusion(x_tmp) |
| | | x_tmp = x_tmp.transpose(1, 2) |
| | | x = x + self.dropout(self.merge_proj(x_concat + x_tmp)) |
| | | |
| | | if self.feed_forward is not None: |
| | | # feed forward module |
| | | residual = x |
| | | x = self.norm_ff(x) |
| | | x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) |
| | | |
| | | x = self.norm_final(x) |
| | | |
| | | if pos_emb is not None: |
| | | return (x, pos_emb), mask |
| | | |
| | | return x, mask |
| | | |
| | | |
| | | class EBranchformerEncoder(AbsEncoder): |
| | | """E-Branchformer encoder module.""" |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | attention_layer_type: str = "rel_selfattn", |
| | | pos_enc_layer_type: str = "rel_pos", |
| | | rel_pos_type: str = "latest", |
| | | cgmlp_linear_units: int = 2048, |
| | | cgmlp_conv_kernel: int = 31, |
| | | use_linear_after_conv: bool = False, |
| | | gate_activation: str = "identity", |
| | | num_blocks: int = 12, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: Optional[str] = "conv2d", |
| | | zero_triu: bool = False, |
| | | padding_idx: int = -1, |
| | | layer_drop_rate: float = 0.0, |
| | | max_pos_emb_len: int = 5000, |
| | | use_ffn: bool = False, |
| | | macaron_ffn: bool = False, |
| | | ffn_activation_type: str = "swish", |
| | | linear_units: int = 2048, |
| | | positionwise_layer_type: str = "linear", |
| | | merge_conv_kernel: int = 3, |
| | | interctc_layer_idx=None, |
| | | interctc_use_conditioning: bool = False, |
| | | ): |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | | if rel_pos_type == "legacy": |
| | | if pos_enc_layer_type == "rel_pos": |
| | | pos_enc_layer_type = "legacy_rel_pos" |
| | | if attention_layer_type == "rel_selfattn": |
| | | attention_layer_type = "legacy_rel_selfattn" |
| | | elif rel_pos_type == "latest": |
| | | assert attention_layer_type != "legacy_rel_selfattn" |
| | | assert pos_enc_layer_type != "legacy_rel_pos" |
| | | else: |
| | | raise ValueError("unknown rel_pos_type: " + rel_pos_type) |
| | | |
| | | if pos_enc_layer_type == "abs_pos": |
| | | pos_enc_class = PositionalEncoding |
| | | elif pos_enc_layer_type == "scaled_abs_pos": |
| | | pos_enc_class = ScaledPositionalEncoding |
| | | elif pos_enc_layer_type == "rel_pos": |
| | | assert attention_layer_type == "rel_selfattn" |
| | | pos_enc_class = RelPositionalEncoding |
| | | elif pos_enc_layer_type == "legacy_rel_pos": |
| | | assert attention_layer_type == "legacy_rel_selfattn" |
| | | pos_enc_class = LegacyRelPositionalEncoding |
| | | logging.warning( |
| | | "Using legacy_rel_pos and it will be deprecated in the future." |
| | | ) |
| | | else: |
| | | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) |
| | | |
| | | if input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(input_size, output_size), |
| | | torch.nn.LayerNorm(output_size), |
| | | torch.nn.Dropout(dropout_rate), |
| | | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), |
| | | ) |
| | | elif input_layer == "conv2d": |
| | | self.embed = Conv2dSubsampling( |
| | | input_size, |
| | | output_size, |
| | | dropout_rate, |
| | | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), |
| | | ) |
| | | elif input_layer == "conv2d2": |
| | | self.embed = Conv2dSubsampling2( |
| | | input_size, |
| | | output_size, |
| | | dropout_rate, |
| | | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), |
| | | ) |
| | | elif input_layer == "conv2d6": |
| | | self.embed = Conv2dSubsampling6( |
| | | input_size, |
| | | output_size, |
| | | dropout_rate, |
| | | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), |
| | | ) |
| | | elif input_layer == "conv2d8": |
| | | self.embed = Conv2dSubsampling8( |
| | | input_size, |
| | | output_size, |
| | | dropout_rate, |
| | | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), |
| | | ) |
| | | elif input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx), |
| | | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), |
| | | ) |
| | | elif isinstance(input_layer, torch.nn.Module): |
| | | self.embed = torch.nn.Sequential( |
| | | input_layer, |
| | | pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len), |
| | | ) |
| | | elif input_layer is None: |
| | | if input_size == output_size: |
| | | self.embed = None |
| | | else: |
| | | self.embed = torch.nn.Linear(input_size, output_size) |
| | | else: |
| | | raise ValueError("unknown input_layer: " + input_layer) |
| | | |
| | | activation = get_activation(ffn_activation_type) |
| | | if positionwise_layer_type == "linear": |
| | | positionwise_layer = PositionwiseFeedForward |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | dropout_rate, |
| | | activation, |
| | | ) |
| | | elif positionwise_layer_type is None: |
| | | logging.warning("no macaron ffn") |
| | | else: |
| | | raise ValueError("Support only linear.") |
| | | |
| | | if attention_layer_type == "selfattn": |
| | | encoder_selfattn_layer = MultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | elif attention_layer_type == "legacy_rel_selfattn": |
| | | assert pos_enc_layer_type == "legacy_rel_pos" |
| | | encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | logging.warning( |
| | | "Using legacy_rel_selfattn and it will be deprecated in the future." |
| | | ) |
| | | elif attention_layer_type == "rel_selfattn": |
| | | assert pos_enc_layer_type == "rel_pos" |
| | | encoder_selfattn_layer = RelPositionMultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | zero_triu, |
| | | ) |
| | | elif attention_layer_type == "fast_selfattn": |
| | | assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"] |
| | | encoder_selfattn_layer = FastSelfAttention |
| | | encoder_selfattn_layer_args = ( |
| | | output_size, |
| | | attention_heads, |
| | | attention_dropout_rate, |
| | | ) |
| | | else: |
| | | raise ValueError("unknown encoder_attn_layer: " + attention_layer_type) |
| | | |
| | | cgmlp_layer = ConvolutionalGatingMLP |
| | | cgmlp_layer_args = ( |
| | | output_size, |
| | | cgmlp_linear_units, |
| | | cgmlp_conv_kernel, |
| | | dropout_rate, |
| | | use_linear_after_conv, |
| | | gate_activation, |
| | | ) |
| | | |
| | | self.encoders = repeat( |
| | | num_blocks, |
| | | lambda lnum: EBranchformerEncoderLayer( |
| | | output_size, |
| | | encoder_selfattn_layer(*encoder_selfattn_layer_args), |
| | | cgmlp_layer(*cgmlp_layer_args), |
| | | positionwise_layer(*positionwise_layer_args) if use_ffn else None, |
| | | positionwise_layer(*positionwise_layer_args) |
| | | if use_ffn and macaron_ffn |
| | | else None, |
| | | dropout_rate, |
| | | merge_conv_kernel, |
| | | ), |
| | | layer_drop_rate, |
| | | ) |
| | | self.after_norm = LayerNorm(output_size) |
| | | |
| | | if interctc_layer_idx is None: |
| | | interctc_layer_idx = [] |
| | | self.interctc_layer_idx = interctc_layer_idx |
| | | if len(interctc_layer_idx) > 0: |
| | | assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks |
| | | self.interctc_use_conditioning = interctc_use_conditioning |
| | | self.conditioning_layer = None |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ctc: CTC = None, |
| | | max_layer: int = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Calculate forward propagation. |
| | | |
| | | Args: |
| | | xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). |
| | | ilens (torch.Tensor): Input length (#batch). |
| | | prev_states (torch.Tensor): Not to be used now. |
| | | ctc (CTC): Intermediate CTC module. |
| | | max_layer (int): Layer depth below which InterCTC is applied. |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, L, output_size). |
| | | torch.Tensor: Output length (#batch). |
| | | torch.Tensor: Not to be used now. |
| | | """ |
| | | |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | |
| | | if ( |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling2) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | ): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {xs_pad.size(1)} frames and is too short for subsampling " |
| | | + f"(it needs more than {limit_size} frames), return empty results", |
| | | xs_pad.size(1), |
| | | limit_size, |
| | | ) |
| | | xs_pad, masks = self.embed(xs_pad, masks) |
| | | elif self.embed is not None: |
| | | xs_pad = self.embed(xs_pad) |
| | | |
| | | intermediate_outs = [] |
| | | if len(self.interctc_layer_idx) == 0: |
| | | if max_layer is not None and 0 <= max_layer < len(self.encoders): |
| | | for layer_idx, encoder_layer in enumerate(self.encoders): |
| | | xs_pad, masks = encoder_layer(xs_pad, masks) |
| | | if layer_idx >= max_layer: |
| | | break |
| | | else: |
| | | xs_pad, masks = self.encoders(xs_pad, masks) |
| | | else: |
| | | for layer_idx, encoder_layer in enumerate(self.encoders): |
| | | xs_pad, masks = encoder_layer(xs_pad, masks) |
| | | |
| | | if layer_idx + 1 in self.interctc_layer_idx: |
| | | encoder_out = xs_pad |
| | | |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | | |
| | | intermediate_outs.append((layer_idx + 1, encoder_out)) |
| | | |
| | | if self.interctc_use_conditioning: |
| | | ctc_out = ctc.softmax(encoder_out) |
| | | |
| | | if isinstance(xs_pad, tuple): |
| | | xs_pad = list(xs_pad) |
| | | xs_pad[0] = xs_pad[0] + self.conditioning_layer(ctc_out) |
| | | xs_pad = tuple(xs_pad) |
| | | else: |
| | | xs_pad = xs_pad + self.conditioning_layer(ctc_out) |
| | | |
| | | if isinstance(xs_pad, tuple): |
| | | xs_pad = xs_pad[0] |
| | | |
| | | xs_pad = self.after_norm(xs_pad) |
| | | olens = masks.squeeze(1).sum(1) |
| | | if len(intermediate_outs) > 0: |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| New file |
| | |
| | | """MLP with convolutional gating (cgMLP) definition. |
| | | |
| | | References: |
| | | https://openreview.net/forum?id=RA-zVvZLYIy |
| | | https://arxiv.org/abs/2105.08050 |
| | | |
| | | """ |
| | | |
| | | import torch |
| | | |
| | | from funasr.modules.nets_utils import get_activation |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | |
| | | |
| | | class ConvolutionalSpatialGatingUnit(torch.nn.Module): |
| | | """Convolutional Spatial Gating Unit (CSGU).""" |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | kernel_size: int, |
| | | dropout_rate: float, |
| | | use_linear_after_conv: bool, |
| | | gate_activation: str, |
| | | ): |
| | | super().__init__() |
| | | |
| | | n_channels = size // 2 # split input channels |
| | | self.norm = LayerNorm(n_channels) |
| | | self.conv = torch.nn.Conv1d( |
| | | n_channels, |
| | | n_channels, |
| | | kernel_size, |
| | | 1, |
| | | (kernel_size - 1) // 2, |
| | | groups=n_channels, |
| | | ) |
| | | if use_linear_after_conv: |
| | | self.linear = torch.nn.Linear(n_channels, n_channels) |
| | | else: |
| | | self.linear = None |
| | | |
| | | if gate_activation == "identity": |
| | | self.act = torch.nn.Identity() |
| | | else: |
| | | self.act = get_activation(gate_activation) |
| | | |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | |
| | | def espnet_initialization_fn(self): |
| | | torch.nn.init.normal_(self.conv.weight, std=1e-6) |
| | | torch.nn.init.ones_(self.conv.bias) |
| | | if self.linear is not None: |
| | | torch.nn.init.normal_(self.linear.weight, std=1e-6) |
| | | torch.nn.init.ones_(self.linear.bias) |
| | | |
| | | def forward(self, x, gate_add=None): |
| | | """Forward method |
| | | |
| | | Args: |
| | | x (torch.Tensor): (N, T, D) |
| | | gate_add (torch.Tensor): (N, T, D/2) |
| | | |
| | | Returns: |
| | | out (torch.Tensor): (N, T, D/2) |
| | | """ |
| | | |
| | | x_r, x_g = x.chunk(2, dim=-1) |
| | | |
| | | x_g = self.norm(x_g) # (N, T, D/2) |
| | | x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2) # (N, T, D/2) |
| | | if self.linear is not None: |
| | | x_g = self.linear(x_g) |
| | | |
| | | if gate_add is not None: |
| | | x_g = x_g + gate_add |
| | | |
| | | x_g = self.act(x_g) |
| | | out = x_r * x_g # (N, T, D/2) |
| | | out = self.dropout(out) |
| | | return out |
| | | |
| | | |
| | | class ConvolutionalGatingMLP(torch.nn.Module): |
| | | """Convolutional Gating MLP (cgMLP).""" |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | linear_units: int, |
| | | kernel_size: int, |
| | | dropout_rate: float, |
| | | use_linear_after_conv: bool, |
| | | gate_activation: str, |
| | | ): |
| | | super().__init__() |
| | | |
| | | self.channel_proj1 = torch.nn.Sequential( |
| | | torch.nn.Linear(size, linear_units), torch.nn.GELU() |
| | | ) |
| | | self.csgu = ConvolutionalSpatialGatingUnit( |
| | | size=linear_units, |
| | | kernel_size=kernel_size, |
| | | dropout_rate=dropout_rate, |
| | | use_linear_after_conv=use_linear_after_conv, |
| | | gate_activation=gate_activation, |
| | | ) |
| | | self.channel_proj2 = torch.nn.Linear(linear_units // 2, size) |
| | | |
| | | def forward(self, x, mask): |
| | | if isinstance(x, tuple): |
| | | xs_pad, pos_emb = x |
| | | else: |
| | | xs_pad, pos_emb = x, None |
| | | |
| | | xs_pad = self.channel_proj1(xs_pad) # size -> linear_units |
| | | xs_pad = self.csgu(xs_pad) # linear_units -> linear_units/2 |
| | | xs_pad = self.channel_proj2(xs_pad) # linear_units/2 -> size |
| | | |
| | | if pos_emb is not None: |
| | | out = (xs_pad, pos_emb) |
| | | else: |
| | | out = xs_pad |
| | | return out |
| New file |
| | |
| | | """Fastformer attention definition. |
| | | |
| | | Reference: |
| | | Wu et al., "Fastformer: Additive Attention Can Be All You Need" |
| | | https://arxiv.org/abs/2108.09084 |
| | | https://github.com/wuch15/Fastformer |
| | | |
| | | """ |
| | | |
| | | import numpy |
| | | import torch |
| | | |
| | | |
| | | class FastSelfAttention(torch.nn.Module): |
| | | """Fast self-attention used in Fastformer.""" |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | attention_heads, |
| | | dropout_rate, |
| | | ): |
| | | super().__init__() |
| | | if size % attention_heads != 0: |
| | | raise ValueError( |
| | | f"Hidden size ({size}) is not an integer multiple " |
| | | f"of attention heads ({attention_heads})" |
| | | ) |
| | | self.attention_head_size = size // attention_heads |
| | | self.num_attention_heads = attention_heads |
| | | |
| | | self.query = torch.nn.Linear(size, size) |
| | | self.query_att = torch.nn.Linear(size, attention_heads) |
| | | self.key = torch.nn.Linear(size, size) |
| | | self.key_att = torch.nn.Linear(size, attention_heads) |
| | | self.transform = torch.nn.Linear(size, size) |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | |
| | | def espnet_initialization_fn(self): |
| | | self.apply(self.init_weights) |
| | | |
| | | def init_weights(self, module): |
| | | if isinstance(module, torch.nn.Linear): |
| | | module.weight.data.normal_(mean=0.0, std=0.02) |
| | | if isinstance(module, torch.nn.Linear) and module.bias is not None: |
| | | module.bias.data.zero_() |
| | | |
| | | def transpose_for_scores(self, x): |
| | | """Reshape and transpose to compute scores. |
| | | |
| | | Args: |
| | | x: (batch, time, size = n_heads * attn_dim) |
| | | |
| | | Returns: |
| | | (batch, n_heads, time, attn_dim) |
| | | """ |
| | | |
| | | new_x_shape = x.shape[:-1] + ( |
| | | self.num_attention_heads, |
| | | self.attention_head_size, |
| | | ) |
| | | return x.reshape(*new_x_shape).transpose(1, 2) |
| | | |
| | | def forward(self, xs_pad, mask): |
| | | """Forward method. |
| | | |
| | | Args: |
| | | xs_pad: (batch, time, size = n_heads * attn_dim) |
| | | mask: (batch, 1, time), nonpadding is 1, padding is 0 |
| | | |
| | | Returns: |
| | | torch.Tensor: (batch, time, size) |
| | | """ |
| | | |
| | | batch_size, seq_len, _ = xs_pad.shape |
| | | mixed_query_layer = self.query(xs_pad) # (batch, time, size) |
| | | mixed_key_layer = self.key(xs_pad) # (batch, time, size) |
| | | |
| | | if mask is not None: |
| | | mask = mask.eq(0) # padding is 1, nonpadding is 0 |
| | | |
| | | # (batch, n_heads, time) |
| | | query_for_score = ( |
| | | self.query_att(mixed_query_layer).transpose(1, 2) |
| | | / self.attention_head_size**0.5 |
| | | ) |
| | | if mask is not None: |
| | | min_value = float( |
| | | numpy.finfo( |
| | | torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype |
| | | ).min |
| | | ) |
| | | query_for_score = query_for_score.masked_fill(mask, min_value) |
| | | query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0) |
| | | else: |
| | | query_weight = torch.softmax(query_for_score, dim=-1) |
| | | |
| | | query_weight = query_weight.unsqueeze(2) # (batch, n_heads, 1, time) |
| | | query_layer = self.transpose_for_scores( |
| | | mixed_query_layer |
| | | ) # (batch, n_heads, time, attn_dim) |
| | | |
| | | pooled_query = ( |
| | | torch.matmul(query_weight, query_layer) |
| | | .transpose(1, 2) |
| | | .reshape(-1, 1, self.num_attention_heads * self.attention_head_size) |
| | | ) # (batch, 1, size = n_heads * attn_dim) |
| | | pooled_query = self.dropout(pooled_query) |
| | | pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size) |
| | | |
| | | mixed_query_key_layer = ( |
| | | mixed_key_layer * pooled_query_repeat |
| | | ) # (batch, time, size) |
| | | |
| | | # (batch, n_heads, time) |
| | | query_key_score = ( |
| | | self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5 |
| | | ).transpose(1, 2) |
| | | if mask is not None: |
| | | min_value = float( |
| | | numpy.finfo( |
| | | torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype |
| | | ).min |
| | | ) |
| | | query_key_score = query_key_score.masked_fill(mask, min_value) |
| | | query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill( |
| | | mask, 0.0 |
| | | ) |
| | | else: |
| | | query_key_weight = torch.softmax(query_key_score, dim=-1) |
| | | |
| | | query_key_weight = query_key_weight.unsqueeze(2) # (batch, n_heads, 1, time) |
| | | key_layer = self.transpose_for_scores( |
| | | mixed_query_key_layer |
| | | ) # (batch, n_heads, time, attn_dim) |
| | | pooled_key = torch.matmul( |
| | | query_key_weight, key_layer |
| | | ) # (batch, n_heads, 1, attn_dim) |
| | | pooled_key = self.dropout(pooled_key) |
| | | |
| | | # NOTE: value = query, due to param sharing |
| | | weighted_value = (pooled_key * query_layer).transpose( |
| | | 1, 2 |
| | | ) # (batch, time, n_heads, attn_dim) |
| | | weighted_value = weighted_value.reshape( |
| | | weighted_value.shape[:-2] |
| | | + (self.num_attention_heads * self.attention_head_size,) |
| | | ) # (batch, time, size) |
| | | weighted_value = ( |
| | | self.dropout(self.transform(weighted_value)) + mixed_query_layer |
| | | ) |
| | | |
| | | return weighted_value |
| | |
| | | class MultiSequential(torch.nn.Sequential): |
| | | """Multi-input multi-output torch.nn.Sequential.""" |
| | | |
| | | def __init__(self, *args, layer_drop_rate=0.0): |
| | | """Initialize MultiSequential with layer_drop. |
| | | |
| | | Args: |
| | | layer_drop_rate (float): Probability of dropping out each fn (layer). |
| | | |
| | | """ |
| | | super(MultiSequential, self).__init__(*args) |
| | | self.layer_drop_rate = layer_drop_rate |
| | | |
| | | def forward(self, *args): |
| | | """Repeat.""" |
| | | for m in self: |
| | | args = m(*args) |
| | | _probs = torch.empty(len(self)).uniform_() |
| | | for idx, m in enumerate(self): |
| | | if not self.training or (_probs[idx] >= self.layer_drop_rate): |
| | | args = m(*args) |
| | | return args |
| | | |
| | | |
| | | def repeat(N, fn): |
| | | def repeat(N, fn, layer_drop_rate=0.0): |
| | | """Repeat module N times. |
| | | |
| | | Args: |
| | | N (int): Number of repeat time. |
| | | fn (Callable): Function to generate module. |
| | | layer_drop_rate (float): Probability of dropping out each fn (layer). |
| | | |
| | | Returns: |
| | | MultiSequential: Repeated model instance. |
| | | |
| | | """ |
| | | return MultiSequential(*[fn(n) for n in range(N)]) |
| | | return MultiSequential(*[fn(n) for n in range(N)], layer_drop_rate=layer_drop_rate) |
| | | |
| | | |
| | | class MultiBlocks(torch.nn.Module): |
| | |
| | | Return: |
| | | model: ASR BAT model. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |