From b5ad7c81be2e24f255cac3d0ef0037bf88228366 Mon Sep 17 00:00:00 2001
From: Kun Zou <54402842+kenzoukun@users.noreply.github.com>
Date: 星期六, 21 十二月 2024 17:13:46 +0800
Subject: [PATCH] Support eparaformer model on aishell1 recipe (#2327)

---
 examples/aishell/e_paraformer/demo_train_or_finetune.sh                        |   51 
 examples/aishell/e_paraformer/utils/compute_wer.py                             |  197 ++
 examples/aishell/e_paraformer/demo_infer.sh                                    |   15 
 examples/aishell/e_paraformer/local/download_and_untar.sh                      |  105 +
 funasr/models/e_paraformer/search.py                                           |  451 +++++
 examples/aishell/e_paraformer/utils/parse_options.sh                           |   97 +
 examples/aishell/e_paraformer/utils/postprocess_text_zh.py                     |   30 
 funasr/models/e_paraformer/decoder.py                                          | 1193 ++++++++++++++
 examples/aishell/e_paraformer/utils/text_tokenize.sh                           |   35 
 funasr/models/e_paraformer/export_meta.py                                      |   86 +
 examples/aishell/e_paraformer/utils/filter_scp.pl                              |   87 +
 examples/aishell/e_paraformer/utils/textnorm_zh.py                             |  911 ++++++++++
 examples/aishell/e_paraformer/run.sh                                           |  201 ++
 examples/aishell/e_paraformer/utils/split_scp.pl                               |  246 ++
 examples/aishell/e_paraformer/local/aishell_data_prep.sh                       |   66 
 funasr/models/e_paraformer/pif_predictor.py                                    |  107 +
 examples/aishell/e_paraformer/utils/text_tokenize.py                           |  104 +
 funasr/models/e_paraformer/__init__.py                                         |    0 
 examples/aishell/e_paraformer/utils/fix_data.sh                                |   35 
 examples/aishell/e_paraformer/utils/shuffle_list.pl                            |   44 
 examples/aishell/e_paraformer/utils/extract_embeds.py                          |   49 
 funasr/models/e_paraformer/model.py                                            |  670 +++++++
 examples/aishell/e_paraformer/utils/fix_data_feat.sh                           |   52 
 examples/aishell/e_paraformer/utils/text2token.py                              |  141 +
 examples/aishell/e_paraformer/conf/e_paraformer_conformer_12e_6d_2048_256.yaml |  121 +
 25 files changed, 5,094 insertions(+), 0 deletions(-)

diff --git a/examples/aishell/e_paraformer/conf/e_paraformer_conformer_12e_6d_2048_256.yaml b/examples/aishell/e_paraformer/conf/e_paraformer_conformer_12e_6d_2048_256.yaml
new file mode 100644
index 0000000..14617e5
--- /dev/null
+++ b/examples/aishell/e_paraformer/conf/e_paraformer_conformer_12e_6d_2048_256.yaml
@@ -0,0 +1,121 @@
+
+# network architecture
+model: EParaformer
+model_conf:
+    ctc_weight: 0.0
+    lsm_weight: 0.1
+    length_normalized_loss: false
+    predictor_weight: 1.0
+    predictor_bias: 2
+    sampling_ratio: 0.4
+    use_1st_decoder_loss: true
+
+# encoder
+encoder: ConformerEncoder
+encoder_conf:
+    output_size: 256    # dimension of attention
+    attention_heads: 4
+    linear_units: 2048  # the number of units of position-wise feed forward
+    num_blocks: 12      # the number of encoder blocks
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.0
+    input_layer: conv2d # encoder architecture type
+    normalize_before: true
+    pos_enc_layer_type: rel_pos
+    selfattention_layer_type: rel_selfattn
+    activation_type: swish
+    macaron_style: true
+    use_cnn_module: true
+    cnn_module_kernel: 15
+
+# decoder
+decoder: ParaformerSANDecoder
+decoder_conf:
+    attention_heads: 4
+    linear_units: 2048
+    num_blocks: 6
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    self_attention_dropout_rate: 0.0
+    src_attention_dropout_rate: 0.0
+
+# predictor
+predictor: PifPredictor
+predictor_conf:
+    idim: 256
+    threshold: 1.0
+    l_order: 1
+    r_order: 1
+    sigma: 0.5
+    bias: 0.0
+    sigma_heads: 4
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 1
+    lfr_n: 1
+
+specaug: SpecAug
+specaug_conf:
+    apply_time_warp: true
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    num_freq_mask: 2
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 40
+    num_time_mask: 2
+
+train_conf:
+  accum_grad: 4
+  grad_clip: 5
+  max_epoch: 150
+  keep_nbest_models: 20
+  avg_nbest_model: 15
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: EspnetStyleBatchSampler
+    batch_type: length # example or length
+    batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 1024
+    shuffle: True
+    num_workers: 4
+    preprocessor_speech: SpeechPreprocessSpeedPerturb
+    preprocessor_speech_conf:
+      speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+normalize: null
+
+
diff --git a/examples/aishell/e_paraformer/demo_infer.sh b/examples/aishell/e_paraformer/demo_infer.sh
new file mode 100644
index 0000000..b3b989f
--- /dev/null
+++ b/examples/aishell/e_paraformer/demo_infer.sh
@@ -0,0 +1,15 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+
+
+python -m funasr.bin.inference \
+--config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \
+--config-name="config.yaml" \
+++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \
+++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \
+++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \
+++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \
+++output_dir="./outputs/debug" \
+++device="cuda:0" \
+
diff --git a/examples/aishell/e_paraformer/demo_train_or_finetune.sh b/examples/aishell/e_paraformer/demo_train_or_finetune.sh
new file mode 100644
index 0000000..06607c7
--- /dev/null
+++ b/examples/aishell/e_paraformer/demo_train_or_finetune.sh
@@ -0,0 +1,51 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+
+# which gpu to train or finetune
+export CUDA_VISIBLE_DEVICES="0,1"
+gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+
+# data dir, which contains: train.json, val.json, tokens.jsonl/tokens.txt, am.mvn
+data_dir="/Users/zhifu/funasr1.0/data/list"
+
+## generate jsonl from wav.scp and text.txt
+#python -m funasr.datasets.audio_datasets.scp2jsonl \
+#++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
+#++data_type_list='["source", "target"]' \
+#++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
+
+train_data="${data_dir}/train.jsonl"
+val_data="${data_dir}/val.jsonl"
+tokens="${data_dir}/tokens.json"
+cmvn_file="${data_dir}/am.mvn"
+
+# exp output dir
+output_dir="/Users/zhifu/exp"
+log_file="${output_dir}/log.txt"
+
+workspace=`pwd`
+config="paraformer_conformer_12e_6d_2048_256.yaml"
+
+init_param="${output_dir}/model.pt"
+
+mkdir -p ${output_dir}
+echo "log_file: ${log_file}"
+
+torchrun \
+--nnodes 1 \
+--nproc_per_node ${gpu_num} \
+../../../funasr/bin/train.py \
+--config-path "${workspace}/conf" \
+--config-name "${config}" \
+++train_data_set_list="${train_data}" \
+++valid_data_set_list="${val_data}" \
+++tokenizer_conf.token_list="${tokens}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++dataset_conf.batch_size=32 \
+++dataset_conf.batch_type="example" \
+++dataset_conf.num_workers=4 \
+++train_conf.max_epoch=150 \
+++optim_conf.lr=0.0002 \
+++init_param="${init_param}" \
+++output_dir="${output_dir}" &> ${log_file}
diff --git a/examples/aishell/e_paraformer/local/aishell_data_prep.sh b/examples/aishell/e_paraformer/local/aishell_data_prep.sh
new file mode 100755
index 0000000..83f489b
--- /dev/null
+++ b/examples/aishell/e_paraformer/local/aishell_data_prep.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+
+# Copyright 2017 Xingyu Na
+# Apache 2.0
+
+#. ./path.sh || exit 1;
+
+if [ $# != 3 ]; then
+  echo "Usage: $0 <audio-path> <text-path> <output-path>"
+  echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
+  exit 1;
+fi
+
+aishell_audio_dir=$1
+aishell_text=$2/aishell_transcript_v0.8.txt
+output_dir=$3
+
+train_dir=$output_dir/data/local/train
+dev_dir=$output_dir/data/local/dev
+test_dir=$output_dir/data/local/test
+tmp_dir=$output_dir/data/local/tmp
+
+mkdir -p $train_dir
+mkdir -p $dev_dir
+mkdir -p $test_dir
+mkdir -p $tmp_dir
+
+# data directory check
+if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
+  echo "Error: $0 requires two directory arguments"
+  exit 1;
+fi
+
+# find wav audio file for train, dev and test resp.
+find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
+n=`cat $tmp_dir/wav.flist | wc -l`
+[ $n -ne 141925 ] && \
+  echo Warning: expected 141925 data data files, found $n
+
+grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
+grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
+grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
+
+rm -r $tmp_dir
+
+# Transcriptions preparation
+for dir in $train_dir $dev_dir $test_dir; do
+  echo Preparing $dir transcriptions
+  sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
+  paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
+  utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
+  awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
+  utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
+  sort -u $dir/transcripts.txt > $dir/text
+done
+
+mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
+
+for f in wav.scp text; do
+  cp $train_dir/$f $output_dir/data/train/$f || exit 1;
+  cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
+  cp $test_dir/$f $output_dir/data/test/$f || exit 1;
+done
+
+echo "$0: AISHELL data preparation succeeded"
+exit 0;
diff --git a/examples/aishell/e_paraformer/local/download_and_untar.sh b/examples/aishell/e_paraformer/local/download_and_untar.sh
new file mode 100755
index 0000000..d982559
--- /dev/null
+++ b/examples/aishell/e_paraformer/local/download_and_untar.sh
@@ -0,0 +1,105 @@
+#!/usr/bin/env bash
+
+# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
+#             2017  Xingyu Na
+# Apache 2.0
+
+remove_archive=false
+
+if [ "$1" == --remove-archive ]; then
+  remove_archive=true
+  shift
+fi
+
+if [ $# -ne 3 ]; then
+  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
+  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
+  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
+  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
+fi
+
+data=$1
+url=$2
+part=$3
+
+if [ ! -d "$data" ]; then
+  echo "$0: no such directory $data"
+  exit 1;
+fi
+
+part_ok=false
+list="data_aishell resource_aishell"
+for x in $list; do
+  if [ "$part" == $x ]; then part_ok=true; fi
+done
+if ! $part_ok; then
+  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
+  exit 1;
+fi
+
+if [ -z "$url" ]; then
+  echo "$0: empty URL base."
+  exit 1;
+fi
+
+if [ -f $data/$part/.complete ]; then
+  echo "$0: data part $part was already successfully extracted, nothing to do."
+  exit 0;
+fi
+
+# sizes of the archive files in bytes.
+sizes="15582913665 1246920"
+
+if [ -f $data/$part.tgz ]; then
+  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
+  size_ok=false
+  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
+  if ! $size_ok; then
+    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
+    echo "does not equal the size of one of the archives."
+    rm $data/$part.tgz
+  else
+    echo "$data/$part.tgz exists and appears to be complete."
+  fi
+fi
+
+if [ ! -f $data/$part.tgz ]; then
+  if ! command -v wget >/dev/null; then
+    echo "$0: wget is not installed."
+    exit 1;
+  fi
+  full_url=$url/$part.tgz
+  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
+
+  cd $data || exit 1
+  if ! wget --no-check-certificate $full_url; then
+    echo "$0: error executing wget $full_url"
+    exit 1;
+  fi
+fi
+
+cd $data || exit 1
+
+if ! tar -xvzf $part.tgz; then
+  echo "$0: error un-tarring archive $data/$part.tgz"
+  exit 1;
+fi
+
+touch $data/$part/.complete
+
+if [ $part == "data_aishell" ]; then
+  cd $data/$part/wav || exit 1
+  for wav in ./*.tar.gz; do
+    echo "Extracting wav from $wav"
+    tar -zxf $wav && rm $wav
+  done
+fi
+
+echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
+
+if $remove_archive; then
+  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
+  rm $data/$part.tgz
+fi
+
+exit 0;
diff --git a/examples/aishell/e_paraformer/run.sh b/examples/aishell/e_paraformer/run.sh
new file mode 100755
index 0000000..ecafc35
--- /dev/null
+++ b/examples/aishell/e_paraformer/run.sh
@@ -0,0 +1,201 @@
+#!/usr/bin/env bash
+
+
+CUDA_VISIBLE_DEVICES="0,1"
+
+# general configuration
+feats_dir="../DATA" #feature output dictionary
+exp_dir=`pwd`
+lang=zh
+token_type=char
+stage=0
+stop_stage=5
+
+# feature configuration
+nj=32
+
+inference_device="cuda" #"cpu"
+inference_checkpoint="model.pt.avg10"
+inference_scp="wav.scp"
+inference_batch_size=32
+
+# data
+raw_data=../raw_data
+data_url=www.openslr.org/resources/33
+
+# exp tag
+tag="exp1"
+workspace=`pwd`
+
+master_port=12345
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train
+valid_set=dev
+test_sets="dev test"
+
+config=e_paraformer_conformer_12e_6d_2048_256.yaml
+model_dir="baseline_$(basename "${config}" .yaml)_${lang}_${token_type}_${tag}"
+
+
+if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
+    echo "stage -1: Data Download"
+    mkdir -p ${raw_data}
+    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
+    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+    echo "stage 0: Data preparation"
+    # Data preparation
+    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
+    for x in train dev test; do
+        cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+        paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
+            > ${feats_dir}/data/${x}/text
+        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+        mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
+
+        # convert wav.scp text to jsonl
+        scp_file_list_arg="++scp_file_list='[\"${feats_dir}/data/${x}/wav.scp\",\"${feats_dir}/data/${x}/text\"]'"
+        python ../../../funasr/datasets/audio_datasets/scp2jsonl.py \
+        ++data_type_list='["source", "target"]' \
+        ++jsonl_file_out=${feats_dir}/data/${x}/audio_datasets.jsonl \
+        ${scp_file_list_arg}
+    done
+fi
+
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+    echo "stage 1: Feature and CMVN Generation"
+    python ../../../funasr/bin/compute_audio_cmvn.py \
+    --config-path "${workspace}/conf" \
+    --config-name "${config}" \
+    ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
+    ++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json"
+fi
+
+token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+    echo "stage 2: Dictionary Preparation"
+    mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/
+   
+    echo "make a dictionary"
+    echo "<blank>" > ${token_list}
+    echo "<s>" >> ${token_list}
+    echo "</s>" >> ${token_list}
+    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
+        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+    echo "<unk>" >> ${token_list}
+fi
+
+# LM Training Stage
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+    echo "stage 3: LM Training"
+fi
+
+# ASR Training Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+  echo "stage 4: ASR Training"
+
+  mkdir -p ${exp_dir}/exp/${model_dir}
+  current_time=$(date "+%Y-%m-%d_%H-%M")
+  log_file="${exp_dir}/exp/${model_dir}/train.log.txt.${current_time}"
+  echo "log_file: ${log_file}"
+
+  export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
+  gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+  torchrun \
+  --nnodes 1 \
+  --nproc_per_node ${gpu_num} \
+  --master_port ${master_port} \
+  ../../../funasr/bin/train.py \
+  --config-path "${workspace}/conf" \
+  --config-name "${config}" \
+  ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
+  ++valid_data_set_list="${feats_dir}/data/${valid_set}/audio_datasets.jsonl" \
+  ++tokenizer_conf.token_list="${token_list}" \
+  ++frontend_conf.cmvn_file="${feats_dir}/data/${train_set}/am.mvn" \
+  ++output_dir="${exp_dir}/exp/${model_dir}" &> ${log_file}
+fi
+
+
+
+# Testing Stage
+if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
+  echo "stage 5: Inference"
+
+  if [ ${inference_device} == "cuda" ]; then
+      nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+  else
+      inference_batch_size=1
+      CUDA_VISIBLE_DEVICES=""
+      for JOB in $(seq ${nj}); do
+          CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+      done
+  fi
+
+  for dset in ${test_sets}; do
+
+    inference_dir="${exp_dir}/exp/${model_dir}/inference-${inference_checkpoint}/${dset}"
+    _logdir="${inference_dir}/logdir"
+    echo "inference_dir: ${inference_dir}"
+
+    mkdir -p "${_logdir}"
+    data_dir="${feats_dir}/data/${dset}"
+    key_file=${data_dir}/${inference_scp}
+
+    split_scps=
+    for JOB in $(seq "${nj}"); do
+        split_scps+=" ${_logdir}/keys.${JOB}.scp"
+    done
+    utils/split_scp.pl "${key_file}" ${split_scps}
+
+    gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+    for JOB in $(seq ${nj}); do
+        {
+          id=$((JOB-1))
+          gpuid=${gpuid_list_array[$id]}
+
+          export CUDA_VISIBLE_DEVICES=${gpuid}
+          python ../../../funasr/bin/inference.py \
+          --config-path="${exp_dir}/exp/${model_dir}" \
+          --config-name="config.yaml" \
+          ++init_param="${exp_dir}/exp/${model_dir}/${inference_checkpoint}" \
+          ++tokenizer_conf.token_list="${token_list}" \
+          ++frontend_conf.cmvn_file="${feats_dir}/data/${train_set}/am.mvn" \
+          ++input="${_logdir}/keys.${JOB}.scp" \
+          ++output_dir="${inference_dir}/${JOB}" \
+          ++device="${inference_device}" \
+          ++ncpu=1 \
+          ++disable_log=true \
+          ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
+        }&
+
+    done
+    wait
+
+    mkdir -p ${inference_dir}/1best_recog
+    for f in token score text; do
+        if [ -f "${inference_dir}/${JOB}/1best_recog/${f}" ]; then
+          for JOB in $(seq "${nj}"); do
+              cat "${inference_dir}/${JOB}/1best_recog/${f}"
+          done | sort -k1 >"${inference_dir}/1best_recog/${f}"
+        fi
+    done
+
+    echo "Computing WER ..."
+    python utils/postprocess_text_zh.py ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc
+    python utils/postprocess_text_zh.py  ${data_dir}/text ${inference_dir}/1best_recog/text.ref
+    python utils/compute_wer.py ${inference_dir}/1best_recog/text.ref ${inference_dir}/1best_recog/text.proc ${inference_dir}/1best_recog/text.cer
+    tail -n 3 ${inference_dir}/1best_recog/text.cer
+  done
+
+fi
diff --git a/examples/aishell/e_paraformer/utils/compute_wer.py b/examples/aishell/e_paraformer/utils/compute_wer.py
new file mode 100755
index 0000000..3d00010
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/compute_wer.py
@@ -0,0 +1,197 @@
+import os
+import numpy as np
+import sys
+
+
+def compute_wer(ref_file, hyp_file, cer_detail_file):
+    rst = {
+        "Wrd": 0,
+        "Corr": 0,
+        "Ins": 0,
+        "Del": 0,
+        "Sub": 0,
+        "Snt": 0,
+        "Err": 0.0,
+        "S.Err": 0.0,
+        "wrong_words": 0,
+        "wrong_sentences": 0,
+    }
+
+    hyp_dict = {}
+    ref_dict = {}
+    with open(hyp_file, "r") as hyp_reader:
+        for line in hyp_reader:
+            key = line.strip().split()[0]
+            value = line.strip().split()[1:]
+            hyp_dict[key] = value
+    with open(ref_file, "r") as ref_reader:
+        for line in ref_reader:
+            key = line.strip().split()[0]
+            value = line.strip().split()[1:]
+            ref_dict[key] = value
+
+    cer_detail_writer = open(cer_detail_file, "w")
+    for hyp_key in hyp_dict:
+        if hyp_key in ref_dict:
+            out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
+            rst["Wrd"] += out_item["nwords"]
+            rst["Corr"] += out_item["cor"]
+            rst["wrong_words"] += out_item["wrong"]
+            rst["Ins"] += out_item["ins"]
+            rst["Del"] += out_item["del"]
+            rst["Sub"] += out_item["sub"]
+            rst["Snt"] += 1
+            if out_item["wrong"] > 0:
+                rst["wrong_sentences"] += 1
+            cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + "\n")
+            cer_detail_writer.write(
+                "ref:" + "\t" + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + "\n"
+            )
+            cer_detail_writer.write(
+                "hyp:" + "\t" + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + "\n"
+            )
+
+    if rst["Wrd"] > 0:
+        rst["Err"] = round(rst["wrong_words"] * 100 / rst["Wrd"], 2)
+    if rst["Snt"] > 0:
+        rst["S.Err"] = round(rst["wrong_sentences"] * 100 / rst["Snt"], 2)
+
+    cer_detail_writer.write("\n")
+    cer_detail_writer.write(
+        "%WER "
+        + str(rst["Err"])
+        + " [ "
+        + str(rst["wrong_words"])
+        + " / "
+        + str(rst["Wrd"])
+        + ", "
+        + str(rst["Ins"])
+        + " ins, "
+        + str(rst["Del"])
+        + " del, "
+        + str(rst["Sub"])
+        + " sub ]"
+        + "\n"
+    )
+    cer_detail_writer.write(
+        "%SER "
+        + str(rst["S.Err"])
+        + " [ "
+        + str(rst["wrong_sentences"])
+        + " / "
+        + str(rst["Snt"])
+        + " ]"
+        + "\n"
+    )
+    cer_detail_writer.write(
+        "Scored "
+        + str(len(hyp_dict))
+        + " sentences, "
+        + str(len(hyp_dict) - rst["Snt"])
+        + " not present in hyp."
+        + "\n"
+    )
+
+
+def compute_wer_by_line(hyp, ref):
+    hyp = list(map(lambda x: x.lower(), hyp))
+    ref = list(map(lambda x: x.lower(), ref))
+
+    len_hyp = len(hyp)
+    len_ref = len(ref)
+
+    cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
+
+    ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
+
+    for i in range(len_hyp + 1):
+        cost_matrix[i][0] = i
+    for j in range(len_ref + 1):
+        cost_matrix[0][j] = j
+
+    for i in range(1, len_hyp + 1):
+        for j in range(1, len_ref + 1):
+            if hyp[i - 1] == ref[j - 1]:
+                cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
+            else:
+                substitution = cost_matrix[i - 1][j - 1] + 1
+                insertion = cost_matrix[i - 1][j] + 1
+                deletion = cost_matrix[i][j - 1] + 1
+
+                compare_val = [substitution, insertion, deletion]
+
+                min_val = min(compare_val)
+                operation_idx = compare_val.index(min_val) + 1
+                cost_matrix[i][j] = min_val
+                ops_matrix[i][j] = operation_idx
+
+    match_idx = []
+    i = len_hyp
+    j = len_ref
+    rst = {"nwords": len_ref, "cor": 0, "wrong": 0, "ins": 0, "del": 0, "sub": 0}
+    while i >= 0 or j >= 0:
+        i_idx = max(0, i)
+        j_idx = max(0, j)
+
+        if ops_matrix[i_idx][j_idx] == 0:  # correct
+            if i - 1 >= 0 and j - 1 >= 0:
+                match_idx.append((j - 1, i - 1))
+                rst["cor"] += 1
+
+            i -= 1
+            j -= 1
+
+        elif ops_matrix[i_idx][j_idx] == 2:  # insert
+            i -= 1
+            rst["ins"] += 1
+
+        elif ops_matrix[i_idx][j_idx] == 3:  # delete
+            j -= 1
+            rst["del"] += 1
+
+        elif ops_matrix[i_idx][j_idx] == 1:  # substitute
+            i -= 1
+            j -= 1
+            rst["sub"] += 1
+
+        if i < 0 and j >= 0:
+            rst["del"] += 1
+        elif j < 0 and i >= 0:
+            rst["ins"] += 1
+
+    match_idx.reverse()
+    wrong_cnt = cost_matrix[len_hyp][len_ref]
+    rst["wrong"] = wrong_cnt
+
+    return rst
+
+
+def print_cer_detail(rst):
+    return (
+        "("
+        + "nwords="
+        + str(rst["nwords"])
+        + ",cor="
+        + str(rst["cor"])
+        + ",ins="
+        + str(rst["ins"])
+        + ",del="
+        + str(rst["del"])
+        + ",sub="
+        + str(rst["sub"])
+        + ") corr:"
+        + "{:.2%}".format(rst["cor"] / rst["nwords"])
+        + ",cer:"
+        + "{:.2%}".format(rst["wrong"] / rst["nwords"])
+    )
+
+
+if __name__ == "__main__":
+    if len(sys.argv) != 4:
+        print("usage : python compute-wer.py test.ref test.hyp test.wer")
+        sys.exit(0)
+
+    ref_file = sys.argv[1]
+    hyp_file = sys.argv[2]
+    cer_detail_file = sys.argv[3]
+    compute_wer(ref_file, hyp_file, cer_detail_file)
diff --git a/examples/aishell/e_paraformer/utils/extract_embeds.py b/examples/aishell/e_paraformer/utils/extract_embeds.py
new file mode 100755
index 0000000..e0cf98d
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/extract_embeds.py
@@ -0,0 +1,49 @@
+from transformers import AutoTokenizer, AutoModel, pipeline
+import numpy as np
+import sys
+import os
+import torch
+from kaldiio import WriteHelper
+import re
+
+text_file_json = sys.argv[1]
+out_ark = sys.argv[2]
+out_scp = sys.argv[3]
+out_shape = sys.argv[4]
+device = int(sys.argv[5])
+model_path = sys.argv[6]
+
+model = AutoModel.from_pretrained(model_path)
+tokenizer = AutoTokenizer.from_pretrained(model_path)
+extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
+
+with open(text_file_json, "r") as f:
+    js = f.readlines()
+
+
+f_shape = open(out_shape, "w")
+with WriteHelper("ark,scp:{},{}".format(out_ark, out_scp)) as writer:
+    with torch.no_grad():
+        for idx, line in enumerate(js):
+            id, tokens = line.strip().split(" ", 1)
+            tokens = re.sub(" ", "", tokens.strip())
+            tokens = " ".join([j for j in tokens])
+            token_num = len(tokens.split(" "))
+            outputs = extractor(tokens)
+            outputs = np.array(outputs)
+            embeds = outputs[0, 1:-1, :]
+
+            token_num_embeds, dim = embeds.shape
+            if token_num == token_num_embeds:
+                writer(id, embeds)
+                shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
+                f_shape.write(shape_line)
+            else:
+                print(
+                    "{}, size has changed, {}, {}, {}".format(
+                        id, token_num, token_num_embeds, tokens
+                    )
+                )
+
+
+f_shape.close()
diff --git a/examples/aishell/e_paraformer/utils/filter_scp.pl b/examples/aishell/e_paraformer/utils/filter_scp.pl
new file mode 100755
index 0000000..003530d
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/filter_scp.pl
@@ -0,0 +1,87 @@
+#!/usr/bin/env perl
+# Copyright 2010-2012 Microsoft Corporation
+#                     Johns Hopkins University (author: Daniel Povey)
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# This script takes a list of utterance-ids or any file whose first field
+# of each line is an utterance-id, and filters an scp
+# file (or any file whose "n-th" field is an utterance id), printing
+# out only those lines whose "n-th" field is in id_list. The index of
+# the "n-th" field is 1, by default, but can be changed by using
+# the -f <n> switch
+
+$exclude = 0;
+$field = 1;
+$shifted = 0;
+
+do {
+  $shifted=0;
+  if ($ARGV[0] eq "--exclude") {
+    $exclude = 1;
+    shift @ARGV;
+    $shifted=1;
+  }
+  if ($ARGV[0] eq "-f") {
+    $field = $ARGV[1];
+    shift @ARGV; shift @ARGV;
+    $shifted=1
+  }
+} while ($shifted);
+
+if(@ARGV < 1 || @ARGV > 2) {
+  die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
+      "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
+      "Note: only the first field of each line in id_list matters.  With --exclude, prints\n" .
+      "only the lines that were *not* in id_list.\n" .
+      "Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
+      "If your older scripts (written before Oct 2014) stopped working and you used the\n" .
+      "-f option, add 1 to the argument.\n" .
+      "See also: scripts/filter_scp.pl .\n";
+}
+
+
+$idlist = shift @ARGV;
+open(F, "<$idlist") || die "Could not open id-list file $idlist";
+while(<F>) {
+  @A = split;
+  @A>=1 || die "Invalid id-list file line $_";
+  $seen{$A[0]} = 1;
+}
+
+if ($field == 1) { # Treat this as special case, since it is common.
+  while(<>) {
+    $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
+    # $1 is what we filter on.
+    if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
+      print $_;
+    }
+  }
+} else {
+  while(<>) {
+    @A = split;
+    @A > 0 || die "Invalid scp file line $_";
+    @A >= $field || die "Invalid scp file line $_";
+    if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
+      print $_;
+    }
+  }
+}
+
+# tests:
+# the following should print "foo 1"
+# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl <(echo foo)
+# the following should print "bar 2".
+# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl -f 2 <(echo 2)
diff --git a/examples/aishell/e_paraformer/utils/fix_data.sh b/examples/aishell/e_paraformer/utils/fix_data.sh
new file mode 100755
index 0000000..b1a2bb8
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/fix_data.sh
@@ -0,0 +1,35 @@
+#!/usr/bin/env bash
+
+echo "$0 $@"
+data_dir=$1
+
+if [ ! -f ${data_dir}/wav.scp ]; then
+  echo "$0: wav.scp is not found"
+  exit 1;
+fi
+
+if [ ! -f ${data_dir}/text ]; then
+  echo "$0: text is not found"
+  exit 1;
+fi
+
+
+
+mkdir -p ${data_dir}/.backup
+
+awk '{print $1}' ${data_dir}/wav.scp > ${data_dir}/.backup/wav_id
+awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
+
+sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
+
+cp ${data_dir}/wav.scp ${data_dir}/.backup/wav.scp
+cp ${data_dir}/text ${data_dir}/.backup/text
+
+mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak
+mv ${data_dir}/text ${data_dir}/text.bak
+
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak | sort -k1,1 -u > ${data_dir}/wav.scp
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
+
+rm ${data_dir}/wav.scp.bak
+rm ${data_dir}/text.bak
diff --git a/examples/aishell/e_paraformer/utils/fix_data_feat.sh b/examples/aishell/e_paraformer/utils/fix_data_feat.sh
new file mode 100755
index 0000000..84eea36
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/fix_data_feat.sh
@@ -0,0 +1,52 @@
+#!/usr/bin/env bash
+
+echo "$0 $@"
+data_dir=$1
+
+if [ ! -f ${data_dir}/feats.scp ]; then
+  echo "$0: feats.scp is not found"
+  exit 1;
+fi
+
+if [ ! -f ${data_dir}/text ]; then
+  echo "$0: text is not found"
+  exit 1;
+fi
+
+if [ ! -f ${data_dir}/speech_shape ]; then
+  echo "$0: feature lengths is not found"
+  exit 1;
+fi
+
+if [ ! -f ${data_dir}/text_shape ]; then
+  echo "$0: text lengths is not found"
+  exit 1;
+fi
+
+mkdir -p ${data_dir}/.backup
+
+awk '{print $1}' ${data_dir}/feats.scp > ${data_dir}/.backup/wav_id
+awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
+
+sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
+
+cp ${data_dir}/feats.scp ${data_dir}/.backup/feats.scp
+cp ${data_dir}/text ${data_dir}/.backup/text
+cp ${data_dir}/speech_shape ${data_dir}/.backup/speech_shape
+cp ${data_dir}/text_shape ${data_dir}/.backup/text_shape
+
+mv ${data_dir}/feats.scp ${data_dir}/feats.scp.bak
+mv ${data_dir}/text ${data_dir}/text.bak
+mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak
+mv ${data_dir}/text_shape ${data_dir}/text_shape.bak
+
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak | sort -k1,1 -u > ${data_dir}/feats.scp
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak | sort -k1,1 -u > ${data_dir}/speech_shape
+utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak | sort -k1,1 -u > ${data_dir}/text_shape
+
+rm ${data_dir}/feats.scp.bak
+rm ${data_dir}/text.bak
+rm ${data_dir}/speech_shape.bak
+rm ${data_dir}/text_shape.bak
+
diff --git a/examples/aishell/e_paraformer/utils/parse_options.sh b/examples/aishell/e_paraformer/utils/parse_options.sh
new file mode 100755
index 0000000..71fb9e5
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/parse_options.sh
@@ -0,0 +1,97 @@
+#!/usr/bin/env bash
+
+# Copyright 2012  Johns Hopkins University (Author: Daniel Povey);
+#                 Arnab Ghoshal, Karel Vesely
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# Parse command-line options.
+# To be sourced by another script (as in ". parse_options.sh").
+# Option format is: --option-name arg
+# and shell variable "option_name" gets set to value "arg."
+# The exception is --help, which takes no arguments, but prints the
+# $help_message variable (if defined).
+
+
+###
+### The --config file options have lower priority to command line
+### options, so we need to import them first...
+###
+
+# Now import all the configs specified by command-line, in left-to-right order
+for ((argpos=1; argpos<$#; argpos++)); do
+  if [ "${!argpos}" == "--config" ]; then
+    argpos_plus1=$((argpos+1))
+    config=${!argpos_plus1}
+    [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
+    . $config  # source the config file.
+  fi
+done
+
+
+###
+### Now we process the command line options
+###
+while true; do
+  [ -z "${1:-}" ] && break;  # break if there are no arguments
+  case "$1" in
+    # If the enclosing script is called with --help option, print the help
+    # message and exit.  Scripts should put help messages in $help_message
+    --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
+      else printf "$help_message\n" 1>&2 ; fi;
+      exit 0 ;;
+    --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
+      exit 1 ;;
+    # If the first command-line argument begins with "--" (e.g. --foo-bar),
+    # then work out the variable name as $name, which will equal "foo_bar".
+    --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
+      # Next we test whether the variable in question is undefned-- if so it's
+      # an invalid option and we die.  Note: $0 evaluates to the name of the
+      # enclosing script.
+      # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
+      # is undefined.  We then have to wrap this test inside "eval" because
+      # foo_bar is itself inside a variable ($name).
+      eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
+
+      oldval="`eval echo \\$$name`";
+      # Work out whether we seem to be expecting a Boolean argument.
+      if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
+        was_bool=true;
+      else
+        was_bool=false;
+      fi
+
+      # Set the variable to the right value-- the escaped quotes make it work if
+      # the option had spaces, like --cmd "queue.pl -sync y"
+      eval $name=\"$2\";
+
+      # Check that Boolean-valued arguments are really Boolean.
+      if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
+        echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
+        exit 1;
+      fi
+      shift 2;
+      ;;
+  *) break;
+  esac
+done
+
+
+# Check for an empty argument to the --cmd option, which can easily occur as a
+# result of scripting errors.
+[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
+
+
+true; # so this script returns exit code 0.
diff --git a/examples/aishell/e_paraformer/utils/postprocess_text_zh.py b/examples/aishell/e_paraformer/utils/postprocess_text_zh.py
new file mode 100755
index 0000000..d03febd
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/postprocess_text_zh.py
@@ -0,0 +1,30 @@
+import sys
+import re
+
+in_f = sys.argv[1]
+out_f = sys.argv[2]
+
+
+with open(in_f, "r", encoding="utf-8") as f:
+    lines = f.readlines()
+
+with open(out_f, "w", encoding="utf-8") as f:
+    for line in lines:
+        outs = line.strip().split(" ", 1)
+        if len(outs) == 2:
+            idx, text = outs
+            text = re.sub("</s>", "", text)
+            text = re.sub("<s>", "", text)
+            text = re.sub("@@", "", text)
+            text = re.sub("@", "", text)
+            text = re.sub("<unk>", "", text)
+            text = re.sub(" ", "", text)
+            text = text.lower()
+        else:
+            idx = outs[0]
+            text = " "
+
+        text = [x for x in text]
+        text = " ".join(text)
+        out = "{} {}\n".format(idx, text)
+        f.write(out)
diff --git a/examples/aishell/e_paraformer/utils/shuffle_list.pl b/examples/aishell/e_paraformer/utils/shuffle_list.pl
new file mode 100755
index 0000000..a116200
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/shuffle_list.pl
@@ -0,0 +1,44 @@
+#!/usr/bin/env perl
+
+# Copyright 2013  Johns Hopkins University (author: Daniel Povey)
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+if ($ARGV[0] eq "--srand") {
+  $n = $ARGV[1];
+  $n =~ m/\d+/ || die "Bad argument to --srand option: \"$n\"";
+  srand($ARGV[1]);
+  shift;
+  shift;
+} else {
+  srand(0); # Gives inconsistent behavior if we don't seed.
+}
+
+if (@ARGV > 1 || $ARGV[0] =~ m/^-.+/) { # >1 args, or an option we
+  # don't understand.
+  print "Usage: shuffle_list.pl [--srand N] [input file]  > output\n";
+  print "randomizes the order of lines of input.\n";
+  exit(1);
+}
+
+@lines;
+while (<>) {
+  push @lines, [ (rand(), $_)] ;
+}
+
+@lines = sort { $a->[0] cmp $b->[0] } @lines;
+foreach $l (@lines) {
+    print $l->[1];
+}
\ No newline at end of file
diff --git a/examples/aishell/e_paraformer/utils/split_scp.pl b/examples/aishell/e_paraformer/utils/split_scp.pl
new file mode 100755
index 0000000..0876dcb
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/split_scp.pl
@@ -0,0 +1,246 @@
+#!/usr/bin/env perl
+
+# Copyright 2010-2011 Microsoft Corporation
+
+# See ../../COPYING for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#  http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+
+# This program splits up any kind of .scp or archive-type file.
+# If there is no utt2spk option it will work on any text  file and
+# will split it up with an approximately equal number of lines in
+# each but.
+# With the --utt2spk option it will work on anything that has the
+# utterance-id as the first entry on each line; the utt2spk file is
+# of the form "utterance speaker" (on each line).
+# It splits it into equal size chunks as far as it can.  If you use the utt2spk
+# option it will make sure these chunks coincide with speaker boundaries.  In
+# this case, if there are more chunks than speakers (and in some other
+# circumstances), some of the resulting chunks will be empty and it will print
+# an error message and exit with nonzero status.
+# You will normally call this like:
+# split_scp.pl scp scp.1 scp.2 scp.3 ...
+# or
+# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
+# Note that you can use this script to split the utt2spk file itself,
+# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
+
+# You can also call the scripts like:
+# split_scp.pl -j 3 0 scp scp.0
+# [note: with this option, it assumes zero-based indexing of the split parts,
+# i.e. the second number must be 0 <= n < num-jobs.]
+
+use warnings;
+
+$num_jobs = 0;
+$job_id = 0;
+$utt2spk_file = "";
+$one_based = 0;
+
+for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
+    if ($ARGV[0] eq "-j") {
+        shift @ARGV;
+        $num_jobs = shift @ARGV;
+        $job_id = shift @ARGV;
+    }
+    if ($ARGV[0] =~ /--utt2spk=(.+)/) {
+        $utt2spk_file=$1;
+        shift;
+    }
+    if ($ARGV[0] eq '--one-based') {
+        $one_based = 1;
+        shift @ARGV;
+    }
+}
+
+if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
+                       $job_id - $one_based >= $num_jobs)) {
+  die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
+      ($one_based ? " --one-based" : "") . "'\n"
+}
+
+$one_based
+    and $job_id--;
+
+if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
+    die
+"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
+   or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
+ ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
+}
+
+$error = 0;
+$inscp = shift @ARGV;
+if ($num_jobs == 0) { # without -j option
+    @OUTPUTS = @ARGV;
+} else {
+    for ($j = 0; $j < $num_jobs; $j++) {
+        if ($j == $job_id) {
+            if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
+            else { push @OUTPUTS, "-"; }
+        } else {
+            push @OUTPUTS, "/dev/null";
+        }
+    }
+}
+
+if ($utt2spk_file ne "") {  # We have the --utt2spk option...
+    open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
+    while(<$u_fh>) {
+        @A = split;
+        @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
+        ($u,$s) = @A;
+        $utt2spk{$u} = $s;
+    }
+    close $u_fh;
+    open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
+    @spkrs = ();
+    while(<$i_fh>) {
+        @A = split;
+        if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
+        $u = $A[0];
+        $s = $utt2spk{$u};
+        defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
+        if(!defined $spk_count{$s}) {
+            push @spkrs, $s;
+            $spk_count{$s} = 0;
+            $spk_data{$s} = [];  # ref to new empty array.
+        }
+        $spk_count{$s}++;
+        push @{$spk_data{$s}}, $_;
+    }
+    # Now split as equally as possible ..
+    # First allocate spks to files by allocating an approximately
+    # equal number of speakers.
+    $numspks = @spkrs;  # number of speakers.
+    $numscps = @OUTPUTS; # number of output files.
+    if ($numspks < $numscps) {
+      die "$0: Refusing to split data because number of speakers $numspks " .
+          "is less than the number of output .scp files $numscps\n";
+    }
+    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+        $scparray[$scpidx] = []; # [] is array reference.
+    }
+    for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
+        $scpidx = int(($spkidx*$numscps) / $numspks);
+        $spk = $spkrs[$spkidx];
+        push @{$scparray[$scpidx]}, $spk;
+        $scpcount[$scpidx] += $spk_count{$spk};
+    }
+
+    # Now will try to reassign beginning + ending speakers
+    # to different scp's and see if it gets more balanced.
+    # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
+    # We can show that if considering changing just 2 scp's, we minimize
+    # this by minimizing the squared difference in sizes.  This is
+    # equivalent to minimizing the absolute difference in sizes.  This
+    # shows this method is bound to converge.
+
+    $changed = 1;
+    while($changed) {
+        $changed = 0;
+        for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+            # First try to reassign ending spk of this scp.
+            if($scpidx < $numscps-1) {
+                $sz = @{$scparray[$scpidx]};
+                if($sz > 0) {
+                    $spk = $scparray[$scpidx]->[$sz-1];
+                    $count = $spk_count{$spk};
+                    $nutt1 = $scpcount[$scpidx];
+                    $nutt2 = $scpcount[$scpidx+1];
+                    if( abs( ($nutt2+$count) - ($nutt1-$count))
+                        < abs($nutt2 - $nutt1))  { # Would decrease
+                        # size-diff by reassigning spk...
+                        $scpcount[$scpidx+1] += $count;
+                        $scpcount[$scpidx] -= $count;
+                        pop @{$scparray[$scpidx]};
+                        unshift @{$scparray[$scpidx+1]}, $spk;
+                        $changed = 1;
+                    }
+                }
+            }
+            if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
+                $spk = $scparray[$scpidx]->[0];
+                $count = $spk_count{$spk};
+                $nutt1 = $scpcount[$scpidx-1];
+                $nutt2 = $scpcount[$scpidx];
+                if( abs( ($nutt2-$count) - ($nutt1+$count))
+                    < abs($nutt2 - $nutt1))  { # Would decrease
+                    # size-diff by reassigning spk...
+                    $scpcount[$scpidx-1] += $count;
+                    $scpcount[$scpidx] -= $count;
+                    shift @{$scparray[$scpidx]};
+                    push @{$scparray[$scpidx-1]}, $spk;
+                    $changed = 1;
+                }
+            }
+        }
+    }
+    # Now print out the files...
+    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
+        $scpfile = $OUTPUTS[$scpidx];
+        ($scpfile ne '-' ? open($f_fh, '>', $scpfile)
+                         : open($f_fh, '>&', \*STDOUT)) ||
+            die "$0: Could not open scp file $scpfile for writing: $!\n";
+        $count = 0;
+        if(@{$scparray[$scpidx]} == 0) {
+            print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
+                         "$scpfile (too many splits and too few speakers?)\n";
+            $error = 1;
+        } else {
+            foreach $spk ( @{$scparray[$scpidx]} ) {
+                print $f_fh @{$spk_data{$spk}};
+                $count += $spk_count{$spk};
+            }
+            $count == $scpcount[$scpidx] || die "Count mismatch [code error]";
+        }
+        close($f_fh);
+    }
+} else {
+   # This block is the "normal" case where there is no --utt2spk
+   # option and we just break into equal size chunks.
+
+    open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
+
+    $numscps = @OUTPUTS;  # size of array.
+    @F = ();
+    while(<$i_fh>) {
+        push @F, $_;
+    }
+    $numlines = @F;
+    if($numlines == 0) {
+        print STDERR "$0: error: empty input scp file $inscp\n";
+        $error = 1;
+    }
+    $linesperscp = int( $numlines / $numscps); # the "whole part"..
+    $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
+    $remainder = $numlines - ($linesperscp * $numscps);
+    ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
+    # [just doing int() rounds down].
+    $n = 0;
+    for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
+        $scpfile = $OUTPUTS[$scpidx];
+        ($scpfile ne '-' ? open($o_fh, '>', $scpfile)
+                         : open($o_fh, '>&', \*STDOUT)) ||
+            die "$0: Could not open scp file $scpfile for writing: $!\n";
+        for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
+            print $o_fh $F[$n++];
+        }
+        close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
+    }
+    $n == $numlines || die "$n != $numlines [code error]";
+}
+
+exit ($error);
diff --git a/examples/aishell/e_paraformer/utils/text2token.py b/examples/aishell/e_paraformer/utils/text2token.py
new file mode 100755
index 0000000..c39db1e
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/text2token.py
@@ -0,0 +1,141 @@
+#!/usr/bin/env python3
+
+# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+
+import argparse
+import codecs
+import re
+import sys
+import json
+
+is_python2 = sys.version_info[0] == 2
+
+
+def exist_or_not(i, match_pos):
+    start_pos = None
+    end_pos = None
+    for pos in match_pos:
+        if pos[0] <= i < pos[1]:
+            start_pos = pos[0]
+            end_pos = pos[1]
+            break
+
+    return start_pos, end_pos
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        description="convert raw text to tokenized text",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--nchar",
+        "-n",
+        default=1,
+        type=int,
+        help="number of characters to split, i.e., \
+                        aabb -> a a b b with -n 1 and aa bb with -n 2",
+    )
+    parser.add_argument("--skip-ncols", "-s", default=0, type=int, help="skip first n columns")
+    parser.add_argument("--space", default="<space>", type=str, help="space symbol")
+    parser.add_argument(
+        "--non-lang-syms",
+        "-l",
+        default=None,
+        type=str,
+        help="list of non-linguistic symobles, e.g., <NOISE> etc.",
+    )
+    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
+    parser.add_argument(
+        "--trans_type",
+        "-t",
+        type=str,
+        default="char",
+        choices=["char", "phn"],
+        help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
+                        If trans_type is char,
+                        read from SI1279.WRD file -> "bricks are an alternative"
+                        Else if trans_type is phn,
+                        read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
+                        sil t er n ih sil t ih v sil" """,
+    )
+    parser.add_argument(
+        "--text_format",
+        default="text",
+        type=str,
+        help="text, jsonl",
+    )
+    return parser
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    rs = []
+    if args.non_lang_syms is not None:
+        with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
+            nls = [x.rstrip() for x in f.readlines()]
+            rs = [re.compile(re.escape(x)) for x in nls]
+
+    if args.text:
+        f = codecs.open(args.text, encoding="utf-8")
+    else:
+        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
+
+    sys.stdout = codecs.getwriter("utf-8")(sys.stdout if is_python2 else sys.stdout.buffer)
+    line = f.readline()
+    n = args.nchar
+    while line:
+        if args.text_format == "jsonl":
+            data = json.loads(line.strip())
+            line = data["target"]
+        x = line.split()
+        print(" ".join(x[: args.skip_ncols]), end=" ")
+        a = " ".join(x[args.skip_ncols :])
+
+        # get all matched positions
+        match_pos = []
+        for r in rs:
+            i = 0
+            while i >= 0:
+                m = r.search(a, i)
+                if m:
+                    match_pos.append([m.start(), m.end()])
+                    i = m.end()
+                else:
+                    break
+
+        if args.trans_type == "phn":
+            a = a.split(" ")
+        else:
+            if len(match_pos) > 0:
+                chars = []
+                i = 0
+                while i < len(a):
+                    start_pos, end_pos = exist_or_not(i, match_pos)
+                    if start_pos is not None:
+                        chars.append(a[start_pos:end_pos])
+                        i = end_pos
+                    else:
+                        chars.append(a[i])
+                        i += 1
+                a = chars
+
+            a = [a[j : j + n] for j in range(0, len(a), n)]
+
+        a_flat = []
+        for z in a:
+            a_flat.append("".join(z))
+
+        a_chars = [z.replace(" ", args.space) for z in a_flat]
+        if args.trans_type == "phn":
+            a_chars = [z.replace("sil", args.space) for z in a_chars]
+        print(" ".join(a_chars))
+        line = f.readline()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/examples/aishell/e_paraformer/utils/text_tokenize.py b/examples/aishell/e_paraformer/utils/text_tokenize.py
new file mode 100755
index 0000000..31500f1
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/text_tokenize.py
@@ -0,0 +1,104 @@
+import re
+import argparse
+
+
+def load_dict(seg_file):
+    seg_dict = {}
+    with open(seg_file, "r") as infile:
+        for line in infile:
+            s = line.strip().split()
+            key = s[0]
+            value = s[1:]
+            seg_dict[key] = " ".join(value)
+    return seg_dict
+
+
+def forward_segment(text, dic):
+    word_list = []
+    i = 0
+    while i < len(text):
+        longest_word = text[i]
+        for j in range(i + 1, len(text) + 1):
+            word = text[i:j]
+            if word in dic:
+                if len(word) > len(longest_word):
+                    longest_word = word
+        word_list.append(longest_word)
+        i += len(longest_word)
+    return word_list
+
+
+def tokenize(txt, seg_dict):
+    out_txt = ""
+    pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
+    for word in txt:
+        if pattern.match(word):
+            if word in seg_dict:
+                out_txt += seg_dict[word] + " "
+            else:
+                out_txt += "<unk>" + " "
+        else:
+            continue
+    return out_txt.strip()
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(
+        description="text tokenize",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--text-file",
+        "-t",
+        default=False,
+        required=True,
+        type=str,
+        help="input text",
+    )
+    parser.add_argument(
+        "--seg-file",
+        "-s",
+        default=False,
+        required=True,
+        type=str,
+        help="seg file",
+    )
+    parser.add_argument(
+        "--txt-index",
+        "-i",
+        default=1,
+        required=True,
+        type=int,
+        help="txt index",
+    )
+    parser.add_argument(
+        "--output-dir",
+        "-o",
+        default=False,
+        required=True,
+        type=str,
+        help="output dir",
+    )
+    return parser
+
+
+def main():
+    parser = get_parser()
+    args = parser.parse_args()
+
+    txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), "w")
+    shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), "w")
+    seg_dict = load_dict(args.seg_file)
+    with open(args.text_file, "r") as infile:
+        for line in infile:
+            s = line.strip().split()
+            text_id = s[0]
+            text_list = forward_segment("".join(s[1:]).lower(), seg_dict)
+            text = tokenize(text_list, seg_dict)
+            lens = len(text.strip().split())
+            txt_writer.write(text_id + " " + text + "\n")
+            shape_writer.write(text_id + " " + str(lens) + "\n")
+
+
+if __name__ == "__main__":
+    main()
diff --git a/examples/aishell/e_paraformer/utils/text_tokenize.sh b/examples/aishell/e_paraformer/utils/text_tokenize.sh
new file mode 100755
index 0000000..6b74fef
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/text_tokenize.sh
@@ -0,0 +1,35 @@
+#!/usr/bin/env bash
+
+
+# Begin configuration section.
+nj=32
+cmd=utils/run.pl
+
+echo "$0 $@"
+
+. utils/parse_options.sh || exit 1;
+
+# tokenize configuration
+text_dir=$1
+seg_file=$2
+logdir=$3
+output_dir=$4
+
+txt_dir=${output_dir}/txt; mkdir -p ${output_dir}/txt
+mkdir -p ${logdir}
+
+$cmd JOB=1:$nj $logdir/text_tokenize.JOB.log \
+  python utils/text_tokenize.py -t ${text_dir}/txt/text.JOB.txt \
+      -s ${seg_file} -i JOB -o ${txt_dir} \
+      || exit 1;
+
+# concatenate the text files together.
+for n in $(seq $nj); do
+  cat ${txt_dir}/text.$n.txt || exit 1
+done > ${output_dir}/text || exit 1
+
+for n in $(seq $nj); do
+  cat ${txt_dir}/len.$n || exit 1
+done > ${output_dir}/text_shape || exit 1
+
+echo "$0: Succeeded text tokenize"
diff --git a/examples/aishell/e_paraformer/utils/textnorm_zh.py b/examples/aishell/e_paraformer/utils/textnorm_zh.py
new file mode 100755
index 0000000..9de8e81
--- /dev/null
+++ b/examples/aishell/e_paraformer/utils/textnorm_zh.py
@@ -0,0 +1,911 @@
+#!/usr/bin/env python3
+# coding=utf-8
+
+# Authors:
+#   2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
+#   2019.9 Jiayu DU
+#
+# requirements:
+#   - python 3.X
+# notes: python 2.X WILL fail or produce misleading results
+
+import sys, os, argparse, codecs, string, re
+
+# ================================================================================ #
+#                                    basic constant
+# ================================================================================ #
+CHINESE_DIGIS = "闆朵竴浜屼笁鍥涗簲鍏竷鍏節"
+BIG_CHINESE_DIGIS_SIMPLIFIED = "闆跺9璐板弫鑲嗕紞闄嗘煉鎹岀帠"
+BIG_CHINESE_DIGIS_TRADITIONAL = "闆跺9璨冲弮鑲嗕紞闄告煉鎹岀帠"
+SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "鍗佺櫨鍗冧竾"
+SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "鎷句桨浠熻惉"
+LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "浜垮厗浜灀绉┌娌熸锭姝h浇"
+LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "鍎勫厗浜灀绉┌婧濇緱姝h級"
+SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "鍗佺櫨鍗冧竾"
+SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "鎷句桨浠熻惉"
+
+ZERO_ALT = "銆�"
+ONE_ALT = "骞�"
+TWO_ALTS = ["涓�", "鍏�"]
+
+POSITIVE = ["姝�", "姝�"]
+NEGATIVE = ["璐�", "璨�"]
+POINT = ["鐐�", "榛�"]
+# PLUS = [u'鍔�', u'鍔�']
+# SIL = [u'鏉�', u'妲�']
+
+FILLER_CHARS = ["鍛�", "鍟�"]
+ER_WHITELIST = (
+    "(鍎垮コ|鍎垮瓙|鍎垮瓩|濂冲効|鍎垮|濡诲効|"
+    "鑳庡効|濠村効|鏂扮敓鍎縷濠村辜鍎縷骞煎効|灏戝効|灏忓効|鍎挎瓕|鍎跨|鍎跨|鎵樺効鎵�|瀛ゅ効|"
+    "鍎挎垙|鍎垮寲|鍙板効搴剕楣垮効宀泑姝e効鍏粡|鍚婂効閮庡綋|鐢熷効鑲插コ|鎵樺効甯﹀コ|鍏诲効闃茶�亅鐥村効鍛嗗コ|"
+    "浣冲効浣冲|鍎挎�滃吔鎵皘鍎挎棤甯哥埗|鍎夸笉瀚屾瘝涓憒鍎胯鍗冮噷姣嶆媴蹇鍎垮ぇ涓嶇敱鐖穦鑻忎篂鍎�)"
+)
+
+# 涓枃鏁板瓧绯荤粺绫诲瀷
+NUMBERING_TYPES = ["low", "mid", "high"]
+
+CURRENCY_NAMES = (
+    "(浜烘皯甯亅缇庡厓|鏃ュ厓|鑻遍晳|娆у厓|椹厠|娉曢儙|鍔犳嬁澶у厓|婢冲厓|娓竵|鍏堜护|鑺叞椹厠|鐖卞皵鍏伴晳|"
+    "閲屾媺|鑽峰叞鐩緗鍩冩柉搴撳|姣斿濉攟鍗板凹鐩緗鏋楀悏鐗箌鏂拌タ鍏板厓|姣旂储|鍗㈠竷|鏂板姞鍧″厓|闊╁厓|娉伴摙)"
+)
+CURRENCY_UNITS = (
+    "((浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧�)|(浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧緗)鍏億(浜縷鍗冧竾|鐧句竾|涓噟鍗億鐧緗)鍧梶瑙抾姣泑鍒�)"
+)
+COM_QUANTIFIERS = (
+    "(鍖箌寮爘搴鍥瀨鍦簗灏緗鏉涓獆棣東闃檤闃祙缃憒鐐畖椤秥涓榺妫祙鍙獆鏀瘄琚瓅杈唡鎸憒鎷厊棰梶澹硘绐爘鏇瞸澧檤缇鑵攟"
+    "鐮搴瀹璐瘄鎵巪鎹唡鍒�|浠鎵搢鎵媩缃梶鍧灞眧宀瓅姹焲婧獆閽焲闃焲鍗晐鍙寍瀵箌鍑簗鍙澶磡鑴殀鏉縷璺硘鏋潀浠秥璐磡"
+    "閽坾绾縷绠鍚峾浣峾韬珅鍫倈璇緗鏈瑋椤祙瀹秥鎴穦灞倈涓潀姣珅鍘榺鍒唡閽眧涓鏂鎷厊閾鐭硘閽閿眧蹇絴(鍗億姣珅寰�)鍏媩"
+    "姣珅鍘榺鍒唡瀵竱灏簗涓坾閲寍瀵粅甯竱閾簗绋媩(鍗億鍒唡鍘榺姣珅寰�)绫硘鎾畖鍕簗鍚坾鍗噟鏂梶鐭硘鐩榺纰梶纰焲鍙爘妗秥绗紎鐩唡"
+    "鐩抾鏉瘄閽焲鏂泑閿厊绨媩绡畖鐩榺妗秥缃恷鐡秥澹秥鍗畖鐩弢绠﹟绠眧鐓瞸鍟東琚媩閽祙骞磡鏈坾鏃瀛鍒粅鏃秥鍛▅澶﹟绉抾鍒唡鏃瑋"
+    "绾獆宀亅涓東鏇磡澶渱鏄澶弢绉媩鍐瑋浠浼弢杈坾涓竱娉绮抾棰梶骞鍫唡鏉鏍箌鏀瘄閬搢闈鐗噟寮爘棰梶鍧�)"
+)
+
+# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
+CHINESE_PUNC_STOP = "锛侊紵锝°��"
+CHINESE_PUNC_NON_STOP = "锛傦純锛勶紖锛嗭紘锛堬級锛婏紜锛岋紞锛忥細锛涳紲锛濓紴锛狅蓟锛硷冀锛撅伎锝�锝涳綔锝濓綖锝燂綘锝剑锝ゃ�併�冦�嬨�屻�嶃�庛�忋�愩�戙�斻�曘�栥�椼�樸�欍�氥�涖�溿�濄�炪�熴�般�俱�库�撯�斺�樷�欌�涒�溾�濃�炩�熲�︹�э箯"
+CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
+
+
+# ================================================================================ #
+#                                    basic class
+# ================================================================================ #
+class ChineseChar(object):
+    """
+    涓枃瀛楃
+    姣忎釜瀛楃瀵瑰簲绠�浣撳拰绻佷綋,
+    e.g. 绠�浣� = '璐�', 绻佷綋 = '璨�'
+    杞崲鏃跺彲杞崲涓虹畝浣撴垨绻佷綋
+    """
+
+    def __init__(self, simplified, traditional):
+        self.simplified = simplified
+        self.traditional = traditional
+        # self.__repr__ = self.__str__
+
+    def __str__(self):
+        return self.simplified or self.traditional or None
+
+    def __repr__(self):
+        return self.__str__()
+
+
+class ChineseNumberUnit(ChineseChar):
+    """
+    涓枃鏁板瓧/鏁颁綅瀛楃
+    姣忎釜瀛楃闄ょ箒绠�浣撳杩樻湁涓�涓澶栫殑澶у啓瀛楃
+    e.g. '闄�' 鍜� '闄�'
+    """
+
+    def __init__(self, power, simplified, traditional, big_s, big_t):
+        super(ChineseNumberUnit, self).__init__(simplified, traditional)
+        self.power = power
+        self.big_s = big_s
+        self.big_t = big_t
+
+    def __str__(self):
+        return "10^{}".format(self.power)
+
+    @classmethod
+    def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
+
+        if small_unit:
+            return ChineseNumberUnit(
+                power=index + 1,
+                simplified=value[0],
+                traditional=value[1],
+                big_s=value[1],
+                big_t=value[1],
+            )
+        elif numbering_type == NUMBERING_TYPES[0]:
+            return ChineseNumberUnit(
+                power=index + 8,
+                simplified=value[0],
+                traditional=value[1],
+                big_s=value[0],
+                big_t=value[1],
+            )
+        elif numbering_type == NUMBERING_TYPES[1]:
+            return ChineseNumberUnit(
+                power=(index + 2) * 4,
+                simplified=value[0],
+                traditional=value[1],
+                big_s=value[0],
+                big_t=value[1],
+            )
+        elif numbering_type == NUMBERING_TYPES[2]:
+            return ChineseNumberUnit(
+                power=pow(2, index + 3),
+                simplified=value[0],
+                traditional=value[1],
+                big_s=value[0],
+                big_t=value[1],
+            )
+        else:
+            raise ValueError(
+                "Counting type should be in {0} ({1} provided).".format(
+                    NUMBERING_TYPES, numbering_type
+                )
+            )
+
+
+class ChineseNumberDigit(ChineseChar):
+    """
+    涓枃鏁板瓧瀛楃
+    """
+
+    def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
+        super(ChineseNumberDigit, self).__init__(simplified, traditional)
+        self.value = value
+        self.big_s = big_s
+        self.big_t = big_t
+        self.alt_s = alt_s
+        self.alt_t = alt_t
+
+    def __str__(self):
+        return str(self.value)
+
+    @classmethod
+    def create(cls, i, v):
+        return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
+
+
+class ChineseMath(ChineseChar):
+    """
+    涓枃鏁颁綅瀛楃
+    """
+
+    def __init__(self, simplified, traditional, symbol, expression=None):
+        super(ChineseMath, self).__init__(simplified, traditional)
+        self.symbol = symbol
+        self.expression = expression
+        self.big_s = simplified
+        self.big_t = traditional
+
+
+CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
+
+
+class NumberSystem(object):
+    """
+    涓枃鏁板瓧绯荤粺
+    """
+
+    pass
+
+
+class MathSymbol(object):
+    """
+    鐢ㄤ簬涓枃鏁板瓧绯荤粺鐨勬暟瀛︾鍙� (绻�/绠�浣�), e.g.
+    positive = ['姝�', '姝�']
+    negative = ['璐�', '璨�']
+    point = ['鐐�', '榛�']
+    """
+
+    def __init__(self, positive, negative, point):
+        self.positive = positive
+        self.negative = negative
+        self.point = point
+
+    def __iter__(self):
+        for v in self.__dict__.values():
+            yield v
+
+
+# class OtherSymbol(object):
+#     """
+#     鍏朵粬绗﹀彿
+#     """
+#
+#     def __init__(self, sil):
+#         self.sil = sil
+#
+#     def __iter__(self):
+#         for v in self.__dict__.values():
+#             yield v
+
+
+# ================================================================================ #
+#                                    basic utils
+# ================================================================================ #
+def create_system(numbering_type=NUMBERING_TYPES[1]):
+    """
+    鏍规嵁鏁板瓧绯荤粺绫诲瀷杩斿洖鍒涘缓鐩稿簲鐨勬暟瀛楃郴缁燂紝榛樿涓� mid
+    NUMBERING_TYPES = ['low', 'mid', 'high']: 涓枃鏁板瓧绯荤粺绫诲瀷
+        low:  '鍏�' = '浜�' * '鍗�' = $10^{9}$,  '浜�' = '鍏�' * '鍗�', etc.
+        mid:  '鍏�' = '浜�' * '涓�' = $10^{12}$, '浜�' = '鍏�' * '涓�', etc.
+        high: '鍏�' = '浜�' * '浜�' = $10^{16}$, '浜�' = '鍏�' * '鍏�', etc.
+    杩斿洖瀵瑰簲鐨勬暟瀛楃郴缁�
+    """
+
+    # chinese number units of '浜�' and larger
+    all_larger_units = zip(
+        LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
+    )
+    larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)]
+    # chinese number units of '鍗�, 鐧�, 鍗�, 涓�'
+    all_smaller_units = zip(
+        SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
+    )
+    smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)]
+    # digis
+    chinese_digis = zip(
+        CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL
+    )
+    digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
+    digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
+    digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
+    digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
+
+    # symbols
+    positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
+    negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
+    point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
+    # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
+    system = NumberSystem()
+    system.units = smaller_units + larger_units
+    system.digits = digits
+    system.math = MathSymbol(positive_cn, negative_cn, point_cn)
+    # system.symbols = OtherSymbol(sil_cn)
+    return system
+
+
+def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
+
+    def get_symbol(char, system):
+        for u in system.units:
+            if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
+                return u
+        for d in system.digits:
+            if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
+                return d
+        for m in system.math:
+            if char in [m.traditional, m.simplified]:
+                return m
+
+    def string2symbols(chinese_string, system):
+        int_string, dec_string = chinese_string, ""
+        for p in [system.math.point.simplified, system.math.point.traditional]:
+            if p in chinese_string:
+                int_string, dec_string = chinese_string.split(p)
+                break
+        return [get_symbol(c, system) for c in int_string], [
+            get_symbol(c, system) for c in dec_string
+        ]
+
+    def correct_symbols(integer_symbols, system):
+        """
+        涓�鐧惧叓 to 涓�鐧惧叓鍗�
+        涓�浜夸竴鍗冧笁鐧句竾 to 涓�浜� 涓�鍗冧竾 涓夌櫨涓�
+        """
+
+        if integer_symbols and isinstance(integer_symbols[0], CNU):
+            if integer_symbols[0].power == 1:
+                integer_symbols = [system.digits[1]] + integer_symbols
+
+        if len(integer_symbols) > 1:
+            if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
+                integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None))
+
+        result = []
+        unit_count = 0
+        for s in integer_symbols:
+            if isinstance(s, CND):
+                result.append(s)
+                unit_count = 0
+            elif isinstance(s, CNU):
+                current_unit = CNU(s.power, None, None, None, None)
+                unit_count += 1
+
+            if unit_count == 1:
+                result.append(current_unit)
+            elif unit_count > 1:
+                for i in range(len(result)):
+                    if (
+                        isinstance(result[-i - 1], CNU)
+                        and result[-i - 1].power < current_unit.power
+                    ):
+                        result[-i - 1] = CNU(
+                            result[-i - 1].power + current_unit.power, None, None, None, None
+                        )
+        return result
+
+    def compute_value(integer_symbols):
+        """
+        Compute the value.
+        When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
+        e.g. '涓ゅ崈涓�' = 2000 * 10000 not 2000 + 10000
+        """
+        value = [0]
+        last_power = 0
+        for s in integer_symbols:
+            if isinstance(s, CND):
+                value[-1] = s.value
+            elif isinstance(s, CNU):
+                value[-1] *= pow(10, s.power)
+                if s.power > last_power:
+                    value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
+                    last_power = s.power
+                value.append(0)
+        return sum(value)
+
+    system = create_system(numbering_type)
+    int_part, dec_part = string2symbols(chinese_string, system)
+    int_part = correct_symbols(int_part, system)
+    int_str = str(compute_value(int_part))
+    dec_str = "".join([str(d.value) for d in dec_part])
+    if dec_part:
+        return "{0}.{1}".format(int_str, dec_str)
+    else:
+        return int_str
+
+
+def num2chn(
+    number_string,
+    numbering_type=NUMBERING_TYPES[1],
+    big=False,
+    traditional=False,
+    alt_zero=False,
+    alt_one=False,
+    alt_two=True,
+    use_zeros=True,
+    use_units=True,
+):
+
+    def get_value(value_string, use_zeros=True):
+
+        striped_string = value_string.lstrip("0")
+
+        # record nothing if all zeros
+        if not striped_string:
+            return []
+
+        # record one digits
+        elif len(striped_string) == 1:
+            if use_zeros and len(value_string) != len(striped_string):
+                return [system.digits[0], system.digits[int(striped_string)]]
+            else:
+                return [system.digits[int(striped_string)]]
+
+        # recursively record multiple digits
+        else:
+            result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string))
+            result_string = value_string[: -result_unit.power]
+            return (
+                get_value(result_string)
+                + [result_unit]
+                + get_value(striped_string[-result_unit.power :])
+            )
+
+    system = create_system(numbering_type)
+
+    int_dec = number_string.split(".")
+    if len(int_dec) == 1:
+        int_string = int_dec[0]
+        dec_string = ""
+    elif len(int_dec) == 2:
+        int_string = int_dec[0]
+        dec_string = int_dec[1]
+    else:
+        raise ValueError(
+            "invalid input num string with more than one dot: {}".format(number_string)
+        )
+
+    if use_units and len(int_string) > 1:
+        result_symbols = get_value(int_string)
+    else:
+        result_symbols = [system.digits[int(c)] for c in int_string]
+    dec_symbols = [system.digits[int(c)] for c in dec_string]
+    if dec_string:
+        result_symbols += [system.math.point] + dec_symbols
+
+    if alt_two:
+        liang = CND(
+            2,
+            system.digits[2].alt_s,
+            system.digits[2].alt_t,
+            system.digits[2].big_s,
+            system.digits[2].big_t,
+        )
+        for i, v in enumerate(result_symbols):
+            if isinstance(v, CND) and v.value == 2:
+                next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
+                previous_symbol = result_symbols[i - 1] if i > 0 else None
+                if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
+                    if next_symbol.power != 1 and (
+                        (previous_symbol is None) or (previous_symbol.power != 1)
+                    ):
+                        result_symbols[i] = liang
+
+    # if big is True, '涓�' will not be used and `alt_two` has no impact on output
+    if big:
+        attr_name = "big_"
+        if traditional:
+            attr_name += "t"
+        else:
+            attr_name += "s"
+    else:
+        if traditional:
+            attr_name = "traditional"
+        else:
+            attr_name = "simplified"
+
+    result = "".join([getattr(s, attr_name) for s in result_symbols])
+
+    # if not use_zeros:
+    #     result = result.strip(getattr(system.digits[0], attr_name))
+
+    if alt_zero:
+        result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s)
+
+    if alt_one:
+        result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s)
+
+    for i, p in enumerate(POINT):
+        if result.startswith(p):
+            return CHINESE_DIGIS[0] + result
+
+    # ^10, 11, .., 19
+    if (
+        len(result) >= 2
+        and result[1]
+        in [
+            SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
+            SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
+        ]
+        and result[0]
+        in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]
+    ):
+        result = result[1:]
+
+    return result
+
+
+# ================================================================================ #
+#                          different types of rewriters
+# ================================================================================ #
+class Cardinal:
+    """
+    CARDINAL绫�
+    """
+
+    def __init__(self, cardinal=None, chntext=None):
+        self.cardinal = cardinal
+        self.chntext = chntext
+
+    def chntext2cardinal(self):
+        return chn2num(self.chntext)
+
+    def cardinal2chntext(self):
+        return num2chn(self.cardinal)
+
+
+class Digit:
+    """
+    DIGIT绫�
+    """
+
+    def __init__(self, digit=None, chntext=None):
+        self.digit = digit
+        self.chntext = chntext
+
+    # def chntext2digit(self):
+    #     return chn2num(self.chntext)
+
+    def digit2chntext(self):
+        return num2chn(self.digit, alt_two=False, use_units=False)
+
+
+class TelePhone:
+    """
+    TELEPHONE绫�
+    """
+
+    def __init__(self, telephone=None, raw_chntext=None, chntext=None):
+        self.telephone = telephone
+        self.raw_chntext = raw_chntext
+        self.chntext = chntext
+
+    # def chntext2telephone(self):
+    #     sil_parts = self.raw_chntext.split('<SIL>')
+    #     self.telephone = '-'.join([
+    #         str(chn2num(p)) for p in sil_parts
+    #     ])
+    #     return self.telephone
+
+    def telephone2chntext(self, fixed=False):
+
+        if fixed:
+            sil_parts = self.telephone.split("-")
+            self.raw_chntext = "<SIL>".join(
+                [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
+            )
+            self.chntext = self.raw_chntext.replace("<SIL>", "")
+        else:
+            sp_parts = self.telephone.strip("+").split()
+            self.raw_chntext = "<SP>".join(
+                [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
+            )
+            self.chntext = self.raw_chntext.replace("<SP>", "")
+        return self.chntext
+
+
+class Fraction:
+    """
+    FRACTION绫�
+    """
+
+    def __init__(self, fraction=None, chntext=None):
+        self.fraction = fraction
+        self.chntext = chntext
+
+    def chntext2fraction(self):
+        denominator, numerator = self.chntext.split("鍒嗕箣")
+        return chn2num(numerator) + "/" + chn2num(denominator)
+
+    def fraction2chntext(self):
+        numerator, denominator = self.fraction.split("/")
+        return num2chn(denominator) + "鍒嗕箣" + num2chn(numerator)
+
+
+class Date:
+    """
+    DATE绫�
+    """
+
+    def __init__(self, date=None, chntext=None):
+        self.date = date
+        self.chntext = chntext
+
+    # def chntext2date(self):
+    #     chntext = self.chntext
+    #     try:
+    #         year, other = chntext.strip().split('骞�', maxsplit=1)
+    #         year = Digit(chntext=year).digit2chntext() + '骞�'
+    #     except ValueError:
+    #         other = chntext
+    #         year = ''
+    #     if other:
+    #         try:
+    #             month, day = other.strip().split('鏈�', maxsplit=1)
+    #             month = Cardinal(chntext=month).chntext2cardinal() + '鏈�'
+    #         except ValueError:
+    #             day = chntext
+    #             month = ''
+    #         if day:
+    #             day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
+    #     else:
+    #         month = ''
+    #         day = ''
+    #     date = year + month + day
+    #     self.date = date
+    #     return self.date
+
+    def date2chntext(self):
+        date = self.date
+        try:
+            year, other = date.strip().split("骞�", 1)
+            year = Digit(digit=year).digit2chntext() + "骞�"
+        except ValueError:
+            other = date
+            year = ""
+        if other:
+            try:
+                month, day = other.strip().split("鏈�", 1)
+                month = Cardinal(cardinal=month).cardinal2chntext() + "鏈�"
+            except ValueError:
+                day = date
+                month = ""
+            if day:
+                day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
+        else:
+            month = ""
+            day = ""
+        chntext = year + month + day
+        self.chntext = chntext
+        return self.chntext
+
+
+class Money:
+    """
+    MONEY绫�
+    """
+
+    def __init__(self, money=None, chntext=None):
+        self.money = money
+        self.chntext = chntext
+
+    # def chntext2money(self):
+    #     return self.money
+
+    def money2chntext(self):
+        money = self.money
+        pattern = re.compile(r"(\d+(\.\d+)?)")
+        matchers = pattern.findall(money)
+        if matchers:
+            for matcher in matchers:
+                money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
+        self.chntext = money
+        return self.chntext
+
+
+class Percentage:
+    """
+    PERCENTAGE绫�
+    """
+
+    def __init__(self, percentage=None, chntext=None):
+        self.percentage = percentage
+        self.chntext = chntext
+
+    def chntext2percentage(self):
+        return chn2num(self.chntext.strip().strip("鐧惧垎涔�")) + "%"
+
+    def percentage2chntext(self):
+        return "鐧惧垎涔�" + num2chn(self.percentage.strip().strip("%"))
+
+
+def remove_erhua(text, er_whitelist):
+    """
+    鍘婚櫎鍎垮寲闊宠瘝涓殑鍎�:
+    浠栧コ鍎垮湪閭h竟鍎� -> 浠栧コ鍎垮湪閭h竟
+    """
+
+    er_pattern = re.compile(er_whitelist)
+    new_str = ""
+    while re.search("鍎�", text):
+        a = re.search("鍎�", text).span()
+        remove_er_flag = 0
+
+        if er_pattern.search(text):
+            b = er_pattern.search(text).span()
+            if b[0] <= a[0]:
+                remove_er_flag = 1
+
+        if remove_er_flag == 0:
+            new_str = new_str + text[0 : a[0]]
+            text = text[a[1] :]
+        else:
+            new_str = new_str + text[0 : b[1]]
+            text = text[b[1] :]
+
+    text = new_str + text
+    return text
+
+
+# ================================================================================ #
+#                            NSW Normalizer
+# ================================================================================ #
+class NSWNormalizer:
+    def __init__(self, raw_text):
+        self.raw_text = "^" + raw_text + "$"
+        self.norm_text = ""
+
+    def _particular(self):
+        text = self.norm_text
+        pattern = re.compile(r"(([a-zA-Z]+)浜�([a-zA-Z]+))")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('particular')
+            for matcher in matchers:
+                text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
+        self.norm_text = text
+        return self.norm_text
+
+    def normalize(self):
+        text = self.raw_text
+
+        # 瑙勮寖鍖栨棩鏈�
+        pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})骞�)?(\d{1,2}鏈�(\d{1,2}[鏃ュ彿])?)?)")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('date')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
+
+        # 瑙勮寖鍖栭噾閽�
+        pattern = re.compile(
+            r"\D+((\d+(\.\d+)?)[澶氫綑鍑燷?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)"
+        )
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('money')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
+
+        # 瑙勮寖鍖栧浐璇�/鎵嬫満鍙风爜
+        # 鎵嬫満
+        # http://www.jihaoba.com/news/show/13680
+        # 绉诲姩锛�139銆�138銆�137銆�136銆�135銆�134銆�159銆�158銆�157銆�150銆�151銆�152銆�188銆�187銆�182銆�183銆�184銆�178銆�198
+        # 鑱旈�氾細130銆�131銆�132銆�156銆�155銆�186銆�185銆�176
+        # 鐢典俊锛�133銆�153銆�189銆�180銆�181銆�177
+        pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('telephone')
+            for matcher in matchers:
+                text = text.replace(
+                    matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
+                )
+        # 鍥鸿瘽
+        pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('fixed telephone')
+            for matcher in matchers:
+                text = text.replace(
+                    matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1
+                )
+
+        # 瑙勮寖鍖栧垎鏁�
+        pattern = re.compile(r"(\d+/\d+)")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('fraction')
+            for matcher in matchers:
+                text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
+
+        # 瑙勮寖鍖栫櫨鍒嗘暟
+        text = text.replace("锛�", "%")
+        pattern = re.compile(r"(\d+(\.\d+)?%)")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('percentage')
+            for matcher in matchers:
+                text = text.replace(
+                    matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1
+                )
+
+        # 瑙勮寖鍖栫函鏁�+閲忚瘝
+        pattern = re.compile(r"(\d+(\.\d+)?)[澶氫綑鍑燷?" + COM_QUANTIFIERS)
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('cardinal+quantifier')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+        # 瑙勮寖鍖栨暟瀛楃紪鍙�
+        pattern = re.compile(r"(\d{4,32})")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('digit')
+            for matcher in matchers:
+                text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
+
+        # 瑙勮寖鍖栫函鏁�
+        pattern = re.compile(r"(\d+(\.\d+)?)")
+        matchers = pattern.findall(text)
+        if matchers:
+            # print('cardinal')
+            for matcher in matchers:
+                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
+
+        self.norm_text = text
+        self._particular()
+
+        return self.norm_text.lstrip("^").rstrip("$")
+
+
+def nsw_test_case(raw_text):
+    print("I:" + raw_text)
+    print("O:" + NSWNormalizer(raw_text).normalize())
+    print("")
+
+
+def nsw_test():
+    nsw_test_case("鍥鸿瘽锛�0595-23865596鎴�23880880銆�")
+    nsw_test_case("鍥鸿瘽锛�0595-23865596鎴�23880880銆�")
+    nsw_test_case("鎵嬫満锛�+86 19859213959鎴�15659451527銆�")
+    nsw_test_case("鍒嗘暟锛�32477/76391銆�")
+    nsw_test_case("鐧惧垎鏁帮細80.03%銆�")
+    nsw_test_case("缂栧彿锛�31520181154418銆�")
+    nsw_test_case("绾暟锛�2983.07鍏嬫垨12345.60绫炽��")
+    nsw_test_case("鏃ユ湡锛�1999骞�2鏈�20鏃ユ垨09骞�3鏈�15鍙枫��")
+    nsw_test_case("閲戦挶锛�12鍧�5锛�34.5鍏冿紝20.1涓�")
+    nsw_test_case("鐗规畩锛歄2O鎴朆2C銆�")
+    nsw_test_case("3456涓囧惃")
+    nsw_test_case("2938涓�")
+    nsw_test_case("938")
+    nsw_test_case("浠婂ぉ鍚冧簡115涓皬绗煎寘231涓澶�")
+    nsw_test_case("鏈�62锛呯殑姒傜巼")
+
+
+if __name__ == "__main__":
+    # nsw_test()
+
+    p = argparse.ArgumentParser()
+    p.add_argument("ifile", help="input filename, assume utf-8 encoding")
+    p.add_argument("ofile", help="output filename")
+    p.add_argument("--to_upper", action="store_true", help="convert to upper case")
+    p.add_argument("--to_lower", action="store_true", help="convert to lower case")
+    p.add_argument(
+        "--has_key", action="store_true", help="input text has Kaldi's key as first field."
+    )
+    p.add_argument(
+        "--remove_fillers", type=bool, default=True, help='remove filler chars such as "鍛�, 鍟�"'
+    )
+    p.add_argument(
+        "--remove_erhua", type=bool, default=True, help='remove erhua chars such as "杩欏効"'
+    )
+    p.add_argument(
+        "--log_interval", type=int, default=10000, help="log interval in number of processed lines"
+    )
+    args = p.parse_args()
+
+    ifile = codecs.open(args.ifile, "r", "utf8")
+    ofile = codecs.open(args.ofile, "w+", "utf8")
+
+    n = 0
+    for l in ifile:
+        key = ""
+        text = ""
+        if args.has_key:
+            cols = l.split(maxsplit=1)
+            key = cols[0]
+            if len(cols) == 2:
+                text = cols[1].strip()
+            else:
+                text = ""
+        else:
+            text = l.strip()
+
+        # cases
+        if args.to_upper and args.to_lower:
+            sys.stderr.write("text norm: to_upper OR to_lower?")
+            exit(1)
+        if args.to_upper:
+            text = text.upper()
+        if args.to_lower:
+            text = text.lower()
+
+        # Filler chars removal
+        if args.remove_fillers:
+            for ch in FILLER_CHARS:
+                text = text.replace(ch, "")
+
+        if args.remove_erhua:
+            text = remove_erhua(text, ER_WHITELIST)
+
+        # NSW(Non-Standard-Word) normalization
+        text = NSWNormalizer(text).normalize()
+
+        # Punctuations removal
+        old_chars = CHINESE_PUNC_LIST + string.punctuation  # includes all CN and EN punctuations
+        new_chars = " " * len(old_chars)
+        del_chars = ""
+        text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
+
+        #
+        if args.has_key:
+            ofile.write(key + "\t" + text + "\n")
+        else:
+            ofile.write(text + "\n")
+
+        n += 1
+        if n % args.log_interval == 0:
+            sys.stderr.write("text norm: {} lines done.\n".format(n))
+
+    sys.stderr.write("text norm: {} lines done in total.\n".format(n))
+
+    ifile.close()
+    ofile.close()
diff --git a/funasr/models/e_paraformer/__init__.py b/funasr/models/e_paraformer/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/e_paraformer/__init__.py
diff --git a/funasr/models/e_paraformer/decoder.py b/funasr/models/e_paraformer/decoder.py
new file mode 100644
index 0000000..7edd91a
--- /dev/null
+++ b/funasr/models/e_paraformer/decoder.py
@@ -0,0 +1,1193 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import torch
+from typing import List, Tuple
+
+from funasr.register import tables
+from funasr.models.scama import utils as myutils
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.decoder import DecoderLayer
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.decoder import BaseTransformerDecoder
+from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.sanm.attention import (
+    MultiHeadedAttentionSANMDecoder,
+    MultiHeadedAttentionCrossAtt,
+)
+
+
+class DecoderLayerSANM(torch.nn.Module):
+    """Single decoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        src_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+    """
+
+    def __init__(
+        self,
+        size,
+        self_attn,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(DecoderLayerSANM, self).__init__()
+        self.size = size
+        self.self_attn = self_attn
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        if self_attn is not None:
+            self.norm2 = LayerNorm(size)
+        if src_attn is not None:
+            self.norm3 = LayerNorm(size)
+        self.dropout = torch.nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear1 = torch.nn.Linear(size + size, size)
+            self.concat_linear2 = torch.nn.Linear(size + size, size)
+        self.reserve_attn = False
+        self.attn_mat = []
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        # tgt = self.dropout(tgt)
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, _ = self.self_attn(tgt, tgt_mask)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+            if self.reserve_attn:
+                x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
+                self.attn_mat.append(attn_mat)
+            else:
+                x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False)
+            x = residual + self.dropout(x_src_attn)
+            # x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        residual = tgt
+        tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn is not None:
+            tgt = self.norm2(tgt)
+            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x = residual + x
+
+        residual = x
+        x = self.norm3(x)
+        x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
+        return attn_mat
+
+    def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        # tgt = self.dropout(tgt)
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            if self.training:
+                cache = None
+            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def forward_chunk(
+        self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
+    ):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn:
+            if self.normalize_before:
+                tgt = self.norm2(tgt)
+            x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
+            x = residual + self.dropout(x)
+
+        if self.src_attn is not None:
+            residual = x
+            if self.normalize_before:
+                x = self.norm3(x)
+
+            x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
+            x = residual + x
+
+        return x, memory, fsmn_cache, opt_cache
+
+
+@tables.register("decoder_classes", "ParaformerSANMDecoder")
+class ParaformerSANMDecoder(BaseTransformerDecoder):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2006.01713
+    """
+
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        wo_input_layer: bool = False,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        att_layer_num: int = 6,
+        kernel_size: int = 21,
+        sanm_shfit: int = 0,
+        lora_list: List[str] = None,
+        lora_rank: int = 8,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.1,
+        chunk_multiply_factor: tuple = (1,),
+        tf2torch_tensor_name_prefix_torch: str = "decoder",
+        tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+    ):
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+        if wo_input_layer:
+            self.embed = None
+        else:
+            if input_layer == "embed":
+                self.embed = torch.nn.Sequential(
+                    torch.nn.Embedding(vocab_size, attention_dim),
+                    # pos_enc_class(attention_dim, positional_dropout_rate),
+                )
+            elif input_layer == "linear":
+                self.embed = torch.nn.Sequential(
+                    torch.nn.Linear(vocab_size, attention_dim),
+                    torch.nn.LayerNorm(attention_dim),
+                    torch.nn.Dropout(dropout_rate),
+                    torch.nn.ReLU(),
+                    pos_enc_class(attention_dim, positional_dropout_rate),
+                )
+            else:
+                raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+        self.normalize_before = normalize_before
+        if self.normalize_before:
+            self.after_norm = LayerNorm(attention_dim)
+        if use_output_layer:
+            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+        else:
+            self.output_layer = None
+
+        self.att_layer_num = att_layer_num
+        self.num_blocks = num_blocks
+        if sanm_shfit is None:
+            sanm_shfit = (kernel_size - 1) // 2
+        self.decoders = repeat(
+            att_layer_num,
+            lambda lnum: DecoderLayerSANM(
+                attention_dim,
+                MultiHeadedAttentionSANMDecoder(
+                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+                ),
+                MultiHeadedAttentionCrossAtt(
+                    attention_heads,
+                    attention_dim,
+                    src_attention_dropout_rate,
+                    lora_list,
+                    lora_rank,
+                    lora_alpha,
+                    lora_dropout,
+                ),
+                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if num_blocks - att_layer_num <= 0:
+            self.decoders2 = None
+        else:
+            self.decoders2 = repeat(
+                num_blocks - att_layer_num,
+                lambda lnum: DecoderLayerSANM(
+                    attention_dim,
+                    MultiHeadedAttentionSANMDecoder(
+                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+                    ),
+                    None,
+                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                    dropout_rate,
+                    normalize_before,
+                    concat_after,
+                ),
+            )
+
+        self.decoders3 = repeat(
+            1,
+            lambda lnum: DecoderLayerSANM(
+                attention_dim,
+                None,
+                None,
+                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
+        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
+        self.chunk_multiply_factor = chunk_multiply_factor
+
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        chunk_mask: torch.Tensor = None,
+        return_hidden: bool = False,
+        return_both: bool = False,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        tgt = ys_in_pad
+        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+        memory = hs_pad
+        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        if chunk_mask is not None:
+            memory_mask = memory_mask * chunk_mask
+            if tgt_mask.size(1) != memory_mask.size(1):
+                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
+
+        x = tgt
+        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
+        if self.decoders2 is not None:
+            x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
+        x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
+        if self.normalize_before:
+            hidden = self.after_norm(x)
+
+        olens = tgt_mask.sum(1)
+        if self.output_layer is not None and return_hidden is False:
+            x = self.output_layer(hidden)
+            return x, olens
+        if return_both:
+            x = self.output_layer(hidden)
+            return x, hidden, olens
+        return hidden, olens
+
+    def score(self, ys, state, x):
+        """Score."""
+        ys_mask = myutils.sequence_mask(
+            torch.tensor([len(ys)], dtype=torch.int32), device=x.device
+        )[:, :, None]
+        logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
+        return logp.squeeze(0), state
+
+    def forward_asf2(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+    ):
+
+        tgt = ys_in_pad
+        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+        memory = hs_pad
+        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
+        attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+        return attn_mat
+
+    def forward_asf6(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+    ):
+
+        tgt = ys_in_pad
+        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+        memory = hs_pad
+        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
+        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[1](tgt, tgt_mask, memory, memory_mask)
+        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[2](tgt, tgt_mask, memory, memory_mask)
+        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[3](tgt, tgt_mask, memory, memory_mask)
+        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[4](tgt, tgt_mask, memory, memory_mask)
+        attn_mat = self.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+        return attn_mat
+
+    def forward_chunk(
+        self,
+        memory: torch.Tensor,
+        tgt: torch.Tensor,
+        cache: dict = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        x = tgt
+        if cache["decode_fsmn"] is None:
+            cache_layer_num = len(self.decoders)
+            if self.decoders2 is not None:
+                cache_layer_num += len(self.decoders2)
+            fsmn_cache = [None] * cache_layer_num
+        else:
+            fsmn_cache = cache["decode_fsmn"]
+
+        if cache["opt"] is None:
+            cache_layer_num = len(self.decoders)
+            opt_cache = [None] * cache_layer_num
+        else:
+            opt_cache = cache["opt"]
+
+        for i in range(self.att_layer_num):
+            decoder = self.decoders[i]
+            x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
+                x,
+                memory,
+                fsmn_cache=fsmn_cache[i],
+                opt_cache=opt_cache[i],
+                chunk_size=cache["chunk_size"],
+                look_back=cache["decoder_chunk_look_back"],
+            )
+
+        if self.num_blocks - self.att_layer_num > 1:
+            for i in range(self.num_blocks - self.att_layer_num):
+                j = i + self.att_layer_num
+                decoder = self.decoders2[i]
+                x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
+                    x, memory, fsmn_cache=fsmn_cache[j]
+                )
+
+        for decoder in self.decoders3:
+            x, memory, _, _ = decoder.forward_chunk(x, memory)
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+
+        cache["decode_fsmn"] = fsmn_cache
+        if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
+            cache["opt"] = opt_cache
+        return x
+
+    def forward_one_step(
+        self,
+        tgt: torch.Tensor,
+        tgt_mask: torch.Tensor,
+        memory: torch.Tensor,
+        cache: List[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
+        """Forward one step.
+
+        Args:
+            tgt: input token ids, int64 (batch, maxlen_out)
+            tgt_mask: input token mask,  (batch, maxlen_out)
+                      dtype=torch.uint8 in PyTorch 1.2-
+                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
+            memory: encoded memory, float32  (batch, maxlen_in, feat)
+            cache: cached output list of (batch, max_time_out-1, size)
+        Returns:
+            y, cache: NN output value and cache per `self.decoders`.
+            y.shape` is (batch, maxlen_out, token)
+        """
+        x = self.embed(tgt)
+        if cache is None:
+            cache_layer_num = len(self.decoders)
+            if self.decoders2 is not None:
+                cache_layer_num += len(self.decoders2)
+            cache = [None] * cache_layer_num
+        new_cache = []
+        # for c, decoder in zip(cache, self.decoders):
+        for i in range(self.att_layer_num):
+            decoder = self.decoders[i]
+            c = cache[i]
+            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                x, tgt_mask, memory, None, cache=c
+            )
+            new_cache.append(c_ret)
+
+        if self.num_blocks - self.att_layer_num > 1:
+            for i in range(self.num_blocks - self.att_layer_num):
+                j = i + self.att_layer_num
+                decoder = self.decoders2[i]
+                c = cache[j]
+                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
+                    x, tgt_mask, memory, None, cache=c
+                )
+                new_cache.append(c_ret)
+
+        for decoder in self.decoders3:
+
+            x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
+                x, tgt_mask, memory, None, cache=None
+            )
+
+        if self.normalize_before:
+            y = self.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+        if self.output_layer is not None:
+            y = torch.log_softmax(self.output_layer(y), dim=-1)
+
+        return y, new_cache
+
+
+class DecoderLayerSANMExport(torch.nn.Module):
+
+    def __init__(self, model):
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.src_attn = model.src_attn
+        self.feed_forward = model.feed_forward
+        self.norm1 = model.norm1
+        self.norm2 = model.norm2 if hasattr(model, "norm2") else None
+        self.norm3 = model.norm3 if hasattr(model, "norm3") else None
+        self.size = model.size
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+
+        residual = tgt
+        tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn is not None:
+            tgt = self.norm2(tgt)
+            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x = residual + x
+
+        if self.src_attn is not None:
+            residual = x
+            x = self.norm3(x)
+            x = residual + self.src_attn(x, memory, memory_mask)
+
+        return x, tgt_mask, memory, memory_mask, cache
+
+    def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        residual = tgt
+        tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.self_attn is not None:
+            tgt = self.norm2(tgt)
+            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+            x = residual + x
+
+        residual = x
+        x = self.norm3(x)
+        x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
+        return attn_mat
+
+
+@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
+class ParaformerSANMDecoderExport(torch.nn.Module):
+    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
+        super().__init__()
+        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+
+        from funasr.utils.torch_function import sequence_mask
+
+        self.model = model
+
+        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
+        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
+
+        for i, d in enumerate(self.model.decoders):
+            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
+            self.model.decoders[i] = DecoderLayerSANMExport(d)
+
+        if self.model.decoders2 is not None:
+            for i, d in enumerate(self.model.decoders2):
+                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+                self.model.decoders2[i] = DecoderLayerSANMExport(d)
+
+        for i, d in enumerate(self.model.decoders3):
+            self.model.decoders3[i] = DecoderLayerSANMExport(d)
+
+        self.output_layer = model.output_layer
+        self.after_norm = model.after_norm
+        self.model_name = model_name
+
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+        return mask_3d_btd, mask_4d_bhlt
+
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        return_hidden: bool = False,
+        return_both: bool = False,
+    ):
+
+        tgt = ys_in_pad
+        tgt_mask = self.make_pad_mask(ys_in_lens)
+        tgt_mask, _ = self.prepare_mask(tgt_mask)
+        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+        memory = hs_pad
+        memory_mask = self.make_pad_mask(hlens)
+        _, memory_mask = self.prepare_mask(memory_mask)
+        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+        x = tgt
+        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(x, tgt_mask, memory, memory_mask)
+        if self.model.decoders2 is not None:
+            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
+                x, tgt_mask, memory, memory_mask
+            )
+        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
+        hidden = self.after_norm(x)
+        # x = self.output_layer(x)
+
+        if self.output_layer is not None and return_hidden is False:
+            x = self.output_layer(hidden)
+            return x, ys_in_lens
+        if return_both:
+            x = self.output_layer(hidden)
+            return x, hidden, ys_in_lens
+        return hidden, ys_in_lens
+
+    def forward_asf2(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+    ):
+
+        tgt = ys_in_pad
+        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+        memory = hs_pad
+        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        _, memory_mask = self.prepare_mask(memory_mask)
+
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
+            tgt, tgt_mask, memory, memory_mask
+        )
+        attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+        return attn_mat
+
+    def forward_asf6(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+    ):
+
+        tgt = ys_in_pad
+        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+        memory = hs_pad
+        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+        _, memory_mask = self.prepare_mask(memory_mask)
+
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
+            tgt, tgt_mask, memory, memory_mask
+        )
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](
+            tgt, tgt_mask, memory, memory_mask
+        )
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](
+            tgt, tgt_mask, memory, memory_mask
+        )
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](
+            tgt, tgt_mask, memory, memory_mask
+        )
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](
+            tgt, tgt_mask, memory, memory_mask
+        )
+        attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+        return attn_mat
+
+    """
+    def get_dummy_inputs(self, enc_size):
+        tgt = torch.LongTensor([0]).unsqueeze(0)
+        memory = torch.randn(1, 100, enc_size)
+        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        cache = [
+            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+            for _ in range(cache_num)
+        ]
+        return (tgt, memory, pre_acoustic_embeds, cache)
+    
+    def is_optimizable(self):
+        return True
+    
+    def get_input_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+               + ['cache_%d' % i for i in range(cache_num)]
+    
+    def get_output_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ['y'] \
+               + ['out_cache_%d' % i for i in range(cache_num)]
+    
+    def get_dynamic_axes(self):
+        ret = {
+            'tgt': {
+                0: 'tgt_batch',
+                1: 'tgt_length'
+            },
+            'memory': {
+                0: 'memory_batch',
+                1: 'memory_length'
+            },
+            'pre_acoustic_embeds': {
+                0: 'acoustic_embeds_batch',
+                1: 'acoustic_embeds_length',
+            }
+        }
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        ret.update({
+            'cache_%d' % d: {
+                0: 'cache_%d_batch' % d,
+                2: 'cache_%d_length' % d
+            }
+            for d in range(cache_num)
+        })
+        return ret
+    """
+
+
+@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
+class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
+    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
+        super().__init__()
+        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+        self.model = model
+
+        from funasr.utils.torch_function import sequence_mask
+
+        self.model = model
+
+        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
+        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
+
+        for i, d in enumerate(self.model.decoders):
+            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
+            self.model.decoders[i] = DecoderLayerSANMExport(d)
+
+        if self.model.decoders2 is not None:
+            for i, d in enumerate(self.model.decoders2):
+                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+                self.model.decoders2[i] = DecoderLayerSANMExport(d)
+
+        for i, d in enumerate(self.model.decoders3):
+            self.model.decoders3[i] = DecoderLayerSANMExport(d)
+
+        self.output_layer = model.output_layer
+        self.after_norm = model.after_norm
+        self.model_name = model_name
+
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+        return mask_3d_btd, mask_4d_bhlt
+
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        *args,
+    ):
+
+        tgt = ys_in_pad
+        tgt_mask = self.make_pad_mask(ys_in_lens)
+        tgt_mask, _ = self.prepare_mask(tgt_mask)
+        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+        memory = hs_pad
+        memory_mask = self.make_pad_mask(hlens)
+        _, memory_mask = self.prepare_mask(memory_mask)
+        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+        x = tgt
+        out_caches = list()
+        for i, decoder in enumerate(self.model.decoders):
+            in_cache = args[i]
+            x, tgt_mask, memory, memory_mask, out_cache = decoder(
+                x, tgt_mask, memory, memory_mask, cache=in_cache
+            )
+            out_caches.append(out_cache)
+        if self.model.decoders2 is not None:
+            for i, decoder in enumerate(self.model.decoders2):
+                in_cache = args[i + len(self.model.decoders)]
+                x, tgt_mask, memory, memory_mask, out_cache = decoder(
+                    x, tgt_mask, memory, memory_mask, cache=in_cache
+                )
+                out_caches.append(out_cache)
+        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
+        x = self.after_norm(x)
+        x = self.output_layer(x)
+
+        return x, out_caches
+
+    def get_dummy_inputs(self, enc_size):
+        enc = torch.randn(2, 100, enc_size).type(torch.float32)
+        enc_len = torch.tensor([30, 100], dtype=torch.int32)
+        acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
+        acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
+        cache_num = len(self.model.decoders)
+        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
+            cache_num += len(self.model.decoders2)
+        cache = [
+            torch.zeros(
+                (2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
+                dtype=torch.float32,
+            )
+            for _ in range(cache_num)
+        ]
+        return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
+
+    def get_input_names(self):
+        cache_num = len(self.model.decoders)
+        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
+            cache_num += len(self.model.decoders2)
+        return ["enc", "enc_len", "acoustic_embeds", "acoustic_embeds_len"] + [
+            "in_cache_%d" % i for i in range(cache_num)
+        ]
+
+    def get_output_names(self):
+        cache_num = len(self.model.decoders)
+        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
+            cache_num += len(self.model.decoders2)
+        return ["logits", "sample_ids"] + ["out_cache_%d" % i for i in range(cache_num)]
+
+    def get_dynamic_axes(self):
+        ret = {
+            "enc": {0: "batch_size", 1: "enc_length"},
+            "acoustic_embeds": {0: "batch_size", 1: "token_length"},
+            "enc_len": {
+                0: "batch_size",
+            },
+            "acoustic_embeds_len": {
+                0: "batch_size",
+            },
+        }
+        cache_num = len(self.model.decoders)
+        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
+            cache_num += len(self.model.decoders2)
+        ret.update(
+            {
+                "in_cache_%d"
+                % d: {
+                    0: "batch_size",
+                }
+                for d in range(cache_num)
+            }
+        )
+        ret.update(
+            {
+                "out_cache_%d"
+                % d: {
+                    0: "batch_size",
+                }
+                for d in range(cache_num)
+            }
+        )
+        return ret
+
+
+@tables.register("decoder_classes", "ParaformerSANDecoder")
+class ParaformerSANDecoder(BaseTransformerDecoder):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2006.01713
+    """
+
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        embeds_id: int = -1,
+    ):
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+        self.decoders = repeat(
+            num_blocks,
+            lambda lnum: DecoderLayer(
+                attention_dim,
+                MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
+                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
+                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        self.embeds_id = embeds_id
+        self.attention_dim = attention_dim
+
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        tgt = ys_in_pad
+        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
+
+        memory = hs_pad
+        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
+        # Padding for Longformer
+        if memory_mask.shape[-1] != memory.shape[1]:
+            padlen = memory.shape[1] - memory_mask.shape[-1]
+            memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
+
+        # x = self.embed(tgt)
+        x = tgt
+        embeds_outputs = None
+        for layer_id, decoder in enumerate(self.decoders):
+            x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, memory_mask)
+            if layer_id == self.embeds_id:
+                embeds_outputs = x
+        if self.normalize_before:
+            x = self.after_norm(x)
+        if self.output_layer is not None:
+            x = self.output_layer(x)
+
+        olens = tgt_mask.sum(1)
+        if embeds_outputs is not None:
+            return x, olens, embeds_outputs
+        else:
+            return x, olens
+
+
+@tables.register("decoder_classes", "ParaformerDecoderSANExport")
+class ParaformerDecoderSANExport(torch.nn.Module):
+    def __init__(
+        self,
+        model,
+        max_seq_len=512,
+        model_name="decoder",
+        onnx: bool = True,
+    ):
+        super().__init__()
+        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+        self.model = model
+
+        from funasr.utils.torch_function import sequence_mask
+
+        self.model = model
+
+        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+        from funasr.models.transformer.decoder import DecoderLayerExport
+        from funasr.models.transformer.attention import MultiHeadedAttentionExport
+
+        for i, d in enumerate(self.model.decoders):
+            if isinstance(d.src_attn, MultiHeadedAttention):
+                d.src_attn = MultiHeadedAttentionExport(d.src_attn)
+            self.model.decoders[i] = DecoderLayerExport(d)
+
+        self.output_layer = model.output_layer
+        self.after_norm = model.after_norm
+        self.model_name = model_name
+
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+        return mask_3d_btd, mask_4d_bhlt
+
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+    ):
+
+        tgt = ys_in_pad
+        tgt_mask = self.make_pad_mask(ys_in_lens)
+        tgt_mask, _ = self.prepare_mask(tgt_mask)
+        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+        memory = hs_pad
+        memory_mask = self.make_pad_mask(hlens)
+        _, memory_mask = self.prepare_mask(memory_mask)
+        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+        x = tgt
+        x, tgt_mask, memory, memory_mask = self.model.decoders(x, tgt_mask, memory, memory_mask)
+        x = self.after_norm(x)
+        x = self.output_layer(x)
+
+        return x, ys_in_lens
+
+    def get_dummy_inputs(self, enc_size):
+        tgt = torch.LongTensor([0]).unsqueeze(0)
+        memory = torch.randn(1, 100, enc_size)
+        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        cache = [
+            torch.zeros(
+                (1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)
+            )
+            for _ in range(cache_num)
+        ]
+        return (tgt, memory, pre_acoustic_embeds, cache)
+
+    def is_optimizable(self):
+        return True
+
+    def get_input_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ["tgt", "memory", "pre_acoustic_embeds"] + ["cache_%d" % i for i in range(cache_num)]
+
+    def get_output_names(self):
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
+
+    def get_dynamic_axes(self):
+        ret = {
+            "tgt": {0: "tgt_batch", 1: "tgt_length"},
+            "memory": {0: "memory_batch", 1: "memory_length"},
+            "pre_acoustic_embeds": {
+                0: "acoustic_embeds_batch",
+                1: "acoustic_embeds_length",
+            },
+        }
+        cache_num = len(self.model.decoders) + len(self.model.decoders2)
+        ret.update(
+            {
+                "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
+                for d in range(cache_num)
+            }
+        )
+        return ret
diff --git a/funasr/models/e_paraformer/export_meta.py b/funasr/models/e_paraformer/export_meta.py
new file mode 100644
index 0000000..db93855
--- /dev/null
+++ b/funasr/models/e_paraformer/export_meta.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import types
+import torch
+from funasr.register import tables
+
+
+def export_rebuild_model(model, **kwargs):
+    model.device = kwargs.get("device")
+    is_onnx = kwargs.get("type", "onnx") == "onnx"
+    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+    model.encoder = encoder_class(model.encoder, onnx=is_onnx)
+
+    predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
+    model.predictor = predictor_class(model.predictor, onnx=is_onnx)
+
+    decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
+    model.decoder = decoder_class(model.decoder, onnx=is_onnx)
+
+    from funasr.utils.torch_function import sequence_mask
+
+    model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
+
+    model.forward = types.MethodType(export_forward, model)
+    model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
+    model.export_input_names = types.MethodType(export_input_names, model)
+    model.export_output_names = types.MethodType(export_output_names, model)
+    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
+    model.export_name = types.MethodType(export_name, model)
+
+    model.export_name = 'model'
+    return model
+
+
+def export_forward(
+    self,
+    speech: torch.Tensor,
+    speech_lengths: torch.Tensor,
+):
+    # a. To device
+    batch = {"speech": speech, "speech_lengths": speech_lengths}
+    # batch = to_device(batch, device=self.device)
+
+    enc, enc_len = self.encoder(**batch)
+    mask = self.make_pad_mask(enc_len)[:, None, :]
+    pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+    pre_token_length = pre_token_length.floor().type(torch.int32)
+
+    decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+    decoder_out = torch.log_softmax(decoder_out, dim=-1)
+    # sample_ids = decoder_out.argmax(dim=-1)
+
+    return decoder_out, pre_token_length
+
+
+def export_dummy_inputs(self):
+    speech = torch.randn(2, 30, 560)
+    speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+    return (speech, speech_lengths)
+
+
+def export_input_names(self):
+    return ["speech", "speech_lengths"]
+
+
+def export_output_names(self):
+    return ["logits", "token_num"]
+
+
+def export_dynamic_axes(self):
+    return {
+        "speech": {0: "batch_size", 1: "feats_length"},
+        "speech_lengths": {
+            0: "batch_size",
+        },
+        "logits": {0: "batch_size", 1: "logits_length"},
+    }
+
+
+def export_name(
+    self,
+):
+    return "model.onnx"
diff --git a/funasr/models/e_paraformer/model.py b/funasr/models/e_paraformer/model.py
new file mode 100644
index 0000000..47b0190
--- /dev/null
+++ b/funasr/models/e_paraformer/model.py
@@ -0,0 +1,670 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# Copyright 2024 Kun Zou (chinazoukun@gmail.com). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import time
+import copy
+import torch
+import logging
+from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
+
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import to_device
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos, add_sos_and_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+
+@tables.register("model_classes", "EParaformer")
+class EParaformer(torch.nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2206.08317
+    Author: Kun Zou, chinazoukun@gmail.com
+    E-Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognition
+    https://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf
+    """
+
+    def __init__(
+        self,
+        specaug: Optional[str] = None,
+        specaug_conf: Optional[Dict] = None,
+        normalize: str = None,
+        normalize_conf: Optional[Dict] = None,
+        encoder: str = None,
+        encoder_conf: Optional[Dict] = None,
+        decoder: str = None,
+        decoder_conf: Optional[Dict] = None,
+        ctc: str = None,
+        ctc_conf: Optional[Dict] = None,
+        predictor: str = None,
+        predictor_conf: Optional[Dict] = None,
+        ctc_weight: float = 0.5,
+        input_size: int = 80,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+        # report_cer: bool = True,
+        # report_wer: bool = True,
+        # sym_space: str = "<space>",
+        # sym_blank: str = "<blank>",
+        # extract_feats_in_collect_stats: bool = True,
+        # predictor=None,
+        predictor_weight: float = 0.0,
+        predictor_bias: int = 2,
+        sampling_ratio: float = 0.2,
+        share_embedding: bool = False,
+        # preencoder: Optional[AbsPreEncoder] = None,
+        # postencoder: Optional[AbsPostEncoder] = None,
+        use_1st_decoder_loss: bool = True,
+        **kwargs,
+    ):
+
+        super().__init__()
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(input_size=input_size, **encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        if decoder is not None:
+            decoder_class = tables.decoder_classes.get(decoder)
+            decoder = decoder_class(
+                vocab_size=vocab_size,
+                encoder_output_size=encoder_output_size,
+                **decoder_conf,
+            )
+        if ctc_weight > 0.0:
+
+            if ctc_conf is None:
+                ctc_conf = {}
+
+            ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
+        if predictor is not None:
+            predictor_class = tables.predictor_classes.get(predictor)
+            predictor = predictor_class(**predictor_conf)
+
+        # note that eos is the same as sos (equivalent ID)
+        self.blank_id = blank_id
+        self.sos = sos if sos is not None else vocab_size - 1
+        self.eos = eos if eos is not None else vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.ctc_weight = ctc_weight
+        # self.token_list = token_list.copy()
+        #
+        # self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+        # self.preencoder = preencoder
+        # self.postencoder = postencoder
+        self.encoder = encoder
+        #
+        # if not hasattr(self.encoder, "interctc_use_conditioning"):
+        #     self.encoder.interctc_use_conditioning = False
+        # if self.encoder.interctc_use_conditioning:
+        #     self.encoder.conditioning_layer = torch.nn.Linear(
+        #         vocab_size, self.encoder.output_size()
+        #     )
+        #
+        # self.error_calculator = None
+        #
+        if ctc_weight == 1.0:
+            self.decoder = None
+        else:
+            self.decoder = decoder
+
+        self.criterion_att = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+
+        if use_1st_decoder_loss:
+            self.criterion_att_1st = LabelSmoothingLoss(
+                size=vocab_size,
+                padding_idx=ignore_id,
+                smoothing=lsm_weight,
+                normalize_length=length_normalized_loss,
+            )
+
+
+        #
+        # if report_cer or report_wer:
+        #     self.error_calculator = ErrorCalculator(
+        #         token_list, sym_space, sym_blank, report_cer, report_wer
+        #     )
+        #
+        if ctc_weight == 0.0:
+            self.ctc = None
+        else:
+            self.ctc = ctc
+        #
+        # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+        self.predictor = predictor
+        self.predictor_weight = predictor_weight
+        self.predictor_bias = predictor_bias
+        self.sampling_ratio = sampling_ratio
+        self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
+
+        self.share_embedding = share_embedding
+        if self.share_embedding:
+            self.decoder.embed = None
+
+        self.use_1st_decoder_loss = use_1st_decoder_loss
+        self.length_normalized_loss = length_normalized_loss
+        self.beam_search = None
+        self.error_calculator = None
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+
+        batch_size = speech.shape[0]
+
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        loss_ctc, cer_ctc = None, None
+        loss_pre = None
+        stats = dict()
+
+        # decoder: CTC branch
+        if self.ctc_weight != 0.0:
+            loss_ctc, cer_ctc = self._calc_ctc_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+
+            # Collect CTC branch stats
+            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+            stats["cer_ctc"] = cer_ctc
+
+        # decoder: Attention decoder branch
+        loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
+            encoder_out, encoder_out_lens, text, text_lengths
+        )
+
+        # 3. CTC-Att loss definition
+        if self.ctc_weight == 0.0:
+            loss = loss_att + loss_pre * self.predictor_weight
+        else:
+            loss = (
+                self.ctc_weight * loss_ctc
+                + (1 - self.ctc_weight) * loss_att
+                + loss_pre * self.predictor_weight
+            )
+        if pre_loss_att is not None:
+            loss += pre_loss_att
+        # Collect Attn branch stats
+        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+        stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
+        stats["acc"] = acc_att
+        stats["cer"] = cer_att
+        stats["wer"] = wer_att
+        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+
+        stats["loss"] = torch.clone(loss.detach())
+        stats["batch_size"] = batch_size
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = (text_lengths + self.predictor_bias).sum()
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+        # Forward encoder
+        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        return encoder_out, encoder_out_lens
+
+    def calc_predictor(self, encoder_out, encoder_out_lens):
+
+        encoder_out_mask = (
+            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
+        ).to(encoder_out.device)
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
+            encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
+        )
+        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
+
+    def cal_decoder_with_predictor(
+        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
+    ):
+
+        decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
+        decoder_out = decoder_outs[0]
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        return decoder_out, ys_pad_lens
+
+    def _calc_att_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        encoder_out_mask = (
+            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
+        ).to(encoder_out.device)
+        if self.predictor_bias == 1:
+            _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+            ys_pad_lens = ys_pad_lens + self.predictor_bias
+        if self.predictor_bias == 2:
+            _, ys_pad = add_sos_and_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+            ys_pad_lens = ys_pad_lens + self.predictor_bias
+
+        pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(
+            encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id
+        )
+
+        # 0. sampler
+        decoder_out_1st = None
+        pre_loss_att = None
+        if self.sampling_ratio > 0.0:
+            if self.use_1st_decoder_loss:
+                sematic_embeds, decoder_out_1st = self.sampler_with_grad(
+                    encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds
+                )
+            else:
+
+                sematic_embeds, decoder_out_1st = self.sampler(
+                    encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds
+                )
+        else:
+            sematic_embeds = pre_acoustic_embeds
+
+        # 1. Forward decoder
+        decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
+        decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+        if decoder_out_1st is None:
+            decoder_out_1st = decoder_out
+        # 2. Compute attention loss
+        if self.use_1st_decoder_loss:
+            pre_loss_att = self.criterion_att_1st(decoder_out_1st, ys_pad)
+        loss_att = self.criterion_att(decoder_out, ys_pad)
+        acc_att = th_accuracy(
+            decoder_out_1st.view(-1, self.vocab_size),
+            ys_pad,
+            ignore_label=self.ignore_id,
+        )
+        loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+        # Compute cer/wer using attention-decoder
+        if self.training or self.error_calculator is None:
+            cer_att, wer_att = None, None
+        else:
+            ys_hat = decoder_out_1st.argmax(dim=-1)
+            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+        return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
+
+    def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
+
+        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
+            ys_pad.device
+        )
+        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+        if self.share_embedding:
+            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+        else:
+            ys_pad_embed = self.decoder.embed(ys_pad_masked)
+        with torch.no_grad():
+            decoder_outs = self.decoder(
+                encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
+            )
+            decoder_out, _ = decoder_outs[0], decoder_outs[1]
+            pred_tokens = decoder_out.argmax(-1)
+            nonpad_positions = ys_pad.ne(self.ignore_id)
+            seq_lens = (nonpad_positions).sum(1)
+            same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+            input_mask = torch.ones_like(nonpad_positions)
+            bsz, seq_len = ys_pad.size()
+            for li in range(bsz):
+                target_num = (
+                    ((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio
+                ).long()
+                if target_num > 0:
+                    input_mask[li].scatter_(
+                        dim=0,
+                        index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
+                        value=0,
+                    )
+            input_mask = input_mask.eq(1)
+            input_mask = input_mask.masked_fill(~nonpad_positions, False)
+            input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+        sematic_embeds = pre_acoustic_embeds.masked_fill(
+            ~input_mask_expand_dim, 0
+        ) + ys_pad_embed.masked_fill(input_mask_expand_dim, 0)
+        return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+    def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
+
+        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
+            ys_pad.device
+        )
+        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
+        if self.share_embedding:
+            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
+        else:
+            ys_pad_embed = self.decoder.embed(ys_pad_masked)
+        decoder_outs = self.decoder(
+            encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
+        )
+        decoder_out, _ = decoder_outs[0], decoder_outs[1]
+        pred_tokens = decoder_out.argmax(-1)
+        nonpad_positions = ys_pad.ne(self.ignore_id)
+        seq_lens = (nonpad_positions).sum(1)
+        same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+        input_mask = torch.ones_like(nonpad_positions)
+        bsz, seq_len = ys_pad.size()
+        for li in range(bsz):
+            target_num = (
+                ((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio
+            ).long()
+            if target_num > 0:
+                input_mask[li].scatter_(
+                    dim=0,
+                    index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
+                    value=0,
+                )
+        input_mask = input_mask.eq(1)
+        input_mask = input_mask.masked_fill(~nonpad_positions, False)
+        input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+        sematic_embeds = pre_acoustic_embeds.masked_fill(
+            ~input_mask_expand_dim, 0
+        ) + ys_pad_embed.masked_fill(input_mask_expand_dim, 0)
+        return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+        return loss_ctc, cer_ctc
+
+    def init_beam_search(
+        self,
+        **kwargs,
+    ):
+        from funasr.models.paraformer.search import BeamSearchPara
+        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+        from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+        # 1. Build ASR model
+        scorers = {}
+
+        if self.ctc != None:
+            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+            scorers.update(ctc=ctc)
+        token_list = kwargs.get("token_list")
+        scorers.update(
+            length_bonus=LengthBonus(len(token_list)),
+        )
+
+        # 3. Build ngram model
+        # ngram is not supported now
+        ngram = None
+        scorers["ngram"] = ngram
+
+        weights = dict(
+            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
+            ctc=kwargs.get("decoding_ctc_weight", 0.0),
+            lm=kwargs.get("lm_weight", 0.0),
+            ngram=kwargs.get("ngram_weight", 0.0),
+            length_bonus=kwargs.get("penalty", 0.0),
+        )
+        beam_search = BeamSearchPara(
+            beam_size=kwargs.get("beam_size", 2),
+            weights=weights,
+            scorers=scorers,
+            sos=self.sos,
+            eos=self.eos,
+            vocab_size=len(token_list),
+            token_list=token_list,
+            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+        )
+        # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+        # for scorer in scorers.values():
+        #     if isinstance(scorer, torch.nn.Module):
+        #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
+        self.beam_search = beam_search
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+        # init beamsearch
+        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
+        is_use_lm = (
+            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
+        )
+        pred_timestamp = kwargs.get("pred_timestamp", False)
+        if self.beam_search is None and (is_use_lm or is_use_ctc):
+            logging.info("enable beam_search")
+            self.init_beam_search(**kwargs)
+            self.nbest = kwargs.get("nbest", 1)
+
+        meta_data = {}
+        if (
+            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+        ):  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is not None:
+                speech_lengths = speech_lengths.squeeze(-1)
+            else:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(
+                data_in,
+                fs=frontend.fs,
+                audio_fs=kwargs.get("fs", 16000),
+                data_type=kwargs.get("data_type", "sound"),
+                tokenizer=tokenizer,
+            )
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(
+                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+            )
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = (
+                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+            )
+
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+        # Encoder
+        if kwargs.get("fp16", False):
+            speech = speech.half()
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        # predictor
+        predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = (
+            predictor_outs[0],
+            predictor_outs[1],
+            predictor_outs[2],
+            predictor_outs[3],
+        )
+        
+        pre_token_length = pre_token_length.round().long()
+        if torch.max(pre_token_length) < 1:
+            return []
+        decoder_outs = self.cal_decoder_with_predictor(
+            encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length
+        )
+        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+
+        results = []
+        b, n, d = decoder_out.size()
+        if isinstance(key[0], (list, tuple)):
+            key = key[0]
+        if len(key) < b:
+            key = key * b
+        for i in range(b):
+            x = encoder_out[i, : encoder_out_lens[i], :]
+            am_scores = decoder_out[i, : pre_token_length[i], :]
+            if self.beam_search is not None:
+                nbest_hyps = self.beam_search(
+                    x=x,
+                    am_scores=am_scores,
+                    maxlenratio=kwargs.get("maxlenratio", 0.0),
+                    minlenratio=kwargs.get("minlenratio", 0.0),
+                )
+
+                nbest_hyps = nbest_hyps[: self.nbest]
+            else:
+
+                yseq = am_scores.argmax(dim=-1)
+                score = am_scores.max(dim=-1)[0]
+                score = torch.sum(score, dim=-1)
+                # pad with mask tokens to ensure compatibility with sos/eos tokens
+                yseq = torch.tensor([self.sos] + yseq.tolist() + [self.eos], device=yseq.device)
+                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+            for nbest_idx, hyp in enumerate(nbest_hyps):
+                ibest_writer = None
+                if kwargs.get("output_dir") is not None:
+                    if not hasattr(self, "writer"):
+                        self.writer = DatadirWriter(kwargs.get("output_dir"))
+                    ibest_writer = self.writer[f"{nbest_idx+1}best_recog"]
+                # remove sos/eos and get results
+                last_pos = -1
+                if isinstance(hyp.yseq, list):
+                    token_int = hyp.yseq[1:last_pos]
+                else:
+                    token_int = hyp.yseq[1:last_pos].tolist()
+
+                # remove blank symbol id, which is assumed to be 0
+                token_int = list(
+                    filter(
+                        lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
+                    )
+                )
+
+                if tokenizer is not None:
+                    # Change integer-ids to tokens
+                    token = tokenizer.ids2tokens(token_int)
+                    text_postprocessed = tokenizer.tokens2text(token)
+                    
+                    if pred_timestamp:
+                        timestamp_str, timestamp = ts_prediction_lfr6_standard(
+                            pre_peak_index[i],
+                            alphas[i],
+                            copy.copy(token),
+                            vad_offset=kwargs.get("begin_time", 0),
+                            upsample_rate=1,
+                        )
+                        if not hasattr(tokenizer, "bpemodel"):
+                            text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
+                        result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,}
+                    else:
+                        if not hasattr(tokenizer, "bpemodel"):
+                            text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+                        result_i = {"key": key[i], "text": text_postprocessed}
+
+                    if ibest_writer is not None:
+                        ibest_writer["token"][key[i]] = " ".join(token)
+                        # ibest_writer["text"][key[i]] = text
+                        ibest_writer["text"][key[i]] = text_postprocessed
+                else:
+                    result_i = {"key": key[i], "token_int": token_int}
+                results.append(result_i)
+
+        return results, meta_data
+
+    def export(self, **kwargs):
+        from .export_meta import export_rebuild_model
+
+        if "max_seq_len" not in kwargs:
+            kwargs["max_seq_len"] = 512
+        models = export_rebuild_model(model=self, **kwargs)
+        return models
diff --git a/funasr/models/e_paraformer/pif_predictor.py b/funasr/models/e_paraformer/pif_predictor.py
new file mode 100644
index 0000000..8d350f6
--- /dev/null
+++ b/funasr/models/e_paraformer/pif_predictor.py
@@ -0,0 +1,107 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# Copyright 2024 Kun Zou (chinazoukun@gmail.com). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import torch
+import logging
+import numpy as np
+
+from funasr.register import tables
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from torch.cuda.amp import autocast
+
+
+@tables.register("predictor_classes", "PifPredictor")
+class PifPredictor(torch.nn.Module):
+    """
+    Author: Kun Zou, chinazoukun@gmail.com
+    E-Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognition
+    https://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf
+    """
+    def __init__(
+        self,
+        idim,
+        l_order,
+        r_order,
+        threshold=1.0,
+        dropout=0.1,
+        smooth_factor=1.0,
+        noise_threshold=0,
+        sigma=0.5,
+        bias=0.0,
+        sigma_heads=4,
+    ):
+        super().__init__()
+
+        self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
+        self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
+        self.cif_output = torch.nn.Linear(idim, 1)
+        self.dropout = torch.nn.Dropout(p=dropout)
+        self.threshold = threshold
+        self.smooth_factor = smooth_factor
+        self.noise_threshold = noise_threshold
+        self.sigma = torch.nn.Parameter(torch.tensor([sigma]*sigma_heads))
+        self.bias = torch.nn.Parameter(torch.tensor([bias]*sigma_heads))
+        self.sigma_heads = sigma_heads
+
+    def forward(
+        self,
+        hidden,
+        target_label=None,
+        mask=None,
+        ignore_id=-1,
+        mask_chunk_predictor=None,
+        target_label_length=None,
+    ):
+
+        with autocast(False):
+            h = hidden
+            context = h.transpose(1, 2)
+            queries = self.pad(context)
+            memory = self.cif_conv1d(queries)
+            output = memory + context
+            output = self.dropout(output)
+            output = output.transpose(1, 2)
+            output = torch.relu(output)
+            output = self.cif_output(output)
+            alphas = torch.sigmoid(output)
+            alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
+            if mask is not None:
+                mask = mask.transpose(-1, -2).float()
+                alphas = alphas * mask
+            if mask_chunk_predictor is not None:
+                alphas = alphas * mask_chunk_predictor
+            alphas = alphas.squeeze(-1)
+            mask = mask.squeeze(-1)
+            if target_label_length is not None:
+                target_length = target_label_length
+            elif target_label is not None:
+                target_mask = (target_label != ignore_id).float()
+                target_length = target_mask.sum(-1)
+            else:
+                target_mask = None
+                target_length = None
+            token_num = alphas.sum(-1)
+            if target_length is not None:
+                alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
+                max_token_num = torch.max(target_length)
+            else:
+                token_num_int = token_num.round()
+                alphas *=(token_num_int / token_num)[:, None]
+                max_token_num = torch.max(token_num_int)
+            alignment = torch.cumsum(alphas, dim=-1)
+            fire_positions = (torch.arange(max_token_num) + 0.5).type_as(alphas).unsqueeze(0)
+            scores = - ((fire_positions[:, None, :, None] - alignment[:, None, None, :]) * self.sigma[None, :, None, None]) **2 + self.bias[None, :, None, None]
+            scores = scores.masked_fill(~(mask[:, None, None, :].to(torch.bool)), float("-inf"))
+            weights = torch.softmax(scores, dim=-1)
+            n_hidden = hidden.view(hidden.size(0), -1, self.sigma_heads, hidden.size(-1) // self.sigma_heads).transpose(1, 2)
+            acoustic_embeds = torch.matmul(weights, n_hidden).transpose(1,2).contiguous().view(hidden.size(0), -1, hidden.size(-1))
+
+            if target_mask is not None:
+                acoustic_embeds *= target_mask[:, :, None]
+            cif_peak = None
+        return acoustic_embeds, token_num, alphas, cif_peak
+
diff --git a/funasr/models/e_paraformer/search.py b/funasr/models/e_paraformer/search.py
new file mode 100644
index 0000000..16e13dd
--- /dev/null
+++ b/funasr/models/e_paraformer/search.py
@@ -0,0 +1,451 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import torch
+import logging
+from itertools import chain
+from typing import Any, Dict, List, NamedTuple, Tuple, Union
+
+from funasr.metrics.common import end_detect
+from funasr.models.transformer.scorers.scorer_interface import (
+    PartialScorerInterface,
+    ScorerInterface,
+)
+
+
+class Hypothesis(NamedTuple):
+    """Hypothesis data type."""
+
+    yseq: torch.Tensor
+    score: Union[float, torch.Tensor] = 0
+    scores: Dict[str, Union[float, torch.Tensor]] = dict()
+    states: Dict[str, Any] = dict()
+
+    def asdict(self) -> dict:
+        """Convert data to JSON-friendly dict."""
+        return self._replace(
+            yseq=self.yseq.tolist(),
+            score=float(self.score),
+            scores={k: float(v) for k, v in self.scores.items()},
+        )._asdict()
+
+
+class BeamSearchPara(torch.nn.Module):
+    """Beam search implementation."""
+
+    def __init__(
+        self,
+        scorers: Dict[str, ScorerInterface],
+        weights: Dict[str, float],
+        beam_size: int,
+        vocab_size: int,
+        sos: int,
+        eos: int,
+        token_list: List[str] = None,
+        pre_beam_ratio: float = 1.5,
+        pre_beam_score_key: str = None,
+    ):
+        """Initialize beam search.
+
+        Args:
+            scorers (dict[str, ScorerInterface]): Dict of decoder modules
+                e.g., Decoder, CTCPrefixScorer, LM
+                The scorer will be ignored if it is `None`
+            weights (dict[str, float]): Dict of weights for each scorers
+                The scorer will be ignored if its weight is 0
+            beam_size (int): The number of hypotheses kept during search
+            vocab_size (int): The number of vocabulary
+            sos (int): Start of sequence id
+            eos (int): End of sequence id
+            token_list (list[str]): List of tokens for debug log
+            pre_beam_score_key (str): key of scores to perform pre-beam search
+            pre_beam_ratio (float): beam size in the pre-beam search
+                will be `int(pre_beam_ratio * beam_size)`
+
+        """
+        super().__init__()
+        # set scorers
+        self.weights = weights
+        self.scorers = dict()
+        self.full_scorers = dict()
+        self.part_scorers = dict()
+        # this module dict is required for recursive cast
+        # `self.to(device, dtype)` in `recog.py`
+        self.nn_dict = torch.nn.ModuleDict()
+        for k, v in scorers.items():
+            w = weights.get(k, 0)
+            if w == 0 or v is None:
+                continue
+            assert isinstance(
+                v, ScorerInterface
+            ), f"{k} ({type(v)}) does not implement ScorerInterface"
+            self.scorers[k] = v
+            if isinstance(v, PartialScorerInterface):
+                self.part_scorers[k] = v
+            else:
+                self.full_scorers[k] = v
+            if isinstance(v, torch.nn.Module):
+                self.nn_dict[k] = v
+
+        # set configurations
+        self.sos = sos
+        self.eos = eos
+        self.token_list = token_list
+        self.pre_beam_size = int(pre_beam_ratio * beam_size)
+        self.beam_size = beam_size
+        self.n_vocab = vocab_size
+        if (
+            pre_beam_score_key is not None
+            and pre_beam_score_key != "full"
+            and pre_beam_score_key not in self.full_scorers
+        ):
+            raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
+        self.pre_beam_score_key = pre_beam_score_key
+        self.do_pre_beam = (
+            self.pre_beam_score_key is not None
+            and self.pre_beam_size < self.n_vocab
+            and len(self.part_scorers) > 0
+        )
+
+    def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
+        """Get an initial hypothesis data.
+
+        Args:
+            x (torch.Tensor): The encoder output feature
+
+        Returns:
+            Hypothesis: The initial hypothesis.
+
+        """
+        init_states = dict()
+        init_scores = dict()
+        for k, d in self.scorers.items():
+            init_states[k] = d.init_state(x)
+            init_scores[k] = 0.0
+        return [
+            Hypothesis(
+                score=0.0,
+                scores=init_scores,
+                states=init_states,
+                yseq=torch.tensor([self.sos], device=x.device),
+            )
+        ]
+
+    @staticmethod
+    def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
+        """Append new token to prefix tokens.
+
+        Args:
+            xs (torch.Tensor): The prefix token
+            x (int): The new token to append
+
+        Returns:
+            torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
+
+        """
+        x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
+        return torch.cat((xs, x))
+
+    def score_full(
+        self, hyp: Hypothesis, x: torch.Tensor
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.full_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.full_scorers`
+                and tensor score values of shape: `(self.n_vocab,)`,
+                and state dict that has string keys
+                and state values of `self.full_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.full_scorers.items():
+            scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
+        return scores, states
+
+    def score_partial(
+        self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
+    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
+        """Score new hypothesis by `self.part_scorers`.
+
+        Args:
+            hyp (Hypothesis): Hypothesis with prefix tokens to score
+            ids (torch.Tensor): 1D tensor of new partial tokens to score
+            x (torch.Tensor): Corresponding input feature
+
+        Returns:
+            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
+                score dict of `hyp` that has string keys of `self.part_scorers`
+                and tensor score values of shape: `(len(ids),)`,
+                and state dict that has string keys
+                and state values of `self.part_scorers`
+
+        """
+        scores = dict()
+        states = dict()
+        for k, d in self.part_scorers.items():
+            scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
+        return scores, states
+
+    def beam(
+        self, weighted_scores: torch.Tensor, ids: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute topk full token ids and partial token ids.
+
+        Args:
+            weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
+            Its shape is `(self.n_vocab,)`.
+            ids (torch.Tensor): The partial token ids to compute topk
+
+        Returns:
+            Tuple[torch.Tensor, torch.Tensor]:
+                The topk full token ids and partial token ids.
+                Their shapes are `(self.beam_size,)`
+
+        """
+        # no pre beam performed
+        if weighted_scores.size(0) == ids.size(0):
+            top_ids = weighted_scores.topk(self.beam_size)[1]
+            return top_ids, top_ids
+
+        # mask pruned in pre-beam not to select in topk
+        tmp = weighted_scores[ids]
+        weighted_scores[:] = -float("inf")
+        weighted_scores[ids] = tmp
+        top_ids = weighted_scores.topk(self.beam_size)[1]
+        local_ids = weighted_scores[ids].topk(self.beam_size)[1]
+        return top_ids, local_ids
+
+    @staticmethod
+    def merge_scores(
+        prev_scores: Dict[str, float],
+        next_full_scores: Dict[str, torch.Tensor],
+        full_idx: int,
+        next_part_scores: Dict[str, torch.Tensor],
+        part_idx: int,
+    ) -> Dict[str, torch.Tensor]:
+        """Merge scores for new hypothesis.
+
+        Args:
+            prev_scores (Dict[str, float]):
+                The previous hypothesis scores by `self.scorers`
+            next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
+            full_idx (int): The next token id for `next_full_scores`
+            next_part_scores (Dict[str, torch.Tensor]):
+                scores of partial tokens by `self.part_scorers`
+            part_idx (int): The new token id for `next_part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are scalar tensors by the scorers.
+
+        """
+        new_scores = dict()
+        for k, v in next_full_scores.items():
+            new_scores[k] = prev_scores[k] + v[full_idx]
+        for k, v in next_part_scores.items():
+            new_scores[k] = prev_scores[k] + v[part_idx]
+        return new_scores
+
+    def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
+        """Merge states for new hypothesis.
+
+        Args:
+            states: states of `self.full_scorers`
+            part_states: states of `self.part_scorers`
+            part_idx (int): The new token id for `part_scores`
+
+        Returns:
+            Dict[str, torch.Tensor]: The new score dict.
+                Its keys are names of `self.full_scorers` and `self.part_scorers`.
+                Its values are states of the scorers.
+
+        """
+        new_states = dict()
+        for k, v in states.items():
+            new_states[k] = v
+        for k, d in self.part_scorers.items():
+            new_states[k] = d.select_state(part_states[k], part_idx)
+        return new_states
+
+    def search(
+        self, running_hyps: List[Hypothesis], x: torch.Tensor, am_score: torch.Tensor
+    ) -> List[Hypothesis]:
+        """Search new tokens for running hypotheses and encoded speech x.
+
+        Args:
+            running_hyps (List[Hypothesis]): Running hypotheses on beam
+            x (torch.Tensor): Encoded speech feature (T, D)
+
+        Returns:
+            List[Hypotheses]: Best sorted hypotheses
+
+        """
+        best_hyps = []
+        part_ids = torch.arange(self.n_vocab, device=x.device)  # no pre-beam
+        for hyp in running_hyps:
+            # scoring
+            weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
+            weighted_scores += am_score
+            scores, states = self.score_full(hyp, x)
+            for k in self.full_scorers:
+                weighted_scores += self.weights[k] * scores[k]
+            # partial scoring
+            if self.do_pre_beam:
+                pre_beam_scores = (
+                    weighted_scores
+                    if self.pre_beam_score_key == "full"
+                    else scores[self.pre_beam_score_key]
+                )
+                part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
+            part_scores, part_states = self.score_partial(hyp, part_ids, x)
+            for k in self.part_scorers:
+                weighted_scores[part_ids] += self.weights[k] * part_scores[k]
+            # add previous hyp score
+            weighted_scores += hyp.score
+
+            # update hyps
+            for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
+                # will be (2 x beam at most)
+                best_hyps.append(
+                    Hypothesis(
+                        score=weighted_scores[j],
+                        yseq=self.append_token(hyp.yseq, j),
+                        scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
+                        states=self.merge_states(states, part_states, part_j),
+                    )
+                )
+
+            # sort and prune 2 x beam -> beam
+            best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
+                : min(len(best_hyps), self.beam_size)
+            ]
+        return best_hyps
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        am_scores: torch.Tensor,
+        maxlenratio: float = 0.0,
+        minlenratio: float = 0.0,
+    ) -> List[Hypothesis]:
+        """Perform beam search.
+
+        Args:
+            x (torch.Tensor): Encoded speech feature (T, D)
+            maxlenratio (float): Input length ratio to obtain max output length.
+                If maxlenratio=0.0 (default), it uses a end-detect function
+                to automatically find maximum hypothesis lengths
+                If maxlenratio<0.0, its absolute value is interpreted
+                as a constant max output length.
+            minlenratio (float): Input length ratio to obtain min output length.
+
+        Returns:
+            list[Hypothesis]: N-best decoding results
+
+        """
+        # set length bounds
+        maxlen = am_scores.shape[0]
+        logging.info("decoder input length: " + str(x.shape[0]))
+        logging.info("max output length: " + str(maxlen))
+
+        # main loop of prefix search
+        running_hyps = self.init_hyp(x)
+        ended_hyps = []
+        for i in range(maxlen):
+            logging.debug("position " + str(i))
+            best = self.search(running_hyps, x, am_scores[i])
+            # post process of one iteration
+            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
+            # end detection
+            if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
+                logging.info(f"end detected at {i}")
+                break
+            if len(running_hyps) == 0:
+                logging.info("no hypothesis. Finish decoding.")
+                break
+            else:
+                logging.debug(f"remained hypotheses: {len(running_hyps)}")
+
+        nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
+        # check the number of hypotheses reaching to eos
+        if len(nbest_hyps) == 0:
+            logging.warning(
+                "there is no N-best results, perform recognition " "again with smaller minlenratio."
+            )
+            return (
+                []
+                if minlenratio < 0.1
+                else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
+            )
+
+        # report the best result
+        best = nbest_hyps[0]
+        for k, v in best.scores.items():
+            logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
+        logging.info(f"total log probability: {best.score:.2f}")
+        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
+        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
+        if self.token_list is not None:
+            logging.info(
+                "best hypo: " + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]]) + "\n"
+            )
+        return nbest_hyps
+
+    def post_process(
+        self,
+        i: int,
+        maxlen: int,
+        maxlenratio: float,
+        running_hyps: List[Hypothesis],
+        ended_hyps: List[Hypothesis],
+    ) -> List[Hypothesis]:
+        """Perform post-processing of beam search iterations.
+
+        Args:
+            i (int): The length of hypothesis tokens.
+            maxlen (int): The maximum length of tokens in beam search.
+            maxlenratio (int): The maximum length ratio in beam search.
+            running_hyps (List[Hypothesis]): The running hypotheses in beam search.
+            ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
+
+        Returns:
+            List[Hypothesis]: The new running hypotheses.
+
+        """
+        logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
+        if self.token_list is not None:
+            logging.debug(
+                "best hypo: "
+                + "".join([self.token_list[x.item()] for x in running_hyps[0].yseq[1:]])
+            )
+        # add eos in the final loop to avoid that there are no ended hyps
+        if i == maxlen - 1:
+            logging.info("adding <eos> in the last position in the loop")
+            running_hyps = [
+                h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
+            ]
+
+        # add ended hypotheses to a final list, and removed them from current hypotheses
+        # (this will be a problem, number of hyps < beam)
+        remained_hyps = []
+        for hyp in running_hyps:
+            if hyp.yseq[-1] == self.eos:
+                # e.g., Word LM needs to add final <eos> score
+                for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
+                    s = d.final_score(hyp.states[k])
+                    hyp.scores[k] += s
+                    hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
+                ended_hyps.append(hyp)
+            else:
+                remained_hyps.append(hyp)
+        return remained_hyps

--
Gitblit v1.9.1