From 2a80f66ffec4948d1e1cee8aa6c5a5743c422b73 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 11 七月 2023 00:37:39 +0800
Subject: [PATCH] update
---
funasr/build_utils/build_asr_model.py | 2
egs/aishell/e_branchformer/conf/train_asr_e_branchformer.yaml | 101 +++++
egs/aishell/e_branchformer/run.sh | 225 +++++++++++++
egs/aishell/e_branchformer/conf/decode_asr_transformer.yaml | 6
egs/aishell/e_branchformer/path.sh | 5
funasr/models/encoder/e_branchformer_encoder.py | 467 +++++++++++++++++++++++++++
egs/aishell/e_branchformer/local/aishell_data_prep.sh | 66 +++
egs/aishell/e_branchformer/local/download_and_untar.sh | 105 ++++++
8 files changed, 977 insertions(+), 0 deletions(-)
diff --git a/egs/aishell/e_branchformer/conf/decode_asr_transformer.yaml b/egs/aishell/e_branchformer/conf/decode_asr_transformer.yaml
new file mode 100644
index 0000000..e87a293
--- /dev/null
+++ b/egs/aishell/e_branchformer/conf/decode_asr_transformer.yaml
@@ -0,0 +1,6 @@
+beam_size: 10
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.4
+lm_weight: 0.0
diff --git a/egs/aishell/e_branchformer/conf/train_asr_e_branchformer.yaml b/egs/aishell/e_branchformer/conf/train_asr_e_branchformer.yaml
new file mode 100644
index 0000000..a30e9a2
--- /dev/null
+++ b/egs/aishell/e_branchformer/conf/train_asr_e_branchformer.yaml
@@ -0,0 +1,101 @@
+# network architecture
+# encoder related
+encoder: e_branchformer
+encoder_conf:
+ output_size: 256
+ attention_heads: 4
+ attention_layer_type: rel_selfattn
+ pos_enc_layer_type: rel_pos
+ rel_pos_type: latest
+ cgmlp_linear_units: 1024
+ cgmlp_conv_kernel: 31
+ use_linear_after_conv: false
+ gate_activation: identity
+ num_blocks: 12
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.1
+ input_layer: conv2d
+ layer_drop_rate: 0.0
+ linear_units: 1024
+ positionwise_layer_type: linear
+ use_ffn: true
+ macaron_ffn: true
+ merge_conv_kernel: 31
+
+# decoder related
+decoder: transformer
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 6
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.
+ src_attention_dropout_rate: 0.
+
+# frontend related
+frontend: wav_frontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 1
+ lfr_n: 1
+
+# hybrid CTC/attention
+model_conf:
+ ctc_weight: 0.3
+ lsm_weight: 0.1 # label smoothing option
+ length_normalized_loss: false
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 180
+best_model_criterion:
+- - valid
+ - acc
+ - max
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.001
+ weight_decay: 0.000001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 35000
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 27
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_ratio_range:
+ - 0.
+ - 0.05
+ num_time_mask: 10
+
+dataset_conf:
+ data_names: speech,text
+ data_types: sound,text
+ shuffle: True
+ shuffle_conf:
+ shuffle_size: 2048
+ sort_size: 500
+ batch_conf:
+ batch_type: token
+ batch_size: 10000
+ num_workers: 8
+
+log_interval: 50
+normalize: None
\ No newline at end of file
diff --git a/egs/aishell/e_branchformer/local/aishell_data_prep.sh b/egs/aishell/e_branchformer/local/aishell_data_prep.sh
new file mode 100755
index 0000000..83f489b
--- /dev/null
+++ b/egs/aishell/e_branchformer/local/aishell_data_prep.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+
+# Copyright 2017 Xingyu Na
+# Apache 2.0
+
+#. ./path.sh || exit 1;
+
+if [ $# != 3 ]; then
+ echo "Usage: $0 <audio-path> <text-path> <output-path>"
+ echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
+ exit 1;
+fi
+
+aishell_audio_dir=$1
+aishell_text=$2/aishell_transcript_v0.8.txt
+output_dir=$3
+
+train_dir=$output_dir/data/local/train
+dev_dir=$output_dir/data/local/dev
+test_dir=$output_dir/data/local/test
+tmp_dir=$output_dir/data/local/tmp
+
+mkdir -p $train_dir
+mkdir -p $dev_dir
+mkdir -p $test_dir
+mkdir -p $tmp_dir
+
+# data directory check
+if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
+ echo "Error: $0 requires two directory arguments"
+ exit 1;
+fi
+
+# find wav audio file for train, dev and test resp.
+find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
+n=`cat $tmp_dir/wav.flist | wc -l`
+[ $n -ne 141925 ] && \
+ echo Warning: expected 141925 data data files, found $n
+
+grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
+grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
+grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
+
+rm -r $tmp_dir
+
+# Transcriptions preparation
+for dir in $train_dir $dev_dir $test_dir; do
+ echo Preparing $dir transcriptions
+ sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
+ paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
+ utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
+ awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
+ utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
+ sort -u $dir/transcripts.txt > $dir/text
+done
+
+mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
+
+for f in wav.scp text; do
+ cp $train_dir/$f $output_dir/data/train/$f || exit 1;
+ cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
+ cp $test_dir/$f $output_dir/data/test/$f || exit 1;
+done
+
+echo "$0: AISHELL data preparation succeeded"
+exit 0;
diff --git a/egs/aishell/e_branchformer/local/download_and_untar.sh b/egs/aishell/e_branchformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/e_branchformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
+# 2017 Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+ remove_archive=true
+ shift
+fi
+
+if [ $# -ne 3 ]; then
+ echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+ echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+ echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+ echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+ echo "$0: no such directory $data"
+ exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+ if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+ echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+ exit 1;
+fi
+
+if [ -z "$url" ]; then
+ echo "$0: empty URL base."
+ exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+ echo "$0: data part $part was already successfully extracted, nothing to do."
+ exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+ size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+ size_ok=false
+ for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+ if ! $size_ok; then
+ echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+ echo "does not equal the size of one of the archives."
+ rm $data/$part.tgz
+ else
+ echo "$data/$part.tgz exists and appears to be complete."
+ fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+ if ! command -v wget >/dev/null; then
+ echo "$0: wget is not installed."
+ exit 1;
+ fi
+ full_url=$url/$part.tgz
+ echo "$0: downloading data from $full_url. This may take some time, please be patient."
+
+ cd $data || exit 1
+ if ! wget --no-check-certificate $full_url; then
+ echo "$0: error executing wget $full_url"
+ exit 1;
+ fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+ echo "$0: error un-tarring archive $data/$part.tgz"
+ exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+ cd $data/$part/wav || exit 1
+ for wav in ./*.tar.gz; do
+ echo "Extracting wav from $wav"
+ tar -zxf $wav && rm $wav
+ done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+ echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+ rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/e_branchformer/path.sh b/egs/aishell/e_branchformer/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/egs/aishell/e_branchformer/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/egs/aishell/e_branchformer/run.sh b/egs/aishell/e_branchformer/run.sh
new file mode 100755
index 0000000..bcba2d7
--- /dev/null
+++ b/egs/aishell/e_branchformer/run.sh
@@ -0,0 +1,225 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1,2,3"
+gpu_num=4
+count=1
+gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="../DATA" #feature output dictionary
+exp_dir="."
+lang=zh
+token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=0
+stop_stage=5
+
+# feature configuration
+feats_dim=80
+nj=64
+
+# data
+raw_data=../raw_data
+data_url=www.openslr.org/resources/33
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train
+valid_set=dev
+test_sets="dev test"
+
+asr_config=conf/train_asr_e_branchformer.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+
+inference_config=conf/decode_asr_transformer.yaml
+inference_asr_model=valid.acc.ave_10best.pb
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+ inference_nj=$[${ngpu}*${njob}]
+ _ngpu=1
+else
+ inference_nj=$njob
+ _ngpu=0
+fi
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+ echo "stage -1: Data Download"
+ local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+ local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
+ for x in train dev test; do
+ cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
+ > ${feats_dir}/data/${x}/text
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
+ done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Feature and CMVN Generation"
+ utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
+fi
+
+token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: Dictionary Preparation"
+ mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/
+
+ echo "make a dictionary"
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
+ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+ echo "<unk>" >> ${token_list}
+fi
+
+# LM Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: ASR Training"
+ mkdir -p ${exp_dir}/exp/${model_dir}
+ mkdir -p ${exp_dir}/exp/${model_dir}/log
+ INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $gpu_num; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ train.py \
+ --task_name asr \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --token_type $token_type \
+ --token_list $token_list \
+ --data_dir ${feats_dir}/data \
+ --train_set ${train_set} \
+ --valid_set ${valid_set} \
+ --data_file_names "wav.scp,text" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
+ --speed_perturb ${speed_perturb} \
+ --resume true \
+ --output_dir ${exp_dir}/exp/${model_dir} \
+ --config $asr_config \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+ echo "stage 5: Inference"
+ for dset in ${test_sets}; do
+ asr_exp=${exp_dir}/exp/${model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/data/${dset}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --gpuid_list ${gpuid_list} \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${asr_exp}"/config.yaml \
+ --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode asr \
+ ${_opts}
+
+ for f in token token_int score text; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+ python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
+ python utils/proce_text.py ${_data}/text ${_data}/text.proc
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+ done
+fi
+
+# Prepare files for ModelScope fine-tuning and inference
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+ echo "stage 6: ModelScope Preparation"
+ cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn
+ vocab_size=$(cat ${token_list} | wc -l)
+ python utils/gen_modelscope_configuration.py \
+ --am_model_name $inference_asr_model \
+ --mode asr \
+ --model_name conformer \
+ --dataset aishell \
+ --output_dir $exp_dir/exp/$model_dir \
+ --vocab_size $vocab_size \
+ --tag $tag
+fi
\ No newline at end of file
diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
index b0734ff..5e93444 100644
--- a/funasr/build_utils/build_asr_model.py
+++ b/funasr/build_utils/build_asr_model.py
@@ -40,6 +40,7 @@
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.branchformer_encoder import BranchformerEncoder
+from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
@@ -115,6 +116,7 @@
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
branchformer=BranchformerEncoder,
+ e_branchformer=EBranchformerEncoder,
mfcca_enc=MFCCAEncoder,
chunk_conformer=ConformerChunkEncoder,
),
diff --git a/funasr/models/encoder/e_branchformer_encoder.py b/funasr/models/encoder/e_branchformer_encoder.py
new file mode 100644
index 0000000..65e481d
--- /dev/null
+++ b/funasr/models/encoder/e_branchformer_encoder.py
@@ -0,0 +1,467 @@
+# Copyright 2022 Kwangyoun Kim (ASAPP inc.)
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""E-Branchformer encoder definition.
+Reference:
+ Kwangyoun Kim, Felix Wu, Yifan Peng, Jing Pan,
+ Prashant Sridhar, Kyu J. Han, Shinji Watanabe,
+ "E-Branchformer: Branchformer with Enhanced merging
+ for speech recognition," in SLT 2022.
+"""
+
+import logging
+from typing import List, Optional, Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.models.ctc import CTC
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.modules.cgmlp import ConvolutionalGatingMLP
+from funasr.modules.fastformer import FastSelfAttention
+from funasr.modules.nets_utils import get_activation, make_pad_mask
+from funasr.modules.attention import ( # noqa: H301
+ LegacyRelPositionMultiHeadedAttention,
+ MultiHeadedAttention,
+ RelPositionMultiHeadedAttention,
+)
+from funasr.modules.embedding import ( # noqa: H301
+ LegacyRelPositionalEncoding,
+ PositionalEncoding,
+ RelPositionalEncoding,
+ ScaledPositionalEncoding,
+)
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.positionwise_feed_forward import (
+ PositionwiseFeedForward,
+)
+from funasr.modules.repeat import repeat
+from funasr.modules.subsampling import (
+ Conv2dSubsampling,
+ Conv2dSubsampling2,
+ Conv2dSubsampling6,
+ Conv2dSubsampling8,
+ TooShortUttError,
+ check_short_utt,
+)
+
+
+class EBranchformerEncoderLayer(torch.nn.Module):
+ """E-Branchformer encoder layer module.
+
+ Args:
+ size (int): model dimension
+ attn: standard self-attention or efficient attention
+ cgmlp: ConvolutionalGatingMLP
+ feed_forward: feed-forward module, optional
+ feed_forward: macaron-style feed-forward module, optional
+ dropout_rate (float): dropout probability
+ merge_conv_kernel (int): kernel size of the depth-wise conv in merge module
+ """
+
+ def __init__(
+ self,
+ size: int,
+ attn: torch.nn.Module,
+ cgmlp: torch.nn.Module,
+ feed_forward: Optional[torch.nn.Module],
+ feed_forward_macaron: Optional[torch.nn.Module],
+ dropout_rate: float,
+ merge_conv_kernel: int = 3,
+ ):
+ super().__init__()
+
+ self.size = size
+ self.attn = attn
+ self.cgmlp = cgmlp
+
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.ff_scale = 1.0
+ if self.feed_forward is not None:
+ self.norm_ff = LayerNorm(size)
+ if self.feed_forward_macaron is not None:
+ self.ff_scale = 0.5
+ self.norm_ff_macaron = LayerNorm(size)
+
+ self.norm_mha = LayerNorm(size) # for the MHA module
+ self.norm_mlp = LayerNorm(size) # for the MLP module
+ self.norm_final = LayerNorm(size) # for the final output of the block
+
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ self.depthwise_conv_fusion = torch.nn.Conv1d(
+ size + size,
+ size + size,
+ kernel_size=merge_conv_kernel,
+ stride=1,
+ padding=(merge_conv_kernel - 1) // 2,
+ groups=size + size,
+ bias=True,
+ )
+ self.merge_proj = torch.nn.Linear(size + size, size)
+
+ def forward(self, x_input, mask, cache=None):
+ """Compute encoded features.
+
+ Args:
+ x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
+ - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
+ - w/o pos emb: Tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+ """
+
+ if cache is not None:
+ raise NotImplementedError("cache is not None, which is not tested")
+
+ if isinstance(x_input, tuple):
+ x, pos_emb = x_input[0], x_input[1]
+ else:
+ x, pos_emb = x_input, None
+
+ if self.feed_forward_macaron is not None:
+ residual = x
+ x = self.norm_ff_macaron(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
+
+ # Two branches
+ x1 = x
+ x2 = x
+
+ # Branch 1: multi-headed attention module
+ x1 = self.norm_mha(x1)
+
+ if isinstance(self.attn, FastSelfAttention):
+ x_att = self.attn(x1, mask)
+ else:
+ if pos_emb is not None:
+ x_att = self.attn(x1, x1, x1, pos_emb, mask)
+ else:
+ x_att = self.attn(x1, x1, x1, mask)
+
+ x1 = self.dropout(x_att)
+
+ # Branch 2: convolutional gating mlp
+ x2 = self.norm_mlp(x2)
+
+ if pos_emb is not None:
+ x2 = (x2, pos_emb)
+ x2 = self.cgmlp(x2, mask)
+ if isinstance(x2, tuple):
+ x2 = x2[0]
+
+ x2 = self.dropout(x2)
+
+ # Merge two branches
+ x_concat = torch.cat([x1, x2], dim=-1)
+ x_tmp = x_concat.transpose(1, 2)
+ x_tmp = self.depthwise_conv_fusion(x_tmp)
+ x_tmp = x_tmp.transpose(1, 2)
+ x = x + self.dropout(self.merge_proj(x_concat + x_tmp))
+
+ if self.feed_forward is not None:
+ # feed forward module
+ residual = x
+ x = self.norm_ff(x)
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+
+ x = self.norm_final(x)
+
+ if pos_emb is not None:
+ return (x, pos_emb), mask
+
+ return x, mask
+
+
+class EBranchformerEncoder(AbsEncoder):
+ """E-Branchformer encoder module."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ attention_layer_type: str = "rel_selfattn",
+ pos_enc_layer_type: str = "rel_pos",
+ rel_pos_type: str = "latest",
+ cgmlp_linear_units: int = 2048,
+ cgmlp_conv_kernel: int = 31,
+ use_linear_after_conv: bool = False,
+ gate_activation: str = "identity",
+ num_blocks: int = 12,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ zero_triu: bool = False,
+ padding_idx: int = -1,
+ layer_drop_rate: float = 0.0,
+ max_pos_emb_len: int = 5000,
+ use_ffn: bool = False,
+ macaron_ffn: bool = False,
+ ffn_activation_type: str = "swish",
+ linear_units: int = 2048,
+ positionwise_layer_type: str = "linear",
+ merge_conv_kernel: int = 3,
+ interctc_layer_idx=None,
+ interctc_use_conditioning: bool = False,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ self._output_size = output_size
+
+ if rel_pos_type == "legacy":
+ if pos_enc_layer_type == "rel_pos":
+ pos_enc_layer_type = "legacy_rel_pos"
+ if attention_layer_type == "rel_selfattn":
+ attention_layer_type = "legacy_rel_selfattn"
+ elif rel_pos_type == "latest":
+ assert attention_layer_type != "legacy_rel_selfattn"
+ assert pos_enc_layer_type != "legacy_rel_pos"
+ else:
+ raise ValueError("unknown rel_pos_type: " + rel_pos_type)
+
+ if pos_enc_layer_type == "abs_pos":
+ pos_enc_class = PositionalEncoding
+ elif pos_enc_layer_type == "scaled_abs_pos":
+ pos_enc_class = ScaledPositionalEncoding
+ elif pos_enc_layer_type == "rel_pos":
+ assert attention_layer_type == "rel_selfattn"
+ pos_enc_class = RelPositionalEncoding
+ elif pos_enc_layer_type == "legacy_rel_pos":
+ assert attention_layer_type == "legacy_rel_selfattn"
+ pos_enc_class = LegacyRelPositionalEncoding
+ logging.warning(
+ "Using legacy_rel_pos and it will be deprecated in the future."
+ )
+ else:
+ raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
+
+ if input_layer == "linear":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Linear(input_size, output_size),
+ torch.nn.LayerNorm(output_size),
+ torch.nn.Dropout(dropout_rate),
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
+ )
+ elif input_layer == "conv2d":
+ self.embed = Conv2dSubsampling(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
+ )
+ elif input_layer == "conv2d2":
+ self.embed = Conv2dSubsampling2(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
+ )
+ elif input_layer == "conv2d6":
+ self.embed = Conv2dSubsampling6(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
+ )
+ elif input_layer == "conv2d8":
+ self.embed = Conv2dSubsampling8(
+ input_size,
+ output_size,
+ dropout_rate,
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
+ )
+ elif input_layer == "embed":
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
+ )
+ elif isinstance(input_layer, torch.nn.Module):
+ self.embed = torch.nn.Sequential(
+ input_layer,
+ pos_enc_class(output_size, positional_dropout_rate, max_pos_emb_len),
+ )
+ elif input_layer is None:
+ if input_size == output_size:
+ self.embed = None
+ else:
+ self.embed = torch.nn.Linear(input_size, output_size)
+ else:
+ raise ValueError("unknown input_layer: " + input_layer)
+
+ activation = get_activation(ffn_activation_type)
+ if positionwise_layer_type == "linear":
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ activation,
+ )
+ elif positionwise_layer_type is None:
+ logging.warning("no macaron ffn")
+ else:
+ raise ValueError("Support only linear.")
+
+ if attention_layer_type == "selfattn":
+ encoder_selfattn_layer = MultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+ elif attention_layer_type == "legacy_rel_selfattn":
+ assert pos_enc_layer_type == "legacy_rel_pos"
+ encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ )
+ logging.warning(
+ "Using legacy_rel_selfattn and it will be deprecated in the future."
+ )
+ elif attention_layer_type == "rel_selfattn":
+ assert pos_enc_layer_type == "rel_pos"
+ encoder_selfattn_layer = RelPositionMultiHeadedAttention
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ zero_triu,
+ )
+ elif attention_layer_type == "fast_selfattn":
+ assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"]
+ encoder_selfattn_layer = FastSelfAttention
+ encoder_selfattn_layer_args = (
+ output_size,
+ attention_heads,
+ attention_dropout_rate,
+ )
+ else:
+ raise ValueError("unknown encoder_attn_layer: " + attention_layer_type)
+
+ cgmlp_layer = ConvolutionalGatingMLP
+ cgmlp_layer_args = (
+ output_size,
+ cgmlp_linear_units,
+ cgmlp_conv_kernel,
+ dropout_rate,
+ use_linear_after_conv,
+ gate_activation,
+ )
+
+ self.encoders = repeat(
+ num_blocks,
+ lambda lnum: EBranchformerEncoderLayer(
+ output_size,
+ encoder_selfattn_layer(*encoder_selfattn_layer_args),
+ cgmlp_layer(*cgmlp_layer_args),
+ positionwise_layer(*positionwise_layer_args) if use_ffn else None,
+ positionwise_layer(*positionwise_layer_args)
+ if use_ffn and macaron_ffn
+ else None,
+ dropout_rate,
+ merge_conv_kernel,
+ ),
+ layer_drop_rate,
+ )
+ self.after_norm = LayerNorm(output_size)
+
+ if interctc_layer_idx is None:
+ interctc_layer_idx = []
+ self.interctc_layer_idx = interctc_layer_idx
+ if len(interctc_layer_idx) > 0:
+ assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+ self.interctc_use_conditioning = interctc_use_conditioning
+ self.conditioning_layer = None
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ max_layer: int = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Calculate forward propagation.
+
+ Args:
+ xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
+ ilens (torch.Tensor): Input length (#batch).
+ prev_states (torch.Tensor): Not to be used now.
+ ctc (CTC): Intermediate CTC module.
+ max_layer (int): Layer depth below which InterCTC is applied.
+ Returns:
+ torch.Tensor: Output tensor (#batch, L, output_size).
+ torch.Tensor: Output length (#batch).
+ torch.Tensor: Not to be used now.
+ """
+
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+
+ if (
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
+ ):
+ short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+ if short_status:
+ raise TooShortUttError(
+ f"has {xs_pad.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ xs_pad.size(1),
+ limit_size,
+ )
+ xs_pad, masks = self.embed(xs_pad, masks)
+ elif self.embed is not None:
+ xs_pad = self.embed(xs_pad)
+
+ intermediate_outs = []
+ if len(self.interctc_layer_idx) == 0:
+ if max_layer is not None and 0 <= max_layer < len(self.encoders):
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ xs_pad, masks = encoder_layer(xs_pad, masks)
+ if layer_idx >= max_layer:
+ break
+ else:
+ xs_pad, masks = self.encoders(xs_pad, masks)
+ else:
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ xs_pad, masks = encoder_layer(xs_pad, masks)
+
+ if layer_idx + 1 in self.interctc_layer_idx:
+ encoder_out = xs_pad
+
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ intermediate_outs.append((layer_idx + 1, encoder_out))
+
+ if self.interctc_use_conditioning:
+ ctc_out = ctc.softmax(encoder_out)
+
+ if isinstance(xs_pad, tuple):
+ xs_pad = list(xs_pad)
+ xs_pad[0] = xs_pad[0] + self.conditioning_layer(ctc_out)
+ xs_pad = tuple(xs_pad)
+ else:
+ xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+ if isinstance(xs_pad, tuple):
+ xs_pad = xs_pad[0]
+
+ xs_pad = self.after_norm(xs_pad)
+ olens = masks.squeeze(1).sum(1)
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), olens, None
+ return xs_pad, olens, None
--
Gitblit v1.9.1