From 964a50d24613998c0dcfa802285d1b5ab54263f4 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 05 七月 2023 20:36:27 +0800
Subject: [PATCH] update

---
 funasr/models/encoder/branchformer_encoder.py             |  547 +++++++++++++++++++++++
 egs/aishell/branchformer/conf/decode_asr_transformer.yaml |    6 
 egs/aishell/branchformer/path.sh                          |    5 
 egs/aishell/branchformer/run.sh                           |  225 +++++++++
 egs/aishell/branchformer/local/download_and_untar.sh      |  105 ++++
 funasr/modules/cgmlp.py                                   |  124 +++++
 egs/aishell/branchformer/README.md                        |   17 
 egs/aishell/branchformer/conf/train_asr_branchformer.yaml |   86 +++
 egs/aishell/branchformer/local/aishell_data_prep.sh       |   66 ++
 egs/aishell/branchformer/utils                            |    1 
 funasr/modules/fastformer.py                              |  153 ++++++
 11 files changed, 1,335 insertions(+), 0 deletions(-)

diff --git a/egs/aishell/branchformer/README.md b/egs/aishell/branchformer/README.md
new file mode 100644
index 0000000..a67b183
--- /dev/null
+++ b/egs/aishell/branchformer/README.md
@@ -0,0 +1,17 @@
+
+# Conformer Result
+
+## Training Config
+- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train info: lr 5e-4, batch_size 25000, 2 gpu(Tesla V100), acc_grad 1, 50 epochs
+- Train config: conf/train_asr_transformer.yaml
+- LM config: LM was not used
+- Model size: 46M
+
+## Results (CER)
+- Decode config: conf/decode_asr_transformer.yaml (ctc weight:0.5)
+
+|   testset   | CER(%)  |
+|:-----------:|:-------:|
+|     dev     |  4.42   |
+|    test     |  4.87   |
\ No newline at end of file
diff --git a/egs/aishell/branchformer/conf/decode_asr_transformer.yaml b/egs/aishell/branchformer/conf/decode_asr_transformer.yaml
new file mode 100644
index 0000000..e87a293
--- /dev/null
+++ b/egs/aishell/branchformer/conf/decode_asr_transformer.yaml
@@ -0,0 +1,6 @@
+beam_size: 10
+penalty: 0.0
+maxlenratio: 0.0
+minlenratio: 0.0
+ctc_weight: 0.4
+lm_weight: 0.0
diff --git a/egs/aishell/branchformer/conf/train_asr_branchformer.yaml b/egs/aishell/branchformer/conf/train_asr_branchformer.yaml
new file mode 100644
index 0000000..cb8ac3b
--- /dev/null
+++ b/egs/aishell/branchformer/conf/train_asr_branchformer.yaml
@@ -0,0 +1,86 @@
+# network architecture
+# encoder related
+encoder: branchformer
+encoder_conf:
+    output_size: 256
+    use_attn: true
+    attention_heads: 4
+    attention_layer_type: rel_selfattn
+    pos_enc_layer_type: rel_pos
+    rel_pos_type: latest
+    use_cgmlp: true
+    cgmlp_linear_units: 2048
+    cgmlp_conv_kernel: 31
+    use_linear_after_conv: false
+    gate_activation: identity
+    merge_method: concat
+    cgmlp_weight: 0.5               # used only if merge_method is "fixed_ave"
+    attn_branch_drop_rate: 0.0      # used only if merge_method is "learned_ave"
+    num_blocks: 24
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.1
+    input_layer: conv2d
+    stochastic_depth_rate: 0.0
+
+# decoder related
+decoder: transformer
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 6
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.
+    src_attention_dropout_rate: 0.
+
+# hybrid CTC/attention
+model_conf:
+    ctc_weight: 0.3
+    lsm_weight: 0.1     # label smoothing option
+    length_normalized_loss: false
+
+# minibatch related
+batch_type: numel
+batch_bins: 25000000
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 60
+val_scheduler_criterion:
+    - valid
+    - acc
+best_model_criterion:
+-   - valid
+    - acc
+    - max
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+   lr: 0.001
+   weight_decay: 0.000001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 35000
+
+num_workers: 4      # num of workers of data loader
+use_amp: true      # automatic mixed precision
+unused_parameters: false    # set as true if some params are unused in DDP
+
+specaug: specaug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 27
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_ratio_range:
+    - 0.
+    - 0.05
+    num_time_mask: 10
diff --git a/egs/aishell/branchformer/local/aishell_data_prep.sh b/egs/aishell/branchformer/local/aishell_data_prep.sh
new file mode 100755
index 0000000..83f489b
--- /dev/null
+++ b/egs/aishell/branchformer/local/aishell_data_prep.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+
+# Copyright 2017 Xingyu Na
+# Apache 2.0
+
+#. ./path.sh || exit 1;
+
+if [ $# != 3 ]; then
+  echo "Usage: $0 <audio-path> <text-path> <output-path>"
+  echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
+  exit 1;
+fi
+
+aishell_audio_dir=$1
+aishell_text=$2/aishell_transcript_v0.8.txt
+output_dir=$3
+
+train_dir=$output_dir/data/local/train
+dev_dir=$output_dir/data/local/dev
+test_dir=$output_dir/data/local/test
+tmp_dir=$output_dir/data/local/tmp
+
+mkdir -p $train_dir
+mkdir -p $dev_dir
+mkdir -p $test_dir
+mkdir -p $tmp_dir
+
+# data directory check
+if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
+  echo "Error: $0 requires two directory arguments"
+  exit 1;
+fi
+
+# find wav audio file for train, dev and test resp.
+find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
+n=`cat $tmp_dir/wav.flist | wc -l`
+[ $n -ne 141925 ] && \
+  echo Warning: expected 141925 data data files, found $n
+
+grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
+grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
+grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
+
+rm -r $tmp_dir
+
+# Transcriptions preparation
+for dir in $train_dir $dev_dir $test_dir; do
+  echo Preparing $dir transcriptions
+  sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
+  paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
+  utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
+  awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
+  utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
+  sort -u $dir/transcripts.txt > $dir/text
+done
+
+mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
+
+for f in wav.scp text; do
+  cp $train_dir/$f $output_dir/data/train/$f || exit 1;
+  cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
+  cp $test_dir/$f $output_dir/data/test/$f || exit 1;
+done
+
+echo "$0: AISHELL data preparation succeeded"
+exit 0;
diff --git a/egs/aishell/branchformer/local/download_and_untar.sh b/egs/aishell/branchformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/egs/aishell/branchformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+#             2017  Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1;
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tgz
+  else
+    echo "$data/$part.tgz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+  if ! command -v wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1;
+  fi
+  full_url=$url/$part.tgz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  cd $data || exit 1
+  if ! wget --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1;
+  fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+  echo "$0: error un-tarring archive $data/$part.tgz"
+  exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+  cd $data/$part/wav || exit 1
+  for wav in ./*.tar.gz; do
+    echo "Extracting wav from $wav"
+    tar -zxf $wav && rm $wav
+  done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+  rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/egs/aishell/branchformer/path.sh b/egs/aishell/branchformer/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/egs/aishell/branchformer/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/egs/aishell/branchformer/run.sh b/egs/aishell/branchformer/run.sh
new file mode 100755
index 0000000..b5aeae7
--- /dev/null
+++ b/egs/aishell/branchformer/run.sh
@@ -0,0 +1,225 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=2
+count=1
+gpu_inference=true  # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir="../DATA" #feature output dictionary
+exp_dir="."
+lang=zh
+token_type=char
+type=sound
+scp=wav.scp
+speed_perturb="0.9 1.0 1.1"
+stage=0
+stop_stage=5
+
+# feature configuration
+feats_dim=80
+nj=64
+
+# data
+raw_data=../raw_data
+data_url=www.openslr.org/resources/33
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train
+valid_set=dev
+test_sets="dev test"
+
+asr_config=conf/train_asr_branchformer.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${lang}_${token_type}_${tag}"
+
+inference_config=conf/decode_asr_transformer.yaml
+inference_asr_model=valid.acc.ave_10best.pb
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES  # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+    inference_nj=$[${ngpu}*${njob}]
+    _ngpu=1
+else
+    inference_nj=$njob
+    _ngpu=0
+fi
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    echo "stage -1: Data Download"
+    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+    echo "stage 0: Data preparation"
+    # Data preparation
+    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
+    for x in train dev test; do
+        cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+        paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
+            > ${feats_dir}/data/${x}/text
+        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+        mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
+    done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    echo "stage 1: Feature and CMVN Generation"
+    utils/compute_cmvn.sh --fbankdir ${feats_dir}/data/${train_set} --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --config_file "$asr_config" --scale 1.0
+fi
+
+token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+    echo "stage 2: Dictionary Preparation"
+    mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/
+
+    echo "make a dictionary"
+    echo "<blank>" > ${token_list}
+    echo "<s>" >> ${token_list}
+    echo "</s>" >> ${token_list}
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
+        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+    echo "<unk>" >> ${token_list}
+fi
+
+# LM Training Stage
+world_size=$gpu_num  # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+world_size=$gpu_num  # run on one machine
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+    echo "stage 4: ASR Training"
+    mkdir -p ${exp_dir}/exp/${model_dir}
+    mkdir -p ${exp_dir}/exp/${model_dir}/log
+    INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+    if [ -f $INIT_FILE ];then
+        rm -f $INIT_FILE
+    fi 
+    init_method=file://$(readlink -f $INIT_FILE)
+    echo "$0: init method is $init_method"
+    for ((i = 0; i < $gpu_num; ++i)); do
+        {
+            rank=$i
+            local_rank=$i
+            gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+            train.py \
+                --task_name asr \
+                --gpu_id $gpu_id \
+                --use_preprocessor true \
+                --token_type $token_type \
+                --token_list $token_list \
+                --data_dir ${feats_dir}/data \
+                --train_set ${train_set} \
+                --valid_set ${valid_set} \
+                --data_file_names "wav.scp,text" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
+                --speed_perturb ${speed_perturb} \
+                --resume true \
+                --output_dir ${exp_dir}/exp/${model_dir} \
+                --config $asr_config \
+                --ngpu $gpu_num \
+                --num_worker_count $count \
+                --dist_init_method $init_method \
+                --dist_world_size $world_size \
+                --dist_rank $rank \
+                --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+        } &
+        done
+        wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+    echo "stage 5: Inference"
+    for dset in ${test_sets}; do
+        asr_exp=${exp_dir}/exp/${model_dir}
+        inference_tag="$(basename "${inference_config}" .yaml)"
+        _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+        _logdir="${_dir}/logdir"
+        if [ -d ${_dir} ]; then
+            echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+            exit 0
+        fi
+        mkdir -p "${_logdir}"
+        _data="${feats_dir}/data/${dset}"
+        key_file=${_data}/${scp}
+        num_scp_file="$(<${key_file} wc -l)"
+        _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+        split_scps=
+        for n in $(seq "${_nj}"); do
+            split_scps+=" ${_logdir}/keys.${n}.scp"
+        done
+        # shellcheck disable=SC2086
+        utils/split_scp.pl "${key_file}" ${split_scps}
+        _opts=
+        if [ -n "${inference_config}" ]; then
+            _opts+="--config ${inference_config} "
+        fi
+        ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+            python -m funasr.bin.asr_inference_launch \
+                --batch_size 1 \
+                --ngpu "${_ngpu}" \
+                --njob ${njob} \
+                --gpuid_list ${gpuid_list} \
+                --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+                --cmvn_file ${feats_dir}/data/${train_set}/cmvn/am.mvn \
+                --key_file "${_logdir}"/keys.JOB.scp \
+                --asr_train_config "${asr_exp}"/config.yaml \
+                --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+                --output_dir "${_logdir}"/output.JOB \
+                --mode asr \
+                ${_opts}
+
+        for f in token token_int score text; do
+            if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+                for i in $(seq "${_nj}"); do
+                    cat "${_logdir}/output.${i}/1best_recog/${f}"
+                done | sort -k1 >"${_dir}/${f}"
+            fi
+        done
+        python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
+        python utils/proce_text.py ${_data}/text ${_data}/text.proc
+        python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+        tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+        cat ${_dir}/text.cer.txt
+    done
+fi
+
+# Prepare files for ModelScope fine-tuning and inference
+if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
+    echo "stage 6: ModelScope Preparation"
+    cp ${feats_dir}/data/${train_set}/cmvn/am.mvn ${exp_dir}/exp/${model_dir}/am.mvn
+    vocab_size=$(cat ${token_list} | wc -l)
+    python utils/gen_modelscope_configuration.py \
+        --am_model_name $inference_asr_model \
+        --mode asr \
+        --model_name conformer \
+        --dataset aishell \
+        --output_dir $exp_dir/exp/$model_dir \
+        --vocab_size $vocab_size \
+        --tag $tag
+fi
\ No newline at end of file
diff --git a/egs/aishell/branchformer/utils b/egs/aishell/branchformer/utils
new file mode 120000
index 0000000..4072eac
--- /dev/null
+++ b/egs/aishell/branchformer/utils
@@ -0,0 +1 @@
+../transformer/utils
\ No newline at end of file
diff --git a/funasr/models/encoder/branchformer_encoder.py b/funasr/models/encoder/branchformer_encoder.py
new file mode 100644
index 0000000..2035ce9
--- /dev/null
+++ b/funasr/models/encoder/branchformer_encoder.py
@@ -0,0 +1,547 @@
+# Copyright 2022 Yifan Peng (Carnegie Mellon University)
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Branchformer encoder definition.
+
+Reference:
+    Yifan Peng, Siddharth Dalmia, Ian Lane, and Shinji Watanabe,
+    鈥淏ranchformer: Parallel MLP-Attention Architectures to Capture
+    Local and Global Context for Speech Recognition and Understanding,鈥�
+    in Proceedings of ICML, 2022.
+
+"""
+
+import logging
+from typing import List, Optional, Tuple, Union
+
+import numpy
+import torch
+from typeguard import check_argument_types
+
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.modules.cgmlp import ConvolutionalGatingMLP
+from funasr.modules.fastformer import FastSelfAttention
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.attention import (  # noqa: H301
+    LegacyRelPositionMultiHeadedAttention,
+    MultiHeadedAttention,
+    RelPositionMultiHeadedAttention,
+)
+from funasr.modules.embedding import (  # noqa: H301
+    LegacyRelPositionalEncoding,
+    PositionalEncoding,
+    RelPositionalEncoding,
+    ScaledPositionalEncoding,
+)
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.repeat import repeat
+from funasr.modules.subsampling import (
+    Conv2dSubsampling,
+    Conv2dSubsampling2,
+    Conv2dSubsampling6,
+    Conv2dSubsampling8,
+    TooShortUttError,
+    check_short_utt,
+)
+
+
+class BranchformerEncoderLayer(torch.nn.Module):
+    """Branchformer encoder layer module.
+
+    Args:
+        size (int): model dimension
+        attn: standard self-attention or efficient attention, optional
+        cgmlp: ConvolutionalGatingMLP, optional
+        dropout_rate (float): dropout probability
+        merge_method (str): concat, learned_ave, fixed_ave
+        cgmlp_weight (float): weight of the cgmlp branch, between 0 and 1,
+            used if merge_method is fixed_ave
+        attn_branch_drop_rate (float): probability of dropping the attn branch,
+            used if merge_method is learned_ave
+        stochastic_depth_rate (float): stochastic depth probability
+    """
+
+    def __init__(
+        self,
+        size: int,
+        attn: Optional[torch.nn.Module],
+        cgmlp: Optional[torch.nn.Module],
+        dropout_rate: float,
+        merge_method: str,
+        cgmlp_weight: float = 0.5,
+        attn_branch_drop_rate: float = 0.0,
+        stochastic_depth_rate: float = 0.0,
+    ):
+        super().__init__()
+        assert (attn is not None) or (
+            cgmlp is not None
+        ), "At least one branch should be valid"
+
+        self.size = size
+        self.attn = attn
+        self.cgmlp = cgmlp
+        self.merge_method = merge_method
+        self.cgmlp_weight = cgmlp_weight
+        self.attn_branch_drop_rate = attn_branch_drop_rate
+        self.stochastic_depth_rate = stochastic_depth_rate
+        self.use_two_branches = (attn is not None) and (cgmlp is not None)
+
+        if attn is not None:
+            self.norm_mha = LayerNorm(size)  # for the MHA module
+        if cgmlp is not None:
+            self.norm_mlp = LayerNorm(size)  # for the MLP module
+        self.norm_final = LayerNorm(size)  # for the final output of the block
+
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+        if self.use_two_branches:
+            if merge_method == "concat":
+                self.merge_proj = torch.nn.Linear(size + size, size)
+
+            elif merge_method == "learned_ave":
+                # attention-based pooling for two branches
+                self.pooling_proj1 = torch.nn.Linear(size, 1)
+                self.pooling_proj2 = torch.nn.Linear(size, 1)
+
+                # linear projections for calculating merging weights
+                self.weight_proj1 = torch.nn.Linear(size, 1)
+                self.weight_proj2 = torch.nn.Linear(size, 1)
+
+                # linear projection after weighted average
+                self.merge_proj = torch.nn.Linear(size, size)
+
+            elif merge_method == "fixed_ave":
+                assert (
+                    0.0 <= cgmlp_weight <= 1.0
+                ), "cgmlp weight should be between 0.0 and 1.0"
+
+                # remove the other branch if only one branch is used
+                if cgmlp_weight == 0.0:
+                    self.use_two_branches = False
+                    self.cgmlp = None
+                    self.norm_mlp = None
+                elif cgmlp_weight == 1.0:
+                    self.use_two_branches = False
+                    self.attn = None
+                    self.norm_mha = None
+
+                # linear projection after weighted average
+                self.merge_proj = torch.nn.Linear(size, size)
+
+            else:
+                raise ValueError(f"unknown merge method: {merge_method}")
+
+        else:
+            self.merge_proj = torch.nn.Identity()
+
+    def forward(self, x_input, mask, cache=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb.
+                - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)].
+                - w/o pos emb: Tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, 1, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+        """
+
+        if cache is not None:
+            raise NotImplementedError("cache is not None, which is not tested")
+
+        if isinstance(x_input, tuple):
+            x, pos_emb = x_input[0], x_input[1]
+        else:
+            x, pos_emb = x_input, None
+
+        skip_layer = False
+        # with stochastic depth, residual connection `x + f(x)` becomes
+        # `x <- x + 1 / (1 - p) * f(x)` at training time.
+        stoch_layer_coeff = 1.0
+        if self.training and self.stochastic_depth_rate > 0:
+            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+        if skip_layer:
+            if cache is not None:
+                x = torch.cat([cache, x], dim=1)
+            if pos_emb is not None:
+                return (x, pos_emb), mask
+            return x, mask
+
+        # Two branches
+        x1 = x
+        x2 = x
+
+        # Branch 1: multi-headed attention module
+        if self.attn is not None:
+            x1 = self.norm_mha(x1)
+
+            if isinstance(self.attn, FastSelfAttention):
+                x_att = self.attn(x1, mask)
+            else:
+                if pos_emb is not None:
+                    x_att = self.attn(x1, x1, x1, pos_emb, mask)
+                else:
+                    x_att = self.attn(x1, x1, x1, mask)
+
+            x1 = self.dropout(x_att)
+
+        # Branch 2: convolutional gating mlp
+        if self.cgmlp is not None:
+            x2 = self.norm_mlp(x2)
+
+            if pos_emb is not None:
+                x2 = (x2, pos_emb)
+            x2 = self.cgmlp(x2, mask)
+            if isinstance(x2, tuple):
+                x2 = x2[0]
+
+            x2 = self.dropout(x2)
+
+        # Merge two branches
+        if self.use_two_branches:
+            if self.merge_method == "concat":
+                x = x + stoch_layer_coeff * self.dropout(
+                    self.merge_proj(torch.cat([x1, x2], dim=-1))
+                )
+            elif self.merge_method == "learned_ave":
+                if (
+                    self.training
+                    and self.attn_branch_drop_rate > 0
+                    and torch.rand(1).item() < self.attn_branch_drop_rate
+                ):
+                    # Drop the attn branch
+                    w1, w2 = 0.0, 1.0
+                else:
+                    # branch1
+                    score1 = (
+                        self.pooling_proj1(x1).transpose(1, 2) / self.size**0.5
+                    )  # (batch, 1, time)
+                    if mask is not None:
+                        min_value = float(
+                            numpy.finfo(
+                                torch.tensor(0, dtype=score1.dtype).numpy().dtype
+                            ).min
+                        )
+                        score1 = score1.masked_fill(mask.eq(0), min_value)
+                        score1 = torch.softmax(score1, dim=-1).masked_fill(
+                            mask.eq(0), 0.0
+                        )
+                    else:
+                        score1 = torch.softmax(score1, dim=-1)
+                    pooled1 = torch.matmul(score1, x1).squeeze(1)  # (batch, size)
+                    weight1 = self.weight_proj1(pooled1)  # (batch, 1)
+
+                    # branch2
+                    score2 = (
+                        self.pooling_proj2(x2).transpose(1, 2) / self.size**0.5
+                    )  # (batch, 1, time)
+                    if mask is not None:
+                        min_value = float(
+                            numpy.finfo(
+                                torch.tensor(0, dtype=score2.dtype).numpy().dtype
+                            ).min
+                        )
+                        score2 = score2.masked_fill(mask.eq(0), min_value)
+                        score2 = torch.softmax(score2, dim=-1).masked_fill(
+                            mask.eq(0), 0.0
+                        )
+                    else:
+                        score2 = torch.softmax(score2, dim=-1)
+                    pooled2 = torch.matmul(score2, x2).squeeze(1)  # (batch, size)
+                    weight2 = self.weight_proj2(pooled2)  # (batch, 1)
+
+                    # normalize weights of two branches
+                    merge_weights = torch.softmax(
+                        torch.cat([weight1, weight2], dim=-1), dim=-1
+                    )  # (batch, 2)
+                    merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
+                        -1
+                    )  # (batch, 2, 1, 1)
+                    w1, w2 = merge_weights[:, 0], merge_weights[:, 1]  # (batch, 1, 1)
+
+                x = x + stoch_layer_coeff * self.dropout(
+                    self.merge_proj(w1 * x1 + w2 * x2)
+                )
+            elif self.merge_method == "fixed_ave":
+                x = x + stoch_layer_coeff * self.dropout(
+                    self.merge_proj(
+                        (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2
+                    )
+                )
+            else:
+                raise RuntimeError(f"unknown merge method: {self.merge_method}")
+        else:
+            if self.attn is None:
+                x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x2))
+            elif self.cgmlp is None:
+                x = x + stoch_layer_coeff * self.dropout(self.merge_proj(x1))
+            else:
+                # This should not happen
+                raise RuntimeError("Both branches are not None, which is unexpected.")
+
+        x = self.norm_final(x)
+
+        if pos_emb is not None:
+            return (x, pos_emb), mask
+
+        return x, mask
+
+
+class BranchformerEncoder(AbsEncoder):
+    """Branchformer encoder module."""
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        use_attn: bool = True,
+        attention_heads: int = 4,
+        attention_layer_type: str = "rel_selfattn",
+        pos_enc_layer_type: str = "rel_pos",
+        rel_pos_type: str = "latest",
+        use_cgmlp: bool = True,
+        cgmlp_linear_units: int = 2048,
+        cgmlp_conv_kernel: int = 31,
+        use_linear_after_conv: bool = False,
+        gate_activation: str = "identity",
+        merge_method: str = "concat",
+        cgmlp_weight: Union[float, List[float]] = 0.5,
+        attn_branch_drop_rate: Union[float, List[float]] = 0.0,
+        num_blocks: int = 12,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        zero_triu: bool = False,
+        padding_idx: int = -1,
+        stochastic_depth_rate: Union[float, List[float]] = 0.0,
+    ):
+        assert check_argument_types()
+        super().__init__()
+        self._output_size = output_size
+
+        if rel_pos_type == "legacy":
+            if pos_enc_layer_type == "rel_pos":
+                pos_enc_layer_type = "legacy_rel_pos"
+            if attention_layer_type == "rel_selfattn":
+                attention_layer_type = "legacy_rel_selfattn"
+        elif rel_pos_type == "latest":
+            assert attention_layer_type != "legacy_rel_selfattn"
+            assert pos_enc_layer_type != "legacy_rel_pos"
+        else:
+            raise ValueError("unknown rel_pos_type: " + rel_pos_type)
+
+        if pos_enc_layer_type == "abs_pos":
+            pos_enc_class = PositionalEncoding
+        elif pos_enc_layer_type == "scaled_abs_pos":
+            pos_enc_class = ScaledPositionalEncoding
+        elif pos_enc_layer_type == "rel_pos":
+            assert attention_layer_type == "rel_selfattn"
+            pos_enc_class = RelPositionalEncoding
+        elif pos_enc_layer_type == "legacy_rel_pos":
+            assert attention_layer_type == "legacy_rel_selfattn"
+            pos_enc_class = LegacyRelPositionalEncoding
+            logging.warning(
+                "Using legacy_rel_pos and it will be deprecated in the future."
+            )
+        else:
+            raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
+
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(
+                input_size,
+                output_size,
+                dropout_rate,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif isinstance(input_layer, torch.nn.Module):
+            self.embed = torch.nn.Sequential(
+                input_layer,
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+
+        if attention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+        elif attention_layer_type == "legacy_rel_selfattn":
+            assert pos_enc_layer_type == "legacy_rel_pos"
+            encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+            logging.warning(
+                "Using legacy_rel_selfattn and it will be deprecated in the future."
+            )
+        elif attention_layer_type == "rel_selfattn":
+            assert pos_enc_layer_type == "rel_pos"
+            encoder_selfattn_layer = RelPositionMultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+                zero_triu,
+            )
+        elif attention_layer_type == "fast_selfattn":
+            assert pos_enc_layer_type in ["abs_pos", "scaled_abs_pos"]
+            encoder_selfattn_layer = FastSelfAttention
+            encoder_selfattn_layer_args = (
+                output_size,
+                attention_heads,
+                attention_dropout_rate,
+            )
+        else:
+            raise ValueError("unknown encoder_attn_layer: " + attention_layer_type)
+
+        cgmlp_layer = ConvolutionalGatingMLP
+        cgmlp_layer_args = (
+            output_size,
+            cgmlp_linear_units,
+            cgmlp_conv_kernel,
+            dropout_rate,
+            use_linear_after_conv,
+            gate_activation,
+        )
+
+        if isinstance(stochastic_depth_rate, float):
+            stochastic_depth_rate = [stochastic_depth_rate] * num_blocks
+        if len(stochastic_depth_rate) != num_blocks:
+            raise ValueError(
+                f"Length of stochastic_depth_rate ({len(stochastic_depth_rate)}) "
+                f"should be equal to num_blocks ({num_blocks})"
+            )
+
+        if isinstance(cgmlp_weight, float):
+            cgmlp_weight = [cgmlp_weight] * num_blocks
+        if len(cgmlp_weight) != num_blocks:
+            raise ValueError(
+                f"Length of cgmlp_weight ({len(cgmlp_weight)}) should be equal to "
+                f"num_blocks ({num_blocks})"
+            )
+
+        if isinstance(attn_branch_drop_rate, float):
+            attn_branch_drop_rate = [attn_branch_drop_rate] * num_blocks
+        if len(attn_branch_drop_rate) != num_blocks:
+            raise ValueError(
+                f"Length of attn_branch_drop_rate ({len(attn_branch_drop_rate)}) "
+                f"should be equal to num_blocks ({num_blocks})"
+            )
+
+        self.encoders = repeat(
+            num_blocks,
+            lambda lnum: BranchformerEncoderLayer(
+                output_size,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args)
+                if use_attn
+                else None,
+                cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
+                dropout_rate,
+                merge_method,
+                cgmlp_weight[lnum],
+                attn_branch_drop_rate[lnum],
+                stochastic_depth_rate[lnum],
+            ),
+        )
+        self.after_norm = LayerNorm(output_size)
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        prev_states: torch.Tensor = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Calculate forward propagation.
+
+        Args:
+            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
+            ilens (torch.Tensor): Input length (#batch).
+            prev_states (torch.Tensor): Not to be used now.
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, L, output_size).
+            torch.Tensor: Output length (#batch).
+            torch.Tensor: Not to be used now.
+
+        """
+
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+
+        if (
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling2)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        elif self.embed is not None:
+            xs_pad = self.embed(xs_pad)
+
+        xs_pad, masks = self.encoders(xs_pad, masks)
+
+        if isinstance(xs_pad, tuple):
+            xs_pad = xs_pad[0]
+
+        xs_pad = self.after_norm(xs_pad)
+        olens = masks.squeeze(1).sum(1)
+        return xs_pad, olens, None
diff --git a/funasr/modules/cgmlp.py b/funasr/modules/cgmlp.py
new file mode 100644
index 0000000..dcd085b
--- /dev/null
+++ b/funasr/modules/cgmlp.py
@@ -0,0 +1,124 @@
+"""MLP with convolutional gating (cgMLP) definition.
+
+References:
+    https://openreview.net/forum?id=RA-zVvZLYIy
+    https://arxiv.org/abs/2105.08050
+
+"""
+
+import torch
+
+from funasr.modules.nets_utils import get_activation
+from funasr.modules.layer_norm import LayerNorm
+
+
+class ConvolutionalSpatialGatingUnit(torch.nn.Module):
+    """Convolutional Spatial Gating Unit (CSGU)."""
+
+    def __init__(
+        self,
+        size: int,
+        kernel_size: int,
+        dropout_rate: float,
+        use_linear_after_conv: bool,
+        gate_activation: str,
+    ):
+        super().__init__()
+
+        n_channels = size // 2  # split input channels
+        self.norm = LayerNorm(n_channels)
+        self.conv = torch.nn.Conv1d(
+            n_channels,
+            n_channels,
+            kernel_size,
+            1,
+            (kernel_size - 1) // 2,
+            groups=n_channels,
+        )
+        if use_linear_after_conv:
+            self.linear = torch.nn.Linear(n_channels, n_channels)
+        else:
+            self.linear = None
+
+        if gate_activation == "identity":
+            self.act = torch.nn.Identity()
+        else:
+            self.act = get_activation(gate_activation)
+
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+    def espnet_initialization_fn(self):
+        torch.nn.init.normal_(self.conv.weight, std=1e-6)
+        torch.nn.init.ones_(self.conv.bias)
+        if self.linear is not None:
+            torch.nn.init.normal_(self.linear.weight, std=1e-6)
+            torch.nn.init.ones_(self.linear.bias)
+
+    def forward(self, x, gate_add=None):
+        """Forward method
+
+        Args:
+            x (torch.Tensor): (N, T, D)
+            gate_add (torch.Tensor): (N, T, D/2)
+
+        Returns:
+            out (torch.Tensor): (N, T, D/2)
+        """
+
+        x_r, x_g = x.chunk(2, dim=-1)
+
+        x_g = self.norm(x_g)  # (N, T, D/2)
+        x_g = self.conv(x_g.transpose(1, 2)).transpose(1, 2)  # (N, T, D/2)
+        if self.linear is not None:
+            x_g = self.linear(x_g)
+
+        if gate_add is not None:
+            x_g = x_g + gate_add
+
+        x_g = self.act(x_g)
+        out = x_r * x_g  # (N, T, D/2)
+        out = self.dropout(out)
+        return out
+
+
+class ConvolutionalGatingMLP(torch.nn.Module):
+    """Convolutional Gating MLP (cgMLP)."""
+
+    def __init__(
+        self,
+        size: int,
+        linear_units: int,
+        kernel_size: int,
+        dropout_rate: float,
+        use_linear_after_conv: bool,
+        gate_activation: str,
+    ):
+        super().__init__()
+
+        self.channel_proj1 = torch.nn.Sequential(
+            torch.nn.Linear(size, linear_units), torch.nn.GELU()
+        )
+        self.csgu = ConvolutionalSpatialGatingUnit(
+            size=linear_units,
+            kernel_size=kernel_size,
+            dropout_rate=dropout_rate,
+            use_linear_after_conv=use_linear_after_conv,
+            gate_activation=gate_activation,
+        )
+        self.channel_proj2 = torch.nn.Linear(linear_units // 2, size)
+
+    def forward(self, x, mask):
+        if isinstance(x, tuple):
+            xs_pad, pos_emb = x
+        else:
+            xs_pad, pos_emb = x, None
+
+        xs_pad = self.channel_proj1(xs_pad)  # size -> linear_units
+        xs_pad = self.csgu(xs_pad)  # linear_units -> linear_units/2
+        xs_pad = self.channel_proj2(xs_pad)  # linear_units/2 -> size
+
+        if pos_emb is not None:
+            out = (xs_pad, pos_emb)
+        else:
+            out = xs_pad
+        return out
diff --git a/funasr/modules/fastformer.py b/funasr/modules/fastformer.py
new file mode 100644
index 0000000..24ca947
--- /dev/null
+++ b/funasr/modules/fastformer.py
@@ -0,0 +1,153 @@
+"""Fastformer attention definition.
+
+Reference:
+    Wu et al., "Fastformer: Additive Attention Can Be All You Need"
+    https://arxiv.org/abs/2108.09084
+    https://github.com/wuch15/Fastformer
+
+"""
+
+import numpy
+import torch
+
+
+class FastSelfAttention(torch.nn.Module):
+    """Fast self-attention used in Fastformer."""
+
+    def __init__(
+        self,
+        size,
+        attention_heads,
+        dropout_rate,
+    ):
+        super().__init__()
+        if size % attention_heads != 0:
+            raise ValueError(
+                f"Hidden size ({size}) is not an integer multiple "
+                f"of attention heads ({attention_heads})"
+            )
+        self.attention_head_size = size // attention_heads
+        self.num_attention_heads = attention_heads
+
+        self.query = torch.nn.Linear(size, size)
+        self.query_att = torch.nn.Linear(size, attention_heads)
+        self.key = torch.nn.Linear(size, size)
+        self.key_att = torch.nn.Linear(size, attention_heads)
+        self.transform = torch.nn.Linear(size, size)
+        self.dropout = torch.nn.Dropout(dropout_rate)
+
+    def espnet_initialization_fn(self):
+        self.apply(self.init_weights)
+
+    def init_weights(self, module):
+        if isinstance(module, torch.nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=0.02)
+        if isinstance(module, torch.nn.Linear) and module.bias is not None:
+            module.bias.data.zero_()
+
+    def transpose_for_scores(self, x):
+        """Reshape and transpose to compute scores.
+
+        Args:
+            x: (batch, time, size = n_heads * attn_dim)
+
+        Returns:
+            (batch, n_heads, time, attn_dim)
+        """
+
+        new_x_shape = x.shape[:-1] + (
+            self.num_attention_heads,
+            self.attention_head_size,
+        )
+        return x.reshape(*new_x_shape).transpose(1, 2)
+
+    def forward(self, xs_pad, mask):
+        """Forward method.
+
+        Args:
+            xs_pad: (batch, time, size = n_heads * attn_dim)
+            mask: (batch, 1, time), nonpadding is 1, padding is 0
+
+        Returns:
+            torch.Tensor: (batch, time, size)
+        """
+
+        batch_size, seq_len, _ = xs_pad.shape
+        mixed_query_layer = self.query(xs_pad)  # (batch, time, size)
+        mixed_key_layer = self.key(xs_pad)  # (batch, time, size)
+
+        if mask is not None:
+            mask = mask.eq(0)  # padding is 1, nonpadding is 0
+
+        # (batch, n_heads, time)
+        query_for_score = (
+            self.query_att(mixed_query_layer).transpose(1, 2)
+            / self.attention_head_size**0.5
+        )
+        if mask is not None:
+            min_value = float(
+                numpy.finfo(
+                    torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype
+                ).min
+            )
+            query_for_score = query_for_score.masked_fill(mask, min_value)
+            query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0)
+        else:
+            query_weight = torch.softmax(query_for_score, dim=-1)
+
+        query_weight = query_weight.unsqueeze(2)  # (batch, n_heads, 1, time)
+        query_layer = self.transpose_for_scores(
+            mixed_query_layer
+        )  # (batch, n_heads, time, attn_dim)
+
+        pooled_query = (
+            torch.matmul(query_weight, query_layer)
+            .transpose(1, 2)
+            .reshape(-1, 1, self.num_attention_heads * self.attention_head_size)
+        )  # (batch, 1, size = n_heads * attn_dim)
+        pooled_query = self.dropout(pooled_query)
+        pooled_query_repeat = pooled_query.repeat(1, seq_len, 1)  # (batch, time, size)
+
+        mixed_query_key_layer = (
+            mixed_key_layer * pooled_query_repeat
+        )  # (batch, time, size)
+
+        # (batch, n_heads, time)
+        query_key_score = (
+            self.key_att(mixed_query_key_layer) / self.attention_head_size**0.5
+        ).transpose(1, 2)
+        if mask is not None:
+            min_value = float(
+                numpy.finfo(
+                    torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype
+                ).min
+            )
+            query_key_score = query_key_score.masked_fill(mask, min_value)
+            query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(
+                mask, 0.0
+            )
+        else:
+            query_key_weight = torch.softmax(query_key_score, dim=-1)
+
+        query_key_weight = query_key_weight.unsqueeze(2)  # (batch, n_heads, 1, time)
+        key_layer = self.transpose_for_scores(
+            mixed_query_key_layer
+        )  # (batch, n_heads, time, attn_dim)
+        pooled_key = torch.matmul(
+            query_key_weight, key_layer
+        )  # (batch, n_heads, 1, attn_dim)
+        pooled_key = self.dropout(pooled_key)
+
+        # NOTE: value = query, due to param sharing
+        weighted_value = (pooled_key * query_layer).transpose(
+            1, 2
+        )  # (batch, time, n_heads, attn_dim)
+        weighted_value = weighted_value.reshape(
+            weighted_value.shape[:-2]
+            + (self.num_attention_heads * self.attention_head_size,)
+        )  # (batch, time, size)
+        weighted_value = (
+            self.dropout(self.transform(weighted_value)) + mixed_query_layer
+        )
+
+        return weighted_value

--
Gitblit v1.9.1