Kun Zou
2024-12-21 b5ad7c81be2e24f255cac3d0ef0037bf88228366
Support eparaformer model on aishell1 recipe (#2327)

25个文件已添加
5094 ■■■■■ 已修改文件
examples/aishell/e_paraformer/conf/e_paraformer_conformer_12e_6d_2048_256.yaml 121 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/demo_infer.sh 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/demo_train_or_finetune.sh 51 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/local/aishell_data_prep.sh 66 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/local/download_and_untar.sh 105 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/run.sh 201 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/compute_wer.py 197 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/extract_embeds.py 49 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/filter_scp.pl 87 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/fix_data.sh 35 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/fix_data_feat.sh 52 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/parse_options.sh 97 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/postprocess_text_zh.py 30 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/shuffle_list.pl 44 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/split_scp.pl 246 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/text2token.py 141 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/text_tokenize.py 104 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/text_tokenize.sh 35 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/utils/textnorm_zh.py 911 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e_paraformer/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e_paraformer/decoder.py 1193 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e_paraformer/export_meta.py 86 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e_paraformer/model.py 670 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e_paraformer/pif_predictor.py 107 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e_paraformer/search.py 451 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/aishell/e_paraformer/conf/e_paraformer_conformer_12e_6d_2048_256.yaml
New file
@@ -0,0 +1,121 @@
# network architecture
model: EParaformer
model_conf:
    ctc_weight: 0.0
    lsm_weight: 0.1
    length_normalized_loss: false
    predictor_weight: 1.0
    predictor_bias: 2
    sampling_ratio: 0.4
    use_1st_decoder_loss: true
# encoder
encoder: ConformerEncoder
encoder_conf:
    output_size: 256    # dimension of attention
    attention_heads: 4
    linear_units: 2048  # the number of units of position-wise feed forward
    num_blocks: 12      # the number of encoder blocks
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    attention_dropout_rate: 0.0
    input_layer: conv2d # encoder architecture type
    normalize_before: true
    pos_enc_layer_type: rel_pos
    selfattention_layer_type: rel_selfattn
    activation_type: swish
    macaron_style: true
    use_cnn_module: true
    cnn_module_kernel: 15
# decoder
decoder: ParaformerSANDecoder
decoder_conf:
    attention_heads: 4
    linear_units: 2048
    num_blocks: 6
    dropout_rate: 0.1
    positional_dropout_rate: 0.1
    self_attention_dropout_rate: 0.0
    src_attention_dropout_rate: 0.0
# predictor
predictor: PifPredictor
predictor_conf:
    idim: 256
    threshold: 1.0
    l_order: 1
    r_order: 1
    sigma: 0.5
    bias: 0.0
    sigma_heads: 4
# frontend related
frontend: WavFrontend
frontend_conf:
    fs: 16000
    window: hamming
    n_mels: 80
    frame_length: 25
    frame_shift: 10
    lfr_m: 1
    lfr_n: 1
specaug: SpecAug
specaug_conf:
    apply_time_warp: true
    time_warp_window: 5
    time_warp_mode: bicubic
    apply_freq_mask: true
    freq_mask_width_range:
    - 0
    - 30
    num_freq_mask: 2
    apply_time_mask: true
    time_mask_width_range:
    - 0
    - 40
    num_time_mask: 2
train_conf:
  accum_grad: 4
  grad_clip: 5
  max_epoch: 150
  keep_nbest_models: 20
  avg_nbest_model: 15
  log_interval: 50
optim: adam
optim_conf:
   lr: 0.0005
scheduler: warmuplr
scheduler_conf:
   warmup_steps: 30000
dataset: AudioDataset
dataset_conf:
    index_ds: IndexDSJsonl
    batch_sampler: EspnetStyleBatchSampler
    batch_type: length # example or length
    batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
    max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
    buffer_size: 1024
    shuffle: True
    num_workers: 4
    preprocessor_speech: SpeechPreprocessSpeedPerturb
    preprocessor_speech_conf:
      speed_perturb: [0.9, 1.0, 1.1]
tokenizer: CharTokenizer
tokenizer_conf:
  unk_symbol: <unk>
ctc_conf:
    dropout_rate: 0.0
    ctc_type: builtin
    reduce: true
    ignore_nan_grad: true
normalize: null
examples/aishell/e_paraformer/demo_infer.sh
New file
@@ -0,0 +1,15 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
python -m funasr.bin.inference \
--config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \
--config-name="config.yaml" \
++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \
++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \
++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \
++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \
++output_dir="./outputs/debug" \
++device="cuda:0" \
examples/aishell/e_paraformer/demo_train_or_finetune.sh
New file
@@ -0,0 +1,51 @@
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
# which gpu to train or finetune
export CUDA_VISIBLE_DEVICES="0,1"
gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
# data dir, which contains: train.json, val.json, tokens.jsonl/tokens.txt, am.mvn
data_dir="/Users/zhifu/funasr1.0/data/list"
## generate jsonl from wav.scp and text.txt
#python -m funasr.datasets.audio_datasets.scp2jsonl \
#++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \
#++data_type_list='["source", "target"]' \
#++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl
train_data="${data_dir}/train.jsonl"
val_data="${data_dir}/val.jsonl"
tokens="${data_dir}/tokens.json"
cmvn_file="${data_dir}/am.mvn"
# exp output dir
output_dir="/Users/zhifu/exp"
log_file="${output_dir}/log.txt"
workspace=`pwd`
config="paraformer_conformer_12e_6d_2048_256.yaml"
init_param="${output_dir}/model.pt"
mkdir -p ${output_dir}
echo "log_file: ${log_file}"
torchrun \
--nnodes 1 \
--nproc_per_node ${gpu_num} \
../../../funasr/bin/train.py \
--config-path "${workspace}/conf" \
--config-name "${config}" \
++train_data_set_list="${train_data}" \
++valid_data_set_list="${val_data}" \
++tokenizer_conf.token_list="${tokens}" \
++frontend_conf.cmvn_file="${cmvn_file}" \
++dataset_conf.batch_size=32 \
++dataset_conf.batch_type="example" \
++dataset_conf.num_workers=4 \
++train_conf.max_epoch=150 \
++optim_conf.lr=0.0002 \
++init_param="${init_param}" \
++output_dir="${output_dir}" &> ${log_file}
examples/aishell/e_paraformer/local/aishell_data_prep.sh
New file
@@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2017 Xingyu Na
# Apache 2.0
#. ./path.sh || exit 1;
if [ $# != 3 ]; then
  echo "Usage: $0 <audio-path> <text-path> <output-path>"
  echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
  exit 1;
fi
aishell_audio_dir=$1
aishell_text=$2/aishell_transcript_v0.8.txt
output_dir=$3
train_dir=$output_dir/data/local/train
dev_dir=$output_dir/data/local/dev
test_dir=$output_dir/data/local/test
tmp_dir=$output_dir/data/local/tmp
mkdir -p $train_dir
mkdir -p $dev_dir
mkdir -p $test_dir
mkdir -p $tmp_dir
# data directory check
if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
  echo "Error: $0 requires two directory arguments"
  exit 1;
fi
# find wav audio file for train, dev and test resp.
find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
n=`cat $tmp_dir/wav.flist | wc -l`
[ $n -ne 141925 ] && \
  echo Warning: expected 141925 data data files, found $n
grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
rm -r $tmp_dir
# Transcriptions preparation
for dir in $train_dir $dev_dir $test_dir; do
  echo Preparing $dir transcriptions
  sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
  paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
  utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
  awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
  utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
  sort -u $dir/transcripts.txt > $dir/text
done
mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
for f in wav.scp text; do
  cp $train_dir/$f $output_dir/data/train/$f || exit 1;
  cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
  cp $test_dir/$f $output_dir/data/test/$f || exit 1;
done
echo "$0: AISHELL data preparation succeeded"
exit 0;
examples/aishell/e_paraformer/local/download_and_untar.sh
New file
@@ -0,0 +1,105 @@
#!/usr/bin/env bash
# Copyright   2014  Johns Hopkins University (author: Daniel Povey)
#             2017  Xingyu Na
# Apache 2.0
remove_archive=false
if [ "$1" == --remove-archive ]; then
  remove_archive=true
  shift
fi
if [ $# -ne 3 ]; then
  echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
  echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell"
  echo "With --remove-archive it will remove the archive after successfully un-tarring it."
  echo "<corpus-part> can be one of: data_aishell, resource_aishell."
fi
data=$1
url=$2
part=$3
if [ ! -d "$data" ]; then
  echo "$0: no such directory $data"
  exit 1;
fi
part_ok=false
list="data_aishell resource_aishell"
for x in $list; do
  if [ "$part" == $x ]; then part_ok=true; fi
done
if ! $part_ok; then
  echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
  exit 1;
fi
if [ -z "$url" ]; then
  echo "$0: empty URL base."
  exit 1;
fi
if [ -f $data/$part/.complete ]; then
  echo "$0: data part $part was already successfully extracted, nothing to do."
  exit 0;
fi
# sizes of the archive files in bytes.
sizes="15582913665 1246920"
if [ -f $data/$part.tgz ]; then
  size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}')
  size_ok=false
  for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done
  if ! $size_ok; then
    echo "$0: removing existing file $data/$part.tgz because its size in bytes $size"
    echo "does not equal the size of one of the archives."
    rm $data/$part.tgz
  else
    echo "$data/$part.tgz exists and appears to be complete."
  fi
fi
if [ ! -f $data/$part.tgz ]; then
  if ! command -v wget >/dev/null; then
    echo "$0: wget is not installed."
    exit 1;
  fi
  full_url=$url/$part.tgz
  echo "$0: downloading data from $full_url.  This may take some time, please be patient."
  cd $data || exit 1
  if ! wget --no-check-certificate $full_url; then
    echo "$0: error executing wget $full_url"
    exit 1;
  fi
fi
cd $data || exit 1
if ! tar -xvzf $part.tgz; then
  echo "$0: error un-tarring archive $data/$part.tgz"
  exit 1;
fi
touch $data/$part/.complete
if [ $part == "data_aishell" ]; then
  cd $data/$part/wav || exit 1
  for wav in ./*.tar.gz; do
    echo "Extracting wav from $wav"
    tar -zxf $wav && rm $wav
  done
fi
echo "$0: Successfully downloaded and un-tarred $data/$part.tgz"
if $remove_archive; then
  echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied."
  rm $data/$part.tgz
fi
exit 0;
examples/aishell/e_paraformer/run.sh
New file
@@ -0,0 +1,201 @@
#!/usr/bin/env bash
CUDA_VISIBLE_DEVICES="0,1"
# general configuration
feats_dir="../DATA" #feature output dictionary
exp_dir=`pwd`
lang=zh
token_type=char
stage=0
stop_stage=5
# feature configuration
nj=32
inference_device="cuda" #"cpu"
inference_checkpoint="model.pt.avg10"
inference_scp="wav.scp"
inference_batch_size=32
# data
raw_data=../raw_data
data_url=www.openslr.org/resources/33
# exp tag
tag="exp1"
workspace=`pwd`
master_port=12345
. utils/parse_options.sh || exit 1;
# Set bash to 'debug' mode, it will exit on :
# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
set -e
set -u
set -o pipefail
train_set=train
valid_set=dev
test_sets="dev test"
config=e_paraformer_conformer_12e_6d_2048_256.yaml
model_dir="baseline_$(basename "${config}" .yaml)_${lang}_${token_type}_${tag}"
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
    echo "stage -1: Data Download"
    mkdir -p ${raw_data}
    local/download_and_untar.sh ${raw_data} ${data_url} data_aishell
    local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell
fi
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
    echo "stage 0: Data preparation"
    # Data preparation
    local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir}
    for x in train dev test; do
        cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
        paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
            > ${feats_dir}/data/${x}/text
        utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
        mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
        # convert wav.scp text to jsonl
        scp_file_list_arg="++scp_file_list='[\"${feats_dir}/data/${x}/wav.scp\",\"${feats_dir}/data/${x}/text\"]'"
        python ../../../funasr/datasets/audio_datasets/scp2jsonl.py \
        ++data_type_list='["source", "target"]' \
        ++jsonl_file_out=${feats_dir}/data/${x}/audio_datasets.jsonl \
        ${scp_file_list_arg}
    done
fi
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
    echo "stage 1: Feature and CMVN Generation"
    python ../../../funasr/bin/compute_audio_cmvn.py \
    --config-path "${workspace}/conf" \
    --config-name "${config}" \
    ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
    ++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json"
fi
token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt
echo "dictionary: ${token_list}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
    echo "stage 2: Dictionary Preparation"
    mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/
    echo "make a dictionary"
    echo "<blank>" > ${token_list}
    echo "<s>" >> ${token_list}
    echo "</s>" >> ${token_list}
    utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \
        | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
    echo "<unk>" >> ${token_list}
fi
# LM Training Stage
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
    echo "stage 3: LM Training"
fi
# ASR Training Stage
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
  echo "stage 4: ASR Training"
  mkdir -p ${exp_dir}/exp/${model_dir}
  current_time=$(date "+%Y-%m-%d_%H-%M")
  log_file="${exp_dir}/exp/${model_dir}/train.log.txt.${current_time}"
  echo "log_file: ${log_file}"
  export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
  gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
  torchrun \
  --nnodes 1 \
  --nproc_per_node ${gpu_num} \
  --master_port ${master_port} \
  ../../../funasr/bin/train.py \
  --config-path "${workspace}/conf" \
  --config-name "${config}" \
  ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \
  ++valid_data_set_list="${feats_dir}/data/${valid_set}/audio_datasets.jsonl" \
  ++tokenizer_conf.token_list="${token_list}" \
  ++frontend_conf.cmvn_file="${feats_dir}/data/${train_set}/am.mvn" \
  ++output_dir="${exp_dir}/exp/${model_dir}" &> ${log_file}
fi
# Testing Stage
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
  echo "stage 5: Inference"
  if [ ${inference_device} == "cuda" ]; then
      nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
  else
      inference_batch_size=1
      CUDA_VISIBLE_DEVICES=""
      for JOB in $(seq ${nj}); do
          CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
      done
  fi
  for dset in ${test_sets}; do
    inference_dir="${exp_dir}/exp/${model_dir}/inference-${inference_checkpoint}/${dset}"
    _logdir="${inference_dir}/logdir"
    echo "inference_dir: ${inference_dir}"
    mkdir -p "${_logdir}"
    data_dir="${feats_dir}/data/${dset}"
    key_file=${data_dir}/${inference_scp}
    split_scps=
    for JOB in $(seq "${nj}"); do
        split_scps+=" ${_logdir}/keys.${JOB}.scp"
    done
    utils/split_scp.pl "${key_file}" ${split_scps}
    gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
    for JOB in $(seq ${nj}); do
        {
          id=$((JOB-1))
          gpuid=${gpuid_list_array[$id]}
          export CUDA_VISIBLE_DEVICES=${gpuid}
          python ../../../funasr/bin/inference.py \
          --config-path="${exp_dir}/exp/${model_dir}" \
          --config-name="config.yaml" \
          ++init_param="${exp_dir}/exp/${model_dir}/${inference_checkpoint}" \
          ++tokenizer_conf.token_list="${token_list}" \
          ++frontend_conf.cmvn_file="${feats_dir}/data/${train_set}/am.mvn" \
          ++input="${_logdir}/keys.${JOB}.scp" \
          ++output_dir="${inference_dir}/${JOB}" \
          ++device="${inference_device}" \
          ++ncpu=1 \
          ++disable_log=true \
          ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
        }&
    done
    wait
    mkdir -p ${inference_dir}/1best_recog
    for f in token score text; do
        if [ -f "${inference_dir}/${JOB}/1best_recog/${f}" ]; then
          for JOB in $(seq "${nj}"); do
              cat "${inference_dir}/${JOB}/1best_recog/${f}"
          done | sort -k1 >"${inference_dir}/1best_recog/${f}"
        fi
    done
    echo "Computing WER ..."
    python utils/postprocess_text_zh.py ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc
    python utils/postprocess_text_zh.py  ${data_dir}/text ${inference_dir}/1best_recog/text.ref
    python utils/compute_wer.py ${inference_dir}/1best_recog/text.ref ${inference_dir}/1best_recog/text.proc ${inference_dir}/1best_recog/text.cer
    tail -n 3 ${inference_dir}/1best_recog/text.cer
  done
fi
examples/aishell/e_paraformer/utils/compute_wer.py
New file
@@ -0,0 +1,197 @@
import os
import numpy as np
import sys
def compute_wer(ref_file, hyp_file, cer_detail_file):
    rst = {
        "Wrd": 0,
        "Corr": 0,
        "Ins": 0,
        "Del": 0,
        "Sub": 0,
        "Snt": 0,
        "Err": 0.0,
        "S.Err": 0.0,
        "wrong_words": 0,
        "wrong_sentences": 0,
    }
    hyp_dict = {}
    ref_dict = {}
    with open(hyp_file, "r") as hyp_reader:
        for line in hyp_reader:
            key = line.strip().split()[0]
            value = line.strip().split()[1:]
            hyp_dict[key] = value
    with open(ref_file, "r") as ref_reader:
        for line in ref_reader:
            key = line.strip().split()[0]
            value = line.strip().split()[1:]
            ref_dict[key] = value
    cer_detail_writer = open(cer_detail_file, "w")
    for hyp_key in hyp_dict:
        if hyp_key in ref_dict:
            out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key])
            rst["Wrd"] += out_item["nwords"]
            rst["Corr"] += out_item["cor"]
            rst["wrong_words"] += out_item["wrong"]
            rst["Ins"] += out_item["ins"]
            rst["Del"] += out_item["del"]
            rst["Sub"] += out_item["sub"]
            rst["Snt"] += 1
            if out_item["wrong"] > 0:
                rst["wrong_sentences"] += 1
            cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + "\n")
            cer_detail_writer.write(
                "ref:" + "\t" + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + "\n"
            )
            cer_detail_writer.write(
                "hyp:" + "\t" + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + "\n"
            )
    if rst["Wrd"] > 0:
        rst["Err"] = round(rst["wrong_words"] * 100 / rst["Wrd"], 2)
    if rst["Snt"] > 0:
        rst["S.Err"] = round(rst["wrong_sentences"] * 100 / rst["Snt"], 2)
    cer_detail_writer.write("\n")
    cer_detail_writer.write(
        "%WER "
        + str(rst["Err"])
        + " [ "
        + str(rst["wrong_words"])
        + " / "
        + str(rst["Wrd"])
        + ", "
        + str(rst["Ins"])
        + " ins, "
        + str(rst["Del"])
        + " del, "
        + str(rst["Sub"])
        + " sub ]"
        + "\n"
    )
    cer_detail_writer.write(
        "%SER "
        + str(rst["S.Err"])
        + " [ "
        + str(rst["wrong_sentences"])
        + " / "
        + str(rst["Snt"])
        + " ]"
        + "\n"
    )
    cer_detail_writer.write(
        "Scored "
        + str(len(hyp_dict))
        + " sentences, "
        + str(len(hyp_dict) - rst["Snt"])
        + " not present in hyp."
        + "\n"
    )
def compute_wer_by_line(hyp, ref):
    hyp = list(map(lambda x: x.lower(), hyp))
    ref = list(map(lambda x: x.lower(), ref))
    len_hyp = len(hyp)
    len_ref = len(ref)
    cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16)
    ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8)
    for i in range(len_hyp + 1):
        cost_matrix[i][0] = i
    for j in range(len_ref + 1):
        cost_matrix[0][j] = j
    for i in range(1, len_hyp + 1):
        for j in range(1, len_ref + 1):
            if hyp[i - 1] == ref[j - 1]:
                cost_matrix[i][j] = cost_matrix[i - 1][j - 1]
            else:
                substitution = cost_matrix[i - 1][j - 1] + 1
                insertion = cost_matrix[i - 1][j] + 1
                deletion = cost_matrix[i][j - 1] + 1
                compare_val = [substitution, insertion, deletion]
                min_val = min(compare_val)
                operation_idx = compare_val.index(min_val) + 1
                cost_matrix[i][j] = min_val
                ops_matrix[i][j] = operation_idx
    match_idx = []
    i = len_hyp
    j = len_ref
    rst = {"nwords": len_ref, "cor": 0, "wrong": 0, "ins": 0, "del": 0, "sub": 0}
    while i >= 0 or j >= 0:
        i_idx = max(0, i)
        j_idx = max(0, j)
        if ops_matrix[i_idx][j_idx] == 0:  # correct
            if i - 1 >= 0 and j - 1 >= 0:
                match_idx.append((j - 1, i - 1))
                rst["cor"] += 1
            i -= 1
            j -= 1
        elif ops_matrix[i_idx][j_idx] == 2:  # insert
            i -= 1
            rst["ins"] += 1
        elif ops_matrix[i_idx][j_idx] == 3:  # delete
            j -= 1
            rst["del"] += 1
        elif ops_matrix[i_idx][j_idx] == 1:  # substitute
            i -= 1
            j -= 1
            rst["sub"] += 1
        if i < 0 and j >= 0:
            rst["del"] += 1
        elif j < 0 and i >= 0:
            rst["ins"] += 1
    match_idx.reverse()
    wrong_cnt = cost_matrix[len_hyp][len_ref]
    rst["wrong"] = wrong_cnt
    return rst
def print_cer_detail(rst):
    return (
        "("
        + "nwords="
        + str(rst["nwords"])
        + ",cor="
        + str(rst["cor"])
        + ",ins="
        + str(rst["ins"])
        + ",del="
        + str(rst["del"])
        + ",sub="
        + str(rst["sub"])
        + ") corr:"
        + "{:.2%}".format(rst["cor"] / rst["nwords"])
        + ",cer:"
        + "{:.2%}".format(rst["wrong"] / rst["nwords"])
    )
if __name__ == "__main__":
    if len(sys.argv) != 4:
        print("usage : python compute-wer.py test.ref test.hyp test.wer")
        sys.exit(0)
    ref_file = sys.argv[1]
    hyp_file = sys.argv[2]
    cer_detail_file = sys.argv[3]
    compute_wer(ref_file, hyp_file, cer_detail_file)
examples/aishell/e_paraformer/utils/extract_embeds.py
New file
@@ -0,0 +1,49 @@
from transformers import AutoTokenizer, AutoModel, pipeline
import numpy as np
import sys
import os
import torch
from kaldiio import WriteHelper
import re
text_file_json = sys.argv[1]
out_ark = sys.argv[2]
out_scp = sys.argv[3]
out_shape = sys.argv[4]
device = int(sys.argv[5])
model_path = sys.argv[6]
model = AutoModel.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device)
with open(text_file_json, "r") as f:
    js = f.readlines()
f_shape = open(out_shape, "w")
with WriteHelper("ark,scp:{},{}".format(out_ark, out_scp)) as writer:
    with torch.no_grad():
        for idx, line in enumerate(js):
            id, tokens = line.strip().split(" ", 1)
            tokens = re.sub(" ", "", tokens.strip())
            tokens = " ".join([j for j in tokens])
            token_num = len(tokens.split(" "))
            outputs = extractor(tokens)
            outputs = np.array(outputs)
            embeds = outputs[0, 1:-1, :]
            token_num_embeds, dim = embeds.shape
            if token_num == token_num_embeds:
                writer(id, embeds)
                shape_line = "{} {},{}\n".format(id, token_num_embeds, dim)
                f_shape.write(shape_line)
            else:
                print(
                    "{}, size has changed, {}, {}, {}".format(
                        id, token_num, token_num_embeds, tokens
                    )
                )
f_shape.close()
examples/aishell/e_paraformer/utils/filter_scp.pl
New file
@@ -0,0 +1,87 @@
#!/usr/bin/env perl
# Copyright 2010-2012 Microsoft Corporation
#                     Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script takes a list of utterance-ids or any file whose first field
# of each line is an utterance-id, and filters an scp
# file (or any file whose "n-th" field is an utterance id), printing
# out only those lines whose "n-th" field is in id_list. The index of
# the "n-th" field is 1, by default, but can be changed by using
# the -f <n> switch
$exclude = 0;
$field = 1;
$shifted = 0;
do {
  $shifted=0;
  if ($ARGV[0] eq "--exclude") {
    $exclude = 1;
    shift @ARGV;
    $shifted=1;
  }
  if ($ARGV[0] eq "-f") {
    $field = $ARGV[1];
    shift @ARGV; shift @ARGV;
    $shifted=1
  }
} while ($shifted);
if(@ARGV < 1 || @ARGV > 2) {
  die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" .
      "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" .
      "Note: only the first field of each line in id_list matters.  With --exclude, prints\n" .
      "only the lines that were *not* in id_list.\n" .
      "Caution: previously, the -f option was interpreted as a zero-based field index.\n" .
      "If your older scripts (written before Oct 2014) stopped working and you used the\n" .
      "-f option, add 1 to the argument.\n" .
      "See also: scripts/filter_scp.pl .\n";
}
$idlist = shift @ARGV;
open(F, "<$idlist") || die "Could not open id-list file $idlist";
while(<F>) {
  @A = split;
  @A>=1 || die "Invalid id-list file line $_";
  $seen{$A[0]} = 1;
}
if ($field == 1) { # Treat this as special case, since it is common.
  while(<>) {
    $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field.";
    # $1 is what we filter on.
    if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) {
      print $_;
    }
  }
} else {
  while(<>) {
    @A = split;
    @A > 0 || die "Invalid scp file line $_";
    @A >= $field || die "Invalid scp file line $_";
    if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) {
      print $_;
    }
  }
}
# tests:
# the following should print "foo 1"
# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl <(echo foo)
# the following should print "bar 2".
# ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl -f 2 <(echo 2)
examples/aishell/e_paraformer/utils/fix_data.sh
New file
@@ -0,0 +1,35 @@
#!/usr/bin/env bash
echo "$0 $@"
data_dir=$1
if [ ! -f ${data_dir}/wav.scp ]; then
  echo "$0: wav.scp is not found"
  exit 1;
fi
if [ ! -f ${data_dir}/text ]; then
  echo "$0: text is not found"
  exit 1;
fi
mkdir -p ${data_dir}/.backup
awk '{print $1}' ${data_dir}/wav.scp > ${data_dir}/.backup/wav_id
awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
cp ${data_dir}/wav.scp ${data_dir}/.backup/wav.scp
cp ${data_dir}/text ${data_dir}/.backup/text
mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak | sort -k1,1 -u > ${data_dir}/wav.scp
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
rm ${data_dir}/wav.scp.bak
rm ${data_dir}/text.bak
examples/aishell/e_paraformer/utils/fix_data_feat.sh
New file
@@ -0,0 +1,52 @@
#!/usr/bin/env bash
echo "$0 $@"
data_dir=$1
if [ ! -f ${data_dir}/feats.scp ]; then
  echo "$0: feats.scp is not found"
  exit 1;
fi
if [ ! -f ${data_dir}/text ]; then
  echo "$0: text is not found"
  exit 1;
fi
if [ ! -f ${data_dir}/speech_shape ]; then
  echo "$0: feature lengths is not found"
  exit 1;
fi
if [ ! -f ${data_dir}/text_shape ]; then
  echo "$0: text lengths is not found"
  exit 1;
fi
mkdir -p ${data_dir}/.backup
awk '{print $1}' ${data_dir}/feats.scp > ${data_dir}/.backup/wav_id
awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id
sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id
cp ${data_dir}/feats.scp ${data_dir}/.backup/feats.scp
cp ${data_dir}/text ${data_dir}/.backup/text
cp ${data_dir}/speech_shape ${data_dir}/.backup/speech_shape
cp ${data_dir}/text_shape ${data_dir}/.backup/text_shape
mv ${data_dir}/feats.scp ${data_dir}/feats.scp.bak
mv ${data_dir}/text ${data_dir}/text.bak
mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak
mv ${data_dir}/text_shape ${data_dir}/text_shape.bak
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak | sort -k1,1 -u > ${data_dir}/feats.scp
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak | sort -k1,1 -u > ${data_dir}/speech_shape
utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak | sort -k1,1 -u > ${data_dir}/text_shape
rm ${data_dir}/feats.scp.bak
rm ${data_dir}/text.bak
rm ${data_dir}/speech_shape.bak
rm ${data_dir}/text_shape.bak
examples/aishell/e_paraformer/utils/parse_options.sh
New file
@@ -0,0 +1,97 @@
#!/usr/bin/env bash
# Copyright 2012  Johns Hopkins University (Author: Daniel Povey);
#                 Arnab Ghoshal, Karel Vesely
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# Parse command-line options.
# To be sourced by another script (as in ". parse_options.sh").
# Option format is: --option-name arg
# and shell variable "option_name" gets set to value "arg."
# The exception is --help, which takes no arguments, but prints the
# $help_message variable (if defined).
###
### The --config file options have lower priority to command line
### options, so we need to import them first...
###
# Now import all the configs specified by command-line, in left-to-right order
for ((argpos=1; argpos<$#; argpos++)); do
  if [ "${!argpos}" == "--config" ]; then
    argpos_plus1=$((argpos+1))
    config=${!argpos_plus1}
    [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1
    . $config  # source the config file.
  fi
done
###
### Now we process the command line options
###
while true; do
  [ -z "${1:-}" ] && break;  # break if there are no arguments
  case "$1" in
    # If the enclosing script is called with --help option, print the help
    # message and exit.  Scripts should put help messages in $help_message
    --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2;
      else printf "$help_message\n" 1>&2 ; fi;
      exit 0 ;;
    --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'"
      exit 1 ;;
    # If the first command-line argument begins with "--" (e.g. --foo-bar),
    # then work out the variable name as $name, which will equal "foo_bar".
    --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`;
      # Next we test whether the variable in question is undefned-- if so it's
      # an invalid option and we die.  Note: $0 evaluates to the name of the
      # enclosing script.
      # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar
      # is undefined.  We then have to wrap this test inside "eval" because
      # foo_bar is itself inside a variable ($name).
      eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
      oldval="`eval echo \\$$name`";
      # Work out whether we seem to be expecting a Boolean argument.
      if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then
        was_bool=true;
      else
        was_bool=false;
      fi
      # Set the variable to the right value-- the escaped quotes make it work if
      # the option had spaces, like --cmd "queue.pl -sync y"
      eval $name=\"$2\";
      # Check that Boolean-valued arguments are really Boolean.
      if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
        echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
        exit 1;
      fi
      shift 2;
      ;;
  *) break;
  esac
done
# Check for an empty argument to the --cmd option, which can easily occur as a
# result of scripting errors.
[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1;
true; # so this script returns exit code 0.
examples/aishell/e_paraformer/utils/postprocess_text_zh.py
New file
@@ -0,0 +1,30 @@
import sys
import re
in_f = sys.argv[1]
out_f = sys.argv[2]
with open(in_f, "r", encoding="utf-8") as f:
    lines = f.readlines()
with open(out_f, "w", encoding="utf-8") as f:
    for line in lines:
        outs = line.strip().split(" ", 1)
        if len(outs) == 2:
            idx, text = outs
            text = re.sub("</s>", "", text)
            text = re.sub("<s>", "", text)
            text = re.sub("@@", "", text)
            text = re.sub("@", "", text)
            text = re.sub("<unk>", "", text)
            text = re.sub(" ", "", text)
            text = text.lower()
        else:
            idx = outs[0]
            text = " "
        text = [x for x in text]
        text = " ".join(text)
        out = "{} {}\n".format(idx, text)
        f.write(out)
examples/aishell/e_paraformer/utils/shuffle_list.pl
New file
@@ -0,0 +1,44 @@
#!/usr/bin/env perl
# Copyright 2013  Johns Hopkins University (author: Daniel Povey)
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
if ($ARGV[0] eq "--srand") {
  $n = $ARGV[1];
  $n =~ m/\d+/ || die "Bad argument to --srand option: \"$n\"";
  srand($ARGV[1]);
  shift;
  shift;
} else {
  srand(0); # Gives inconsistent behavior if we don't seed.
}
if (@ARGV > 1 || $ARGV[0] =~ m/^-.+/) { # >1 args, or an option we
  # don't understand.
  print "Usage: shuffle_list.pl [--srand N] [input file]  > output\n";
  print "randomizes the order of lines of input.\n";
  exit(1);
}
@lines;
while (<>) {
  push @lines, [ (rand(), $_)] ;
}
@lines = sort { $a->[0] cmp $b->[0] } @lines;
foreach $l (@lines) {
    print $l->[1];
}
examples/aishell/e_paraformer/utils/split_scp.pl
New file
@@ -0,0 +1,246 @@
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# See ../../COPYING for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This program splits up any kind of .scp or archive-type file.
# If there is no utt2spk option it will work on any text  file and
# will split it up with an approximately equal number of lines in
# each but.
# With the --utt2spk option it will work on anything that has the
# utterance-id as the first entry on each line; the utt2spk file is
# of the form "utterance speaker" (on each line).
# It splits it into equal size chunks as far as it can.  If you use the utt2spk
# option it will make sure these chunks coincide with speaker boundaries.  In
# this case, if there are more chunks than speakers (and in some other
# circumstances), some of the resulting chunks will be empty and it will print
# an error message and exit with nonzero status.
# You will normally call this like:
# split_scp.pl scp scp.1 scp.2 scp.3 ...
# or
# split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ...
# Note that you can use this script to split the utt2spk file itself,
# e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ...
# You can also call the scripts like:
# split_scp.pl -j 3 0 scp scp.0
# [note: with this option, it assumes zero-based indexing of the split parts,
# i.e. the second number must be 0 <= n < num-jobs.]
use warnings;
$num_jobs = 0;
$job_id = 0;
$utt2spk_file = "";
$one_based = 0;
for ($x = 1; $x <= 3 && @ARGV > 0; $x++) {
    if ($ARGV[0] eq "-j") {
        shift @ARGV;
        $num_jobs = shift @ARGV;
        $job_id = shift @ARGV;
    }
    if ($ARGV[0] =~ /--utt2spk=(.+)/) {
        $utt2spk_file=$1;
        shift;
    }
    if ($ARGV[0] eq '--one-based') {
        $one_based = 1;
        shift @ARGV;
    }
}
if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 ||
                       $job_id - $one_based >= $num_jobs)) {
  die "$0: Invalid job number/index values for '-j $num_jobs $job_id" .
      ($one_based ? " --one-based" : "") . "'\n"
}
$one_based
    and $job_id--;
if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) {
    die
"Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ...
   or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp]
 ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n";
}
$error = 0;
$inscp = shift @ARGV;
if ($num_jobs == 0) { # without -j option
    @OUTPUTS = @ARGV;
} else {
    for ($j = 0; $j < $num_jobs; $j++) {
        if ($j == $job_id) {
            if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; }
            else { push @OUTPUTS, "-"; }
        } else {
            push @OUTPUTS, "/dev/null";
        }
    }
}
if ($utt2spk_file ne "") {  # We have the --utt2spk option...
    open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n";
    while(<$u_fh>) {
        @A = split;
        @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n";
        ($u,$s) = @A;
        $utt2spk{$u} = $s;
    }
    close $u_fh;
    open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
    @spkrs = ();
    while(<$i_fh>) {
        @A = split;
        if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; }
        $u = $A[0];
        $s = $utt2spk{$u};
        defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n";
        if(!defined $spk_count{$s}) {
            push @spkrs, $s;
            $spk_count{$s} = 0;
            $spk_data{$s} = [];  # ref to new empty array.
        }
        $spk_count{$s}++;
        push @{$spk_data{$s}}, $_;
    }
    # Now split as equally as possible ..
    # First allocate spks to files by allocating an approximately
    # equal number of speakers.
    $numspks = @spkrs;  # number of speakers.
    $numscps = @OUTPUTS; # number of output files.
    if ($numspks < $numscps) {
      die "$0: Refusing to split data because number of speakers $numspks " .
          "is less than the number of output .scp files $numscps\n";
    }
    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
        $scparray[$scpidx] = []; # [] is array reference.
    }
    for ($spkidx = 0; $spkidx < $numspks; $spkidx++) {
        $scpidx = int(($spkidx*$numscps) / $numspks);
        $spk = $spkrs[$spkidx];
        push @{$scparray[$scpidx]}, $spk;
        $scpcount[$scpidx] += $spk_count{$spk};
    }
    # Now will try to reassign beginning + ending speakers
    # to different scp's and see if it gets more balanced.
    # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2.
    # We can show that if considering changing just 2 scp's, we minimize
    # this by minimizing the squared difference in sizes.  This is
    # equivalent to minimizing the absolute difference in sizes.  This
    # shows this method is bound to converge.
    $changed = 1;
    while($changed) {
        $changed = 0;
        for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
            # First try to reassign ending spk of this scp.
            if($scpidx < $numscps-1) {
                $sz = @{$scparray[$scpidx]};
                if($sz > 0) {
                    $spk = $scparray[$scpidx]->[$sz-1];
                    $count = $spk_count{$spk};
                    $nutt1 = $scpcount[$scpidx];
                    $nutt2 = $scpcount[$scpidx+1];
                    if( abs( ($nutt2+$count) - ($nutt1-$count))
                        < abs($nutt2 - $nutt1))  { # Would decrease
                        # size-diff by reassigning spk...
                        $scpcount[$scpidx+1] += $count;
                        $scpcount[$scpidx] -= $count;
                        pop @{$scparray[$scpidx]};
                        unshift @{$scparray[$scpidx+1]}, $spk;
                        $changed = 1;
                    }
                }
            }
            if($scpidx > 0 && @{$scparray[$scpidx]} > 0) {
                $spk = $scparray[$scpidx]->[0];
                $count = $spk_count{$spk};
                $nutt1 = $scpcount[$scpidx-1];
                $nutt2 = $scpcount[$scpidx];
                if( abs( ($nutt2-$count) - ($nutt1+$count))
                    < abs($nutt2 - $nutt1))  { # Would decrease
                    # size-diff by reassigning spk...
                    $scpcount[$scpidx-1] += $count;
                    $scpcount[$scpidx] -= $count;
                    shift @{$scparray[$scpidx]};
                    push @{$scparray[$scpidx-1]}, $spk;
                    $changed = 1;
                }
            }
        }
    }
    # Now print out the files...
    for($scpidx = 0; $scpidx < $numscps; $scpidx++) {
        $scpfile = $OUTPUTS[$scpidx];
        ($scpfile ne '-' ? open($f_fh, '>', $scpfile)
                         : open($f_fh, '>&', \*STDOUT)) ||
            die "$0: Could not open scp file $scpfile for writing: $!\n";
        $count = 0;
        if(@{$scparray[$scpidx]} == 0) {
            print STDERR "$0: eError: split_scp.pl producing empty .scp file " .
                         "$scpfile (too many splits and too few speakers?)\n";
            $error = 1;
        } else {
            foreach $spk ( @{$scparray[$scpidx]} ) {
                print $f_fh @{$spk_data{$spk}};
                $count += $spk_count{$spk};
            }
            $count == $scpcount[$scpidx] || die "Count mismatch [code error]";
        }
        close($f_fh);
    }
} else {
   # This block is the "normal" case where there is no --utt2spk
   # option and we just break into equal size chunks.
    open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n";
    $numscps = @OUTPUTS;  # size of array.
    @F = ();
    while(<$i_fh>) {
        push @F, $_;
    }
    $numlines = @F;
    if($numlines == 0) {
        print STDERR "$0: error: empty input scp file $inscp\n";
        $error = 1;
    }
    $linesperscp = int( $numlines / $numscps); # the "whole part"..
    $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n";
    $remainder = $numlines - ($linesperscp * $numscps);
    ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder";
    # [just doing int() rounds down].
    $n = 0;
    for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) {
        $scpfile = $OUTPUTS[$scpidx];
        ($scpfile ne '-' ? open($o_fh, '>', $scpfile)
                         : open($o_fh, '>&', \*STDOUT)) ||
            die "$0: Could not open scp file $scpfile for writing: $!\n";
        for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) {
            print $o_fh $F[$n++];
        }
        close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n";
    }
    $n == $numlines || die "$n != $numlines [code error]";
}
exit ($error);
examples/aishell/e_paraformer/utils/text2token.py
New file
@@ -0,0 +1,141 @@
#!/usr/bin/env python3
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import re
import sys
import json
is_python2 = sys.version_info[0] == 2
def exist_or_not(i, match_pos):
    start_pos = None
    end_pos = None
    for pos in match_pos:
        if pos[0] <= i < pos[1]:
            start_pos = pos[0]
            end_pos = pos[1]
            break
    return start_pos, end_pos
def get_parser():
    parser = argparse.ArgumentParser(
        description="convert raw text to tokenized text",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--nchar",
        "-n",
        default=1,
        type=int,
        help="number of characters to split, i.e., \
                        aabb -> a a b b with -n 1 and aa bb with -n 2",
    )
    parser.add_argument("--skip-ncols", "-s", default=0, type=int, help="skip first n columns")
    parser.add_argument("--space", default="<space>", type=str, help="space symbol")
    parser.add_argument(
        "--non-lang-syms",
        "-l",
        default=None,
        type=str,
        help="list of non-linguistic symobles, e.g., <NOISE> etc.",
    )
    parser.add_argument("text", type=str, default=False, nargs="?", help="input text")
    parser.add_argument(
        "--trans_type",
        "-t",
        type=str,
        default="char",
        choices=["char", "phn"],
        help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 -
                        If trans_type is char,
                        read from SI1279.WRD file -> "bricks are an alternative"
                        Else if trans_type is phn,
                        read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l
                        sil t er n ih sil t ih v sil" """,
    )
    parser.add_argument(
        "--text_format",
        default="text",
        type=str,
        help="text, jsonl",
    )
    return parser
def main():
    parser = get_parser()
    args = parser.parse_args()
    rs = []
    if args.non_lang_syms is not None:
        with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f:
            nls = [x.rstrip() for x in f.readlines()]
            rs = [re.compile(re.escape(x)) for x in nls]
    if args.text:
        f = codecs.open(args.text, encoding="utf-8")
    else:
        f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer)
    sys.stdout = codecs.getwriter("utf-8")(sys.stdout if is_python2 else sys.stdout.buffer)
    line = f.readline()
    n = args.nchar
    while line:
        if args.text_format == "jsonl":
            data = json.loads(line.strip())
            line = data["target"]
        x = line.split()
        print(" ".join(x[: args.skip_ncols]), end=" ")
        a = " ".join(x[args.skip_ncols :])
        # get all matched positions
        match_pos = []
        for r in rs:
            i = 0
            while i >= 0:
                m = r.search(a, i)
                if m:
                    match_pos.append([m.start(), m.end()])
                    i = m.end()
                else:
                    break
        if args.trans_type == "phn":
            a = a.split(" ")
        else:
            if len(match_pos) > 0:
                chars = []
                i = 0
                while i < len(a):
                    start_pos, end_pos = exist_or_not(i, match_pos)
                    if start_pos is not None:
                        chars.append(a[start_pos:end_pos])
                        i = end_pos
                    else:
                        chars.append(a[i])
                        i += 1
                a = chars
            a = [a[j : j + n] for j in range(0, len(a), n)]
        a_flat = []
        for z in a:
            a_flat.append("".join(z))
        a_chars = [z.replace(" ", args.space) for z in a_flat]
        if args.trans_type == "phn":
            a_chars = [z.replace("sil", args.space) for z in a_chars]
        print(" ".join(a_chars))
        line = f.readline()
if __name__ == "__main__":
    main()
examples/aishell/e_paraformer/utils/text_tokenize.py
New file
@@ -0,0 +1,104 @@
import re
import argparse
def load_dict(seg_file):
    seg_dict = {}
    with open(seg_file, "r") as infile:
        for line in infile:
            s = line.strip().split()
            key = s[0]
            value = s[1:]
            seg_dict[key] = " ".join(value)
    return seg_dict
def forward_segment(text, dic):
    word_list = []
    i = 0
    while i < len(text):
        longest_word = text[i]
        for j in range(i + 1, len(text) + 1):
            word = text[i:j]
            if word in dic:
                if len(word) > len(longest_word):
                    longest_word = word
        word_list.append(longest_word)
        i += len(longest_word)
    return word_list
def tokenize(txt, seg_dict):
    out_txt = ""
    pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])")
    for word in txt:
        if pattern.match(word):
            if word in seg_dict:
                out_txt += seg_dict[word] + " "
            else:
                out_txt += "<unk>" + " "
        else:
            continue
    return out_txt.strip()
def get_parser():
    parser = argparse.ArgumentParser(
        description="text tokenize",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--text-file",
        "-t",
        default=False,
        required=True,
        type=str,
        help="input text",
    )
    parser.add_argument(
        "--seg-file",
        "-s",
        default=False,
        required=True,
        type=str,
        help="seg file",
    )
    parser.add_argument(
        "--txt-index",
        "-i",
        default=1,
        required=True,
        type=int,
        help="txt index",
    )
    parser.add_argument(
        "--output-dir",
        "-o",
        default=False,
        required=True,
        type=str,
        help="output dir",
    )
    return parser
def main():
    parser = get_parser()
    args = parser.parse_args()
    txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), "w")
    shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), "w")
    seg_dict = load_dict(args.seg_file)
    with open(args.text_file, "r") as infile:
        for line in infile:
            s = line.strip().split()
            text_id = s[0]
            text_list = forward_segment("".join(s[1:]).lower(), seg_dict)
            text = tokenize(text_list, seg_dict)
            lens = len(text.strip().split())
            txt_writer.write(text_id + " " + text + "\n")
            shape_writer.write(text_id + " " + str(lens) + "\n")
if __name__ == "__main__":
    main()
examples/aishell/e_paraformer/utils/text_tokenize.sh
New file
@@ -0,0 +1,35 @@
#!/usr/bin/env bash
# Begin configuration section.
nj=32
cmd=utils/run.pl
echo "$0 $@"
. utils/parse_options.sh || exit 1;
# tokenize configuration
text_dir=$1
seg_file=$2
logdir=$3
output_dir=$4
txt_dir=${output_dir}/txt; mkdir -p ${output_dir}/txt
mkdir -p ${logdir}
$cmd JOB=1:$nj $logdir/text_tokenize.JOB.log \
  python utils/text_tokenize.py -t ${text_dir}/txt/text.JOB.txt \
      -s ${seg_file} -i JOB -o ${txt_dir} \
      || exit 1;
# concatenate the text files together.
for n in $(seq $nj); do
  cat ${txt_dir}/text.$n.txt || exit 1
done > ${output_dir}/text || exit 1
for n in $(seq $nj); do
  cat ${txt_dir}/len.$n || exit 1
done > ${output_dir}/text_shape || exit 1
echo "$0: Succeeded text tokenize"
examples/aishell/e_paraformer/utils/textnorm_zh.py
New file
@@ -0,0 +1,911 @@
#!/usr/bin/env python3
# coding=utf-8
# Authors:
#   2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git)
#   2019.9 Jiayu DU
#
# requirements:
#   - python 3.X
# notes: python 2.X WILL fail or produce misleading results
import sys, os, argparse, codecs, string, re
# ================================================================================ #
#                                    basic constant
# ================================================================================ #
CHINESE_DIGIS = "零一二三四五六七八九"
BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖"
BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖"
SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万"
SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬"
LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载"
LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載"
SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万"
SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬"
ZERO_ALT = "〇"
ONE_ALT = "幺"
TWO_ALTS = ["两", "兩"]
POSITIVE = ["正", "正"]
NEGATIVE = ["负", "負"]
POINT = ["点", "點"]
# PLUS = [u'加', u'加']
# SIL = [u'杠', u'槓']
FILLER_CHARS = ["呃", "啊"]
ER_WHITELIST = (
    "(儿女|儿子|儿孙|女儿|儿媳|妻儿|"
    "胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|"
    "儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|"
    "佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)"
)
# 中文数字系统类型
NUMBERING_TYPES = ["low", "mid", "high"]
CURRENCY_NAMES = (
    "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|"
    "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)"
)
CURRENCY_UNITS = (
    "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)"
)
COM_QUANTIFIERS = (
    "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|"
    "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|"
    "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|"
    "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|"
    "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|"
    "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)"
)
# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git)
CHINESE_PUNC_STOP = "!?。。"
CHINESE_PUNC_NON_STOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏"
CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP
# ================================================================================ #
#                                    basic class
# ================================================================================ #
class ChineseChar(object):
    """
    中文字符
    每个字符对应简体和繁体,
    e.g. 简体 = '负', 繁体 = '負'
    转换时可转换为简体或繁体
    """
    def __init__(self, simplified, traditional):
        self.simplified = simplified
        self.traditional = traditional
        # self.__repr__ = self.__str__
    def __str__(self):
        return self.simplified or self.traditional or None
    def __repr__(self):
        return self.__str__()
class ChineseNumberUnit(ChineseChar):
    """
    中文数字/数位字符
    每个字符除繁简体外还有一个额外的大写字符
    e.g. '陆' 和 '陸'
    """
    def __init__(self, power, simplified, traditional, big_s, big_t):
        super(ChineseNumberUnit, self).__init__(simplified, traditional)
        self.power = power
        self.big_s = big_s
        self.big_t = big_t
    def __str__(self):
        return "10^{}".format(self.power)
    @classmethod
    def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False):
        if small_unit:
            return ChineseNumberUnit(
                power=index + 1,
                simplified=value[0],
                traditional=value[1],
                big_s=value[1],
                big_t=value[1],
            )
        elif numbering_type == NUMBERING_TYPES[0]:
            return ChineseNumberUnit(
                power=index + 8,
                simplified=value[0],
                traditional=value[1],
                big_s=value[0],
                big_t=value[1],
            )
        elif numbering_type == NUMBERING_TYPES[1]:
            return ChineseNumberUnit(
                power=(index + 2) * 4,
                simplified=value[0],
                traditional=value[1],
                big_s=value[0],
                big_t=value[1],
            )
        elif numbering_type == NUMBERING_TYPES[2]:
            return ChineseNumberUnit(
                power=pow(2, index + 3),
                simplified=value[0],
                traditional=value[1],
                big_s=value[0],
                big_t=value[1],
            )
        else:
            raise ValueError(
                "Counting type should be in {0} ({1} provided).".format(
                    NUMBERING_TYPES, numbering_type
                )
            )
class ChineseNumberDigit(ChineseChar):
    """
    中文数字字符
    """
    def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None):
        super(ChineseNumberDigit, self).__init__(simplified, traditional)
        self.value = value
        self.big_s = big_s
        self.big_t = big_t
        self.alt_s = alt_s
        self.alt_t = alt_t
    def __str__(self):
        return str(self.value)
    @classmethod
    def create(cls, i, v):
        return ChineseNumberDigit(i, v[0], v[1], v[2], v[3])
class ChineseMath(ChineseChar):
    """
    中文数位字符
    """
    def __init__(self, simplified, traditional, symbol, expression=None):
        super(ChineseMath, self).__init__(simplified, traditional)
        self.symbol = symbol
        self.expression = expression
        self.big_s = simplified
        self.big_t = traditional
CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath
class NumberSystem(object):
    """
    中文数字系统
    """
    pass
class MathSymbol(object):
    """
    用于中文数字系统的数学符号 (繁/简体), e.g.
    positive = ['正', '正']
    negative = ['负', '負']
    point = ['点', '點']
    """
    def __init__(self, positive, negative, point):
        self.positive = positive
        self.negative = negative
        self.point = point
    def __iter__(self):
        for v in self.__dict__.values():
            yield v
# class OtherSymbol(object):
#     """
#     其他符号
#     """
#
#     def __init__(self, sil):
#         self.sil = sil
#
#     def __iter__(self):
#         for v in self.__dict__.values():
#             yield v
# ================================================================================ #
#                                    basic utils
# ================================================================================ #
def create_system(numbering_type=NUMBERING_TYPES[1]):
    """
    根据数字系统类型返回创建相应的数字系统,默认为 mid
    NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型
        low:  '兆' = '亿' * '十' = $10^{9}$,  '京' = '兆' * '十', etc.
        mid:  '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc.
        high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc.
    返回对应的数字系统
    """
    # chinese number units of '亿' and larger
    all_larger_units = zip(
        LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL
    )
    larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)]
    # chinese number units of '十, 百, 千, 万'
    all_smaller_units = zip(
        SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL
    )
    smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)]
    # digis
    chinese_digis = zip(
        CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL
    )
    digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)]
    digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT
    digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT
    digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1]
    # symbols
    positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x)
    negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x)
    point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y)))
    # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y)))
    system = NumberSystem()
    system.units = smaller_units + larger_units
    system.digits = digits
    system.math = MathSymbol(positive_cn, negative_cn, point_cn)
    # system.symbols = OtherSymbol(sil_cn)
    return system
def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]):
    def get_symbol(char, system):
        for u in system.units:
            if char in [u.traditional, u.simplified, u.big_s, u.big_t]:
                return u
        for d in system.digits:
            if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]:
                return d
        for m in system.math:
            if char in [m.traditional, m.simplified]:
                return m
    def string2symbols(chinese_string, system):
        int_string, dec_string = chinese_string, ""
        for p in [system.math.point.simplified, system.math.point.traditional]:
            if p in chinese_string:
                int_string, dec_string = chinese_string.split(p)
                break
        return [get_symbol(c, system) for c in int_string], [
            get_symbol(c, system) for c in dec_string
        ]
    def correct_symbols(integer_symbols, system):
        """
        一百八 to 一百八十
        一亿一千三百万 to 一亿 一千万 三百万
        """
        if integer_symbols and isinstance(integer_symbols[0], CNU):
            if integer_symbols[0].power == 1:
                integer_symbols = [system.digits[1]] + integer_symbols
        if len(integer_symbols) > 1:
            if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU):
                integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None))
        result = []
        unit_count = 0
        for s in integer_symbols:
            if isinstance(s, CND):
                result.append(s)
                unit_count = 0
            elif isinstance(s, CNU):
                current_unit = CNU(s.power, None, None, None, None)
                unit_count += 1
            if unit_count == 1:
                result.append(current_unit)
            elif unit_count > 1:
                for i in range(len(result)):
                    if (
                        isinstance(result[-i - 1], CNU)
                        and result[-i - 1].power < current_unit.power
                    ):
                        result[-i - 1] = CNU(
                            result[-i - 1].power + current_unit.power, None, None, None, None
                        )
        return result
    def compute_value(integer_symbols):
        """
        Compute the value.
        When current unit is larger than previous unit, current unit * all previous units will be used as all previous units.
        e.g. '两千万' = 2000 * 10000 not 2000 + 10000
        """
        value = [0]
        last_power = 0
        for s in integer_symbols:
            if isinstance(s, CND):
                value[-1] = s.value
            elif isinstance(s, CNU):
                value[-1] *= pow(10, s.power)
                if s.power > last_power:
                    value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1]))
                    last_power = s.power
                value.append(0)
        return sum(value)
    system = create_system(numbering_type)
    int_part, dec_part = string2symbols(chinese_string, system)
    int_part = correct_symbols(int_part, system)
    int_str = str(compute_value(int_part))
    dec_str = "".join([str(d.value) for d in dec_part])
    if dec_part:
        return "{0}.{1}".format(int_str, dec_str)
    else:
        return int_str
def num2chn(
    number_string,
    numbering_type=NUMBERING_TYPES[1],
    big=False,
    traditional=False,
    alt_zero=False,
    alt_one=False,
    alt_two=True,
    use_zeros=True,
    use_units=True,
):
    def get_value(value_string, use_zeros=True):
        striped_string = value_string.lstrip("0")
        # record nothing if all zeros
        if not striped_string:
            return []
        # record one digits
        elif len(striped_string) == 1:
            if use_zeros and len(value_string) != len(striped_string):
                return [system.digits[0], system.digits[int(striped_string)]]
            else:
                return [system.digits[int(striped_string)]]
        # recursively record multiple digits
        else:
            result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string))
            result_string = value_string[: -result_unit.power]
            return (
                get_value(result_string)
                + [result_unit]
                + get_value(striped_string[-result_unit.power :])
            )
    system = create_system(numbering_type)
    int_dec = number_string.split(".")
    if len(int_dec) == 1:
        int_string = int_dec[0]
        dec_string = ""
    elif len(int_dec) == 2:
        int_string = int_dec[0]
        dec_string = int_dec[1]
    else:
        raise ValueError(
            "invalid input num string with more than one dot: {}".format(number_string)
        )
    if use_units and len(int_string) > 1:
        result_symbols = get_value(int_string)
    else:
        result_symbols = [system.digits[int(c)] for c in int_string]
    dec_symbols = [system.digits[int(c)] for c in dec_string]
    if dec_string:
        result_symbols += [system.math.point] + dec_symbols
    if alt_two:
        liang = CND(
            2,
            system.digits[2].alt_s,
            system.digits[2].alt_t,
            system.digits[2].big_s,
            system.digits[2].big_t,
        )
        for i, v in enumerate(result_symbols):
            if isinstance(v, CND) and v.value == 2:
                next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None
                previous_symbol = result_symbols[i - 1] if i > 0 else None
                if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))):
                    if next_symbol.power != 1 and (
                        (previous_symbol is None) or (previous_symbol.power != 1)
                    ):
                        result_symbols[i] = liang
    # if big is True, '两' will not be used and `alt_two` has no impact on output
    if big:
        attr_name = "big_"
        if traditional:
            attr_name += "t"
        else:
            attr_name += "s"
    else:
        if traditional:
            attr_name = "traditional"
        else:
            attr_name = "simplified"
    result = "".join([getattr(s, attr_name) for s in result_symbols])
    # if not use_zeros:
    #     result = result.strip(getattr(system.digits[0], attr_name))
    if alt_zero:
        result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s)
    if alt_one:
        result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s)
    for i, p in enumerate(POINT):
        if result.startswith(p):
            return CHINESE_DIGIS[0] + result
    # ^10, 11, .., 19
    if (
        len(result) >= 2
        and result[1]
        in [
            SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0],
            SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0],
        ]
        and result[0]
        in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]
    ):
        result = result[1:]
    return result
# ================================================================================ #
#                          different types of rewriters
# ================================================================================ #
class Cardinal:
    """
    CARDINAL类
    """
    def __init__(self, cardinal=None, chntext=None):
        self.cardinal = cardinal
        self.chntext = chntext
    def chntext2cardinal(self):
        return chn2num(self.chntext)
    def cardinal2chntext(self):
        return num2chn(self.cardinal)
class Digit:
    """
    DIGIT类
    """
    def __init__(self, digit=None, chntext=None):
        self.digit = digit
        self.chntext = chntext
    # def chntext2digit(self):
    #     return chn2num(self.chntext)
    def digit2chntext(self):
        return num2chn(self.digit, alt_two=False, use_units=False)
class TelePhone:
    """
    TELEPHONE类
    """
    def __init__(self, telephone=None, raw_chntext=None, chntext=None):
        self.telephone = telephone
        self.raw_chntext = raw_chntext
        self.chntext = chntext
    # def chntext2telephone(self):
    #     sil_parts = self.raw_chntext.split('<SIL>')
    #     self.telephone = '-'.join([
    #         str(chn2num(p)) for p in sil_parts
    #     ])
    #     return self.telephone
    def telephone2chntext(self, fixed=False):
        if fixed:
            sil_parts = self.telephone.split("-")
            self.raw_chntext = "<SIL>".join(
                [num2chn(part, alt_two=False, use_units=False) for part in sil_parts]
            )
            self.chntext = self.raw_chntext.replace("<SIL>", "")
        else:
            sp_parts = self.telephone.strip("+").split()
            self.raw_chntext = "<SP>".join(
                [num2chn(part, alt_two=False, use_units=False) for part in sp_parts]
            )
            self.chntext = self.raw_chntext.replace("<SP>", "")
        return self.chntext
class Fraction:
    """
    FRACTION类
    """
    def __init__(self, fraction=None, chntext=None):
        self.fraction = fraction
        self.chntext = chntext
    def chntext2fraction(self):
        denominator, numerator = self.chntext.split("分之")
        return chn2num(numerator) + "/" + chn2num(denominator)
    def fraction2chntext(self):
        numerator, denominator = self.fraction.split("/")
        return num2chn(denominator) + "分之" + num2chn(numerator)
class Date:
    """
    DATE类
    """
    def __init__(self, date=None, chntext=None):
        self.date = date
        self.chntext = chntext
    # def chntext2date(self):
    #     chntext = self.chntext
    #     try:
    #         year, other = chntext.strip().split('年', maxsplit=1)
    #         year = Digit(chntext=year).digit2chntext() + '年'
    #     except ValueError:
    #         other = chntext
    #         year = ''
    #     if other:
    #         try:
    #             month, day = other.strip().split('月', maxsplit=1)
    #             month = Cardinal(chntext=month).chntext2cardinal() + '月'
    #         except ValueError:
    #             day = chntext
    #             month = ''
    #         if day:
    #             day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1]
    #     else:
    #         month = ''
    #         day = ''
    #     date = year + month + day
    #     self.date = date
    #     return self.date
    def date2chntext(self):
        date = self.date
        try:
            year, other = date.strip().split("年", 1)
            year = Digit(digit=year).digit2chntext() + "年"
        except ValueError:
            other = date
            year = ""
        if other:
            try:
                month, day = other.strip().split("月", 1)
                month = Cardinal(cardinal=month).cardinal2chntext() + "月"
            except ValueError:
                day = date
                month = ""
            if day:
                day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1]
        else:
            month = ""
            day = ""
        chntext = year + month + day
        self.chntext = chntext
        return self.chntext
class Money:
    """
    MONEY类
    """
    def __init__(self, money=None, chntext=None):
        self.money = money
        self.chntext = chntext
    # def chntext2money(self):
    #     return self.money
    def money2chntext(self):
        money = self.money
        pattern = re.compile(r"(\d+(\.\d+)?)")
        matchers = pattern.findall(money)
        if matchers:
            for matcher in matchers:
                money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext())
        self.chntext = money
        return self.chntext
class Percentage:
    """
    PERCENTAGE类
    """
    def __init__(self, percentage=None, chntext=None):
        self.percentage = percentage
        self.chntext = chntext
    def chntext2percentage(self):
        return chn2num(self.chntext.strip().strip("百分之")) + "%"
    def percentage2chntext(self):
        return "百分之" + num2chn(self.percentage.strip().strip("%"))
def remove_erhua(text, er_whitelist):
    """
    去除儿化音词中的儿:
    他女儿在那边儿 -> 他女儿在那边
    """
    er_pattern = re.compile(er_whitelist)
    new_str = ""
    while re.search("儿", text):
        a = re.search("儿", text).span()
        remove_er_flag = 0
        if er_pattern.search(text):
            b = er_pattern.search(text).span()
            if b[0] <= a[0]:
                remove_er_flag = 1
        if remove_er_flag == 0:
            new_str = new_str + text[0 : a[0]]
            text = text[a[1] :]
        else:
            new_str = new_str + text[0 : b[1]]
            text = text[b[1] :]
    text = new_str + text
    return text
# ================================================================================ #
#                            NSW Normalizer
# ================================================================================ #
class NSWNormalizer:
    def __init__(self, raw_text):
        self.raw_text = "^" + raw_text + "$"
        self.norm_text = ""
    def _particular(self):
        text = self.norm_text
        pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))")
        matchers = pattern.findall(text)
        if matchers:
            # print('particular')
            for matcher in matchers:
                text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1)
        self.norm_text = text
        return self.norm_text
    def normalize(self):
        text = self.raw_text
        # 规范化日期
        pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)")
        matchers = pattern.findall(text)
        if matchers:
            # print('date')
            for matcher in matchers:
                text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1)
        # 规范化金钱
        pattern = re.compile(
            r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)"
        )
        matchers = pattern.findall(text)
        if matchers:
            # print('money')
            for matcher in matchers:
                text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1)
        # 规范化固话/手机号码
        # 手机
        # http://www.jihaoba.com/news/show/13680
        # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198
        # 联通:130、131、132、156、155、186、185、176
        # 电信:133、153、189、180、181、177
        pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D")
        matchers = pattern.findall(text)
        if matchers:
            # print('telephone')
            for matcher in matchers:
                text = text.replace(
                    matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1
                )
        # 固话
        pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D")
        matchers = pattern.findall(text)
        if matchers:
            # print('fixed telephone')
            for matcher in matchers:
                text = text.replace(
                    matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1
                )
        # 规范化分数
        pattern = re.compile(r"(\d+/\d+)")
        matchers = pattern.findall(text)
        if matchers:
            # print('fraction')
            for matcher in matchers:
                text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1)
        # 规范化百分数
        text = text.replace("%", "%")
        pattern = re.compile(r"(\d+(\.\d+)?%)")
        matchers = pattern.findall(text)
        if matchers:
            # print('percentage')
            for matcher in matchers:
                text = text.replace(
                    matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1
                )
        # 规范化纯数+量词
        pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS)
        matchers = pattern.findall(text)
        if matchers:
            # print('cardinal+quantifier')
            for matcher in matchers:
                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
        # 规范化数字编号
        pattern = re.compile(r"(\d{4,32})")
        matchers = pattern.findall(text)
        if matchers:
            # print('digit')
            for matcher in matchers:
                text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1)
        # 规范化纯数
        pattern = re.compile(r"(\d+(\.\d+)?)")
        matchers = pattern.findall(text)
        if matchers:
            # print('cardinal')
            for matcher in matchers:
                text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1)
        self.norm_text = text
        self._particular()
        return self.norm_text.lstrip("^").rstrip("$")
def nsw_test_case(raw_text):
    print("I:" + raw_text)
    print("O:" + NSWNormalizer(raw_text).normalize())
    print("")
def nsw_test():
    nsw_test_case("固话:0595-23865596或23880880。")
    nsw_test_case("固话:0595-23865596或23880880。")
    nsw_test_case("手机:+86 19859213959或15659451527。")
    nsw_test_case("分数:32477/76391。")
    nsw_test_case("百分数:80.03%。")
    nsw_test_case("编号:31520181154418。")
    nsw_test_case("纯数:2983.07克或12345.60米。")
    nsw_test_case("日期:1999年2月20日或09年3月15号。")
    nsw_test_case("金钱:12块5,34.5元,20.1万")
    nsw_test_case("特殊:O2O或B2C。")
    nsw_test_case("3456万吨")
    nsw_test_case("2938个")
    nsw_test_case("938")
    nsw_test_case("今天吃了115个小笼包231个馒头")
    nsw_test_case("有62%的概率")
if __name__ == "__main__":
    # nsw_test()
    p = argparse.ArgumentParser()
    p.add_argument("ifile", help="input filename, assume utf-8 encoding")
    p.add_argument("ofile", help="output filename")
    p.add_argument("--to_upper", action="store_true", help="convert to upper case")
    p.add_argument("--to_lower", action="store_true", help="convert to lower case")
    p.add_argument(
        "--has_key", action="store_true", help="input text has Kaldi's key as first field."
    )
    p.add_argument(
        "--remove_fillers", type=bool, default=True, help='remove filler chars such as "呃, 啊"'
    )
    p.add_argument(
        "--remove_erhua", type=bool, default=True, help='remove erhua chars such as "这儿"'
    )
    p.add_argument(
        "--log_interval", type=int, default=10000, help="log interval in number of processed lines"
    )
    args = p.parse_args()
    ifile = codecs.open(args.ifile, "r", "utf8")
    ofile = codecs.open(args.ofile, "w+", "utf8")
    n = 0
    for l in ifile:
        key = ""
        text = ""
        if args.has_key:
            cols = l.split(maxsplit=1)
            key = cols[0]
            if len(cols) == 2:
                text = cols[1].strip()
            else:
                text = ""
        else:
            text = l.strip()
        # cases
        if args.to_upper and args.to_lower:
            sys.stderr.write("text norm: to_upper OR to_lower?")
            exit(1)
        if args.to_upper:
            text = text.upper()
        if args.to_lower:
            text = text.lower()
        # Filler chars removal
        if args.remove_fillers:
            for ch in FILLER_CHARS:
                text = text.replace(ch, "")
        if args.remove_erhua:
            text = remove_erhua(text, ER_WHITELIST)
        # NSW(Non-Standard-Word) normalization
        text = NSWNormalizer(text).normalize()
        # Punctuations removal
        old_chars = CHINESE_PUNC_LIST + string.punctuation  # includes all CN and EN punctuations
        new_chars = " " * len(old_chars)
        del_chars = ""
        text = text.translate(str.maketrans(old_chars, new_chars, del_chars))
        #
        if args.has_key:
            ofile.write(key + "\t" + text + "\n")
        else:
            ofile.write(text + "\n")
        n += 1
        if n % args.log_interval == 0:
            sys.stderr.write("text norm: {} lines done.\n".format(n))
    sys.stderr.write("text norm: {} lines done in total.\n".format(n))
    ifile.close()
    ofile.close()
funasr/models/e_paraformer/__init__.py
funasr/models/e_paraformer/decoder.py
New file
@@ -0,0 +1,1193 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
from typing import List, Tuple
from funasr.register import tables
from funasr.models.scama import utils as myutils
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.decoder import DecoderLayer
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.sanm.attention import (
    MultiHeadedAttentionSANMDecoder,
    MultiHeadedAttentionCrossAtt,
)
class DecoderLayerSANM(torch.nn.Module):
    """Single decoder layer module.
    Args:
        size (int): Input dimension.
        self_attn (torch.nn.Module): Self-attention module instance.
            `MultiHeadedAttention` instance can be used as the argument.
        src_attn (torch.nn.Module): Self-attention module instance.
            `MultiHeadedAttention` instance can be used as the argument.
        feed_forward (torch.nn.Module): Feed-forward module instance.
            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
            can be used as the argument.
        dropout_rate (float): Dropout rate.
        normalize_before (bool): Whether to use layer_norm before the first block.
        concat_after (bool): Whether to concat attention layer's input and output.
            if True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            if False, no additional linear will be applied. i.e. x -> x + att(x)
    """
    def __init__(
        self,
        size,
        self_attn,
        src_attn,
        feed_forward,
        dropout_rate,
        normalize_before=True,
        concat_after=False,
    ):
        """Construct an DecoderLayer object."""
        super(DecoderLayerSANM, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.norm1 = LayerNorm(size)
        if self_attn is not None:
            self.norm2 = LayerNorm(size)
        if src_attn is not None:
            self.norm3 = LayerNorm(size)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        if self.concat_after:
            self.concat_linear1 = torch.nn.Linear(size + size, size)
            self.concat_linear2 = torch.nn.Linear(size + size, size)
        self.reserve_attn = False
        self.attn_mat = []
    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        """Compute decoded features.
        Args:
            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
            cache (List[torch.Tensor]): List of cached tensors.
                Each tensor shape should be (#batch, maxlen_out - 1, size).
        Returns:
            torch.Tensor: Output tensor(#batch, maxlen_out, size).
            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
        """
        # tgt = self.dropout(tgt)
        residual = tgt
        if self.normalize_before:
            tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn:
            if self.normalize_before:
                tgt = self.norm2(tgt)
            x, _ = self.self_attn(tgt, tgt_mask)
            x = residual + self.dropout(x)
        if self.src_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.norm3(x)
            if self.reserve_attn:
                x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
                self.attn_mat.append(attn_mat)
            else:
                x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False)
            x = residual + self.dropout(x_src_attn)
            # x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
        return x, tgt_mask, memory, memory_mask, cache
    def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
        residual = x
        x = self.norm3(x)
        x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
        return attn_mat
    def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        """Compute decoded features.
        Args:
            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
            cache (List[torch.Tensor]): List of cached tensors.
                Each tensor shape should be (#batch, maxlen_out - 1, size).
        Returns:
            torch.Tensor: Output tensor(#batch, maxlen_out, size).
            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
        """
        # tgt = self.dropout(tgt)
        residual = tgt
        if self.normalize_before:
            tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn:
            if self.normalize_before:
                tgt = self.norm2(tgt)
            if self.training:
                cache = None
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + self.dropout(x)
        if self.src_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.norm3(x)
            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
        return x, tgt_mask, memory, memory_mask, cache
    def forward_chunk(
        self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
    ):
        """Compute decoded features.
        Args:
            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
            cache (List[torch.Tensor]): List of cached tensors.
                Each tensor shape should be (#batch, maxlen_out - 1, size).
        Returns:
            torch.Tensor: Output tensor(#batch, maxlen_out, size).
            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
        """
        residual = tgt
        if self.normalize_before:
            tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn:
            if self.normalize_before:
                tgt = self.norm2(tgt)
            x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
            x = residual + self.dropout(x)
        if self.src_attn is not None:
            residual = x
            if self.normalize_before:
                x = self.norm3(x)
            x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
            x = residual + x
        return x, memory, fsmn_cache, opt_cache
@tables.register("decoder_classes", "ParaformerSANMDecoder")
class ParaformerSANMDecoder(BaseTransformerDecoder):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(
        self,
        vocab_size: int,
        encoder_output_size: int,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        self_attention_dropout_rate: float = 0.0,
        src_attention_dropout_rate: float = 0.0,
        input_layer: str = "embed",
        use_output_layer: bool = True,
        wo_input_layer: bool = False,
        pos_enc_class=PositionalEncoding,
        normalize_before: bool = True,
        concat_after: bool = False,
        att_layer_num: int = 6,
        kernel_size: int = 21,
        sanm_shfit: int = 0,
        lora_list: List[str] = None,
        lora_rank: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.1,
        chunk_multiply_factor: tuple = (1,),
        tf2torch_tensor_name_prefix_torch: str = "decoder",
        tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
    ):
        super().__init__(
            vocab_size=vocab_size,
            encoder_output_size=encoder_output_size,
            dropout_rate=dropout_rate,
            positional_dropout_rate=positional_dropout_rate,
            input_layer=input_layer,
            use_output_layer=use_output_layer,
            pos_enc_class=pos_enc_class,
            normalize_before=normalize_before,
        )
        attention_dim = encoder_output_size
        if wo_input_layer:
            self.embed = None
        else:
            if input_layer == "embed":
                self.embed = torch.nn.Sequential(
                    torch.nn.Embedding(vocab_size, attention_dim),
                    # pos_enc_class(attention_dim, positional_dropout_rate),
                )
            elif input_layer == "linear":
                self.embed = torch.nn.Sequential(
                    torch.nn.Linear(vocab_size, attention_dim),
                    torch.nn.LayerNorm(attention_dim),
                    torch.nn.Dropout(dropout_rate),
                    torch.nn.ReLU(),
                    pos_enc_class(attention_dim, positional_dropout_rate),
                )
            else:
                raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
        self.normalize_before = normalize_before
        if self.normalize_before:
            self.after_norm = LayerNorm(attention_dim)
        if use_output_layer:
            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
        else:
            self.output_layer = None
        self.att_layer_num = att_layer_num
        self.num_blocks = num_blocks
        if sanm_shfit is None:
            sanm_shfit = (kernel_size - 1) // 2
        self.decoders = repeat(
            att_layer_num,
            lambda lnum: DecoderLayerSANM(
                attention_dim,
                MultiHeadedAttentionSANMDecoder(
                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                ),
                MultiHeadedAttentionCrossAtt(
                    attention_heads,
                    attention_dim,
                    src_attention_dropout_rate,
                    lora_list,
                    lora_rank,
                    lora_alpha,
                    lora_dropout,
                ),
                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                dropout_rate,
                normalize_before,
                concat_after,
            ),
        )
        if num_blocks - att_layer_num <= 0:
            self.decoders2 = None
        else:
            self.decoders2 = repeat(
                num_blocks - att_layer_num,
                lambda lnum: DecoderLayerSANM(
                    attention_dim,
                    MultiHeadedAttentionSANMDecoder(
                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
                    ),
                    None,
                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                    dropout_rate,
                    normalize_before,
                    concat_after,
                ),
            )
        self.decoders3 = repeat(
            1,
            lambda lnum: DecoderLayerSANM(
                attention_dim,
                None,
                None,
                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                dropout_rate,
                normalize_before,
                concat_after,
            ),
        )
        self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
        self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
        self.chunk_multiply_factor = chunk_multiply_factor
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        chunk_mask: torch.Tensor = None,
        return_hidden: bool = False,
        return_both: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
        Args:
            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
            hlens: (batch)
            ys_in_pad:
                input token ids, int64 (batch, maxlen_out)
                if input_layer == "embed"
                input tensor (batch, maxlen_out, #mels) in the other cases
            ys_in_lens: (batch)
        Returns:
            (tuple): tuple containing:
            x: decoded token score before softmax (batch, maxlen_out, token)
                if use_output_layer is True,
            olens: (batch, )
        """
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        if chunk_mask is not None:
            memory_mask = memory_mask * chunk_mask
            if tgt_mask.size(1) != memory_mask.size(1):
                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
        if self.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
        x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
        if self.normalize_before:
            hidden = self.after_norm(x)
        olens = tgt_mask.sum(1)
        if self.output_layer is not None and return_hidden is False:
            x = self.output_layer(hidden)
            return x, olens
        if return_both:
            x = self.output_layer(hidden)
            return x, hidden, olens
        return hidden, olens
    def score(self, ys, state, x):
        """Score."""
        ys_mask = myutils.sequence_mask(
            torch.tensor([len(ys)], dtype=torch.int32), device=x.device
        )[:, :, None]
        logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
        return logp.squeeze(0), state
    def forward_asf2(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
        attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    def forward_asf6(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[1](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[2](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[3](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[4](tgt, tgt_mask, memory, memory_mask)
        attn_mat = self.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    def forward_chunk(
        self,
        memory: torch.Tensor,
        tgt: torch.Tensor,
        cache: dict = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
        Args:
            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
            hlens: (batch)
            ys_in_pad:
                input token ids, int64 (batch, maxlen_out)
                if input_layer == "embed"
                input tensor (batch, maxlen_out, #mels) in the other cases
            ys_in_lens: (batch)
        Returns:
            (tuple): tuple containing:
            x: decoded token score before softmax (batch, maxlen_out, token)
                if use_output_layer is True,
            olens: (batch, )
        """
        x = tgt
        if cache["decode_fsmn"] is None:
            cache_layer_num = len(self.decoders)
            if self.decoders2 is not None:
                cache_layer_num += len(self.decoders2)
            fsmn_cache = [None] * cache_layer_num
        else:
            fsmn_cache = cache["decode_fsmn"]
        if cache["opt"] is None:
            cache_layer_num = len(self.decoders)
            opt_cache = [None] * cache_layer_num
        else:
            opt_cache = cache["opt"]
        for i in range(self.att_layer_num):
            decoder = self.decoders[i]
            x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
                x,
                memory,
                fsmn_cache=fsmn_cache[i],
                opt_cache=opt_cache[i],
                chunk_size=cache["chunk_size"],
                look_back=cache["decoder_chunk_look_back"],
            )
        if self.num_blocks - self.att_layer_num > 1:
            for i in range(self.num_blocks - self.att_layer_num):
                j = i + self.att_layer_num
                decoder = self.decoders2[i]
                x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
                    x, memory, fsmn_cache=fsmn_cache[j]
                )
        for decoder in self.decoders3:
            x, memory, _, _ = decoder.forward_chunk(x, memory)
        if self.normalize_before:
            x = self.after_norm(x)
        if self.output_layer is not None:
            x = self.output_layer(x)
        cache["decode_fsmn"] = fsmn_cache
        if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1:
            cache["opt"] = opt_cache
        return x
    def forward_one_step(
        self,
        tgt: torch.Tensor,
        tgt_mask: torch.Tensor,
        memory: torch.Tensor,
        cache: List[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """Forward one step.
        Args:
            tgt: input token ids, int64 (batch, maxlen_out)
            tgt_mask: input token mask,  (batch, maxlen_out)
                      dtype=torch.uint8 in PyTorch 1.2-
                      dtype=torch.bool in PyTorch 1.2+ (include 1.2)
            memory: encoded memory, float32  (batch, maxlen_in, feat)
            cache: cached output list of (batch, max_time_out-1, size)
        Returns:
            y, cache: NN output value and cache per `self.decoders`.
            y.shape` is (batch, maxlen_out, token)
        """
        x = self.embed(tgt)
        if cache is None:
            cache_layer_num = len(self.decoders)
            if self.decoders2 is not None:
                cache_layer_num += len(self.decoders2)
            cache = [None] * cache_layer_num
        new_cache = []
        # for c, decoder in zip(cache, self.decoders):
        for i in range(self.att_layer_num):
            decoder = self.decoders[i]
            c = cache[i]
            x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
                x, tgt_mask, memory, None, cache=c
            )
            new_cache.append(c_ret)
        if self.num_blocks - self.att_layer_num > 1:
            for i in range(self.num_blocks - self.att_layer_num):
                j = i + self.att_layer_num
                decoder = self.decoders2[i]
                c = cache[j]
                x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
                    x, tgt_mask, memory, None, cache=c
                )
                new_cache.append(c_ret)
        for decoder in self.decoders3:
            x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
                x, tgt_mask, memory, None, cache=None
            )
        if self.normalize_before:
            y = self.after_norm(x[:, -1])
        else:
            y = x[:, -1]
        if self.output_layer is not None:
            y = torch.log_softmax(self.output_layer(y), dim=-1)
        return y, new_cache
class DecoderLayerSANMExport(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.self_attn = model.self_attn
        self.src_attn = model.src_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2 if hasattr(model, "norm2") else None
        self.norm3 = model.norm3 if hasattr(model, "norm3") else None
        self.size = model.size
    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
        if self.src_attn is not None:
            residual = x
            x = self.norm3(x)
            x = residual + self.src_attn(x, memory, memory_mask)
        return x, tgt_mask, memory, memory_mask, cache
    def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
        residual = x
        x = self.norm3(x)
        x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
        return attn_mat
@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
class ParaformerSANMDecoderExport(torch.nn.Module):
    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANMExport(d)
        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANMExport(d)
        for i, d in enumerate(self.model.decoders3):
            self.model.decoders3[i] = DecoderLayerSANMExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        return_hidden: bool = False,
        return_both: bool = False,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(x, tgt_mask, memory, memory_mask)
        if self.model.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
        hidden = self.after_norm(x)
        # x = self.output_layer(x)
        if self.output_layer is not None and return_hidden is False:
            x = self.output_layer(hidden)
            return x, ys_in_lens
        if return_both:
            x = self.output_layer(hidden)
            return x, hidden, ys_in_lens
        return hidden, ys_in_lens
    def forward_asf2(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        _, memory_mask = self.prepare_mask(memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
            tgt, tgt_mask, memory, memory_mask
        )
        attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    def forward_asf6(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        _, memory_mask = self.prepare_mask(memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](
            tgt, tgt_mask, memory, memory_mask
        )
        attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    """
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
            for _ in range(cache_num)
        ]
        return (tgt, memory, pre_acoustic_embeds, cache)
    def is_optimizable(self):
        return True
    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
               + ['cache_%d' % i for i in range(cache_num)]
    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['y'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            'tgt': {
                0: 'tgt_batch',
                1: 'tgt_length'
            },
            'memory': {
                0: 'memory_batch',
                1: 'memory_length'
            },
            'pre_acoustic_embeds': {
                0: 'acoustic_embeds_batch',
                1: 'acoustic_embeds_length',
            }
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update({
            'cache_%d' % d: {
                0: 'cache_%d_batch' % d,
                2: 'cache_%d_length' % d
            }
            for d in range(cache_num)
        })
        return ret
    """
@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANMExport(d)
        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANMExport(d)
        for i, d in enumerate(self.model.decoders3):
            self.model.decoders3[i] = DecoderLayerSANMExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        *args,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        out_caches = list()
        for i, decoder in enumerate(self.model.decoders):
            in_cache = args[i]
            x, tgt_mask, memory, memory_mask, out_cache = decoder(
                x, tgt_mask, memory, memory_mask, cache=in_cache
            )
            out_caches.append(out_cache)
        if self.model.decoders2 is not None:
            for i, decoder in enumerate(self.model.decoders2):
                in_cache = args[i + len(self.model.decoders)]
                x, tgt_mask, memory, memory_mask, out_cache = decoder(
                    x, tgt_mask, memory, memory_mask, cache=in_cache
                )
                out_caches.append(out_cache)
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
        x = self.after_norm(x)
        x = self.output_layer(x)
        return x, out_caches
    def get_dummy_inputs(self, enc_size):
        enc = torch.randn(2, 100, enc_size).type(torch.float32)
        enc_len = torch.tensor([30, 100], dtype=torch.int32)
        acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
        acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        cache = [
            torch.zeros(
                (2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
                dtype=torch.float32,
            )
            for _ in range(cache_num)
        ]
        return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
    def get_input_names(self):
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        return ["enc", "enc_len", "acoustic_embeds", "acoustic_embeds_len"] + [
            "in_cache_%d" % i for i in range(cache_num)
        ]
    def get_output_names(self):
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        return ["logits", "sample_ids"] + ["out_cache_%d" % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            "enc": {0: "batch_size", 1: "enc_length"},
            "acoustic_embeds": {0: "batch_size", 1: "token_length"},
            "enc_len": {
                0: "batch_size",
            },
            "acoustic_embeds_len": {
                0: "batch_size",
            },
        }
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        ret.update(
            {
                "in_cache_%d"
                % d: {
                    0: "batch_size",
                }
                for d in range(cache_num)
            }
        )
        ret.update(
            {
                "out_cache_%d"
                % d: {
                    0: "batch_size",
                }
                for d in range(cache_num)
            }
        )
        return ret
@tables.register("decoder_classes", "ParaformerSANDecoder")
class ParaformerSANDecoder(BaseTransformerDecoder):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(
        self,
        vocab_size: int,
        encoder_output_size: int,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        self_attention_dropout_rate: float = 0.0,
        src_attention_dropout_rate: float = 0.0,
        input_layer: str = "embed",
        use_output_layer: bool = True,
        pos_enc_class=PositionalEncoding,
        normalize_before: bool = True,
        concat_after: bool = False,
        embeds_id: int = -1,
    ):
        super().__init__(
            vocab_size=vocab_size,
            encoder_output_size=encoder_output_size,
            dropout_rate=dropout_rate,
            positional_dropout_rate=positional_dropout_rate,
            input_layer=input_layer,
            use_output_layer=use_output_layer,
            pos_enc_class=pos_enc_class,
            normalize_before=normalize_before,
        )
        attention_dim = encoder_output_size
        self.decoders = repeat(
            num_blocks,
            lambda lnum: DecoderLayer(
                attention_dim,
                MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                dropout_rate,
                normalize_before,
                concat_after,
            ),
        )
        self.embeds_id = embeds_id
        self.attention_dim = attention_dim
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
        Args:
            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
            hlens: (batch)
            ys_in_pad:
                input token ids, int64 (batch, maxlen_out)
                if input_layer == "embed"
                input tensor (batch, maxlen_out, #mels) in the other cases
            ys_in_lens: (batch)
        Returns:
            (tuple): tuple containing:
            x: decoded token score before softmax (batch, maxlen_out, token)
                if use_output_layer is True,
            olens: (batch, )
        """
        tgt = ys_in_pad
        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
        memory = hs_pad
        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
        # Padding for Longformer
        if memory_mask.shape[-1] != memory.shape[1]:
            padlen = memory.shape[1] - memory_mask.shape[-1]
            memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
        # x = self.embed(tgt)
        x = tgt
        embeds_outputs = None
        for layer_id, decoder in enumerate(self.decoders):
            x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, memory_mask)
            if layer_id == self.embeds_id:
                embeds_outputs = x
        if self.normalize_before:
            x = self.after_norm(x)
        if self.output_layer is not None:
            x = self.output_layer(x)
        olens = tgt_mask.sum(1)
        if embeds_outputs is not None:
            return x, olens, embeds_outputs
        else:
            return x, olens
@tables.register("decoder_classes", "ParaformerDecoderSANExport")
class ParaformerDecoderSANExport(torch.nn.Module):
    def __init__(
        self,
        model,
        max_seq_len=512,
        model_name="decoder",
        onnx: bool = True,
    ):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.transformer.decoder import DecoderLayerExport
        from funasr.models.transformer.attention import MultiHeadedAttentionExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.src_attn, MultiHeadedAttention):
                d.src_attn = MultiHeadedAttentionExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        x, tgt_mask, memory, memory_mask = self.model.decoders(x, tgt_mask, memory, memory_mask)
        x = self.after_norm(x)
        x = self.output_layer(x)
        return x, ys_in_lens
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros(
                (1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)
            )
            for _ in range(cache_num)
        ]
        return (tgt, memory, pre_acoustic_embeds, cache)
    def is_optimizable(self):
        return True
    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ["tgt", "memory", "pre_acoustic_embeds"] + ["cache_%d" % i for i in range(cache_num)]
    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            "tgt": {0: "tgt_batch", 1: "tgt_length"},
            "memory": {0: "memory_batch", 1: "memory_length"},
            "pre_acoustic_embeds": {
                0: "acoustic_embeds_batch",
                1: "acoustic_embeds_length",
            },
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update(
            {
                "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
                for d in range(cache_num)
            }
        )
        return ret
funasr/models/e_paraformer/export_meta.py
New file
@@ -0,0 +1,86 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import types
import torch
from funasr.register import tables
def export_rebuild_model(model, **kwargs):
    model.device = kwargs.get("device")
    is_onnx = kwargs.get("type", "onnx") == "onnx"
    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
    model.encoder = encoder_class(model.encoder, onnx=is_onnx)
    predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
    model.predictor = predictor_class(model.predictor, onnx=is_onnx)
    decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
    model.decoder = decoder_class(model.decoder, onnx=is_onnx)
    from funasr.utils.torch_function import sequence_mask
    model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
    model.forward = types.MethodType(export_forward, model)
    model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
    model.export_input_names = types.MethodType(export_input_names, model)
    model.export_output_names = types.MethodType(export_output_names, model)
    model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
    model.export_name = types.MethodType(export_name, model)
    model.export_name = 'model'
    return model
def export_forward(
    self,
    speech: torch.Tensor,
    speech_lengths: torch.Tensor,
):
    # a. To device
    batch = {"speech": speech, "speech_lengths": speech_lengths}
    # batch = to_device(batch, device=self.device)
    enc, enc_len = self.encoder(**batch)
    mask = self.make_pad_mask(enc_len)[:, None, :]
    pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
    pre_token_length = pre_token_length.floor().type(torch.int32)
    decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
    decoder_out = torch.log_softmax(decoder_out, dim=-1)
    # sample_ids = decoder_out.argmax(dim=-1)
    return decoder_out, pre_token_length
def export_dummy_inputs(self):
    speech = torch.randn(2, 30, 560)
    speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
    return (speech, speech_lengths)
def export_input_names(self):
    return ["speech", "speech_lengths"]
def export_output_names(self):
    return ["logits", "token_num"]
def export_dynamic_axes(self):
    return {
        "speech": {0: "batch_size", 1: "feats_length"},
        "speech_lengths": {
            0: "batch_size",
        },
        "logits": {0: "batch_size", 1: "logits_length"},
    }
def export_name(
    self,
):
    return "model.onnx"
funasr/models/e_paraformer/model.py
New file
@@ -0,0 +1,670 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# Copyright 2024 Kun Zou (chinazoukun@gmail.com). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import time
import copy
import torch
import logging
from torch.cuda.amp import autocast
from typing import Union, Dict, List, Tuple, Optional
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import to_device
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos, add_sos_and_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
@tables.register("model_classes", "EParaformer")
class EParaformer(torch.nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    Author: Kun Zou, chinazoukun@gmail.com
    E-Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognition
    https://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf
    """
    def __init__(
        self,
        specaug: Optional[str] = None,
        specaug_conf: Optional[Dict] = None,
        normalize: str = None,
        normalize_conf: Optional[Dict] = None,
        encoder: str = None,
        encoder_conf: Optional[Dict] = None,
        decoder: str = None,
        decoder_conf: Optional[Dict] = None,
        ctc: str = None,
        ctc_conf: Optional[Dict] = None,
        predictor: str = None,
        predictor_conf: Optional[Dict] = None,
        ctc_weight: float = 0.5,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        # report_cer: bool = True,
        # report_wer: bool = True,
        # sym_space: str = "<space>",
        # sym_blank: str = "<blank>",
        # extract_feats_in_collect_stats: bool = True,
        # predictor=None,
        predictor_weight: float = 0.0,
        predictor_bias: int = 2,
        sampling_ratio: float = 0.2,
        share_embedding: bool = False,
        # preencoder: Optional[AbsPreEncoder] = None,
        # postencoder: Optional[AbsPostEncoder] = None,
        use_1st_decoder_loss: bool = True,
        **kwargs,
    ):
        super().__init__()
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = tables.normalize_classes.get(normalize)
            normalize = normalize_class(**normalize_conf)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(input_size=input_size, **encoder_conf)
        encoder_output_size = encoder.output_size()
        if decoder is not None:
            decoder_class = tables.decoder_classes.get(decoder)
            decoder = decoder_class(
                vocab_size=vocab_size,
                encoder_output_size=encoder_output_size,
                **decoder_conf,
            )
        if ctc_weight > 0.0:
            if ctc_conf is None:
                ctc_conf = {}
            ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
        if predictor is not None:
            predictor_class = tables.predictor_classes.get(predictor)
            predictor = predictor_class(**predictor_conf)
        # note that eos is the same as sos (equivalent ID)
        self.blank_id = blank_id
        self.sos = sos if sos is not None else vocab_size - 1
        self.eos = eos if eos is not None else vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.ctc_weight = ctc_weight
        # self.token_list = token_list.copy()
        #
        # self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        # self.preencoder = preencoder
        # self.postencoder = postencoder
        self.encoder = encoder
        #
        # if not hasattr(self.encoder, "interctc_use_conditioning"):
        #     self.encoder.interctc_use_conditioning = False
        # if self.encoder.interctc_use_conditioning:
        #     self.encoder.conditioning_layer = torch.nn.Linear(
        #         vocab_size, self.encoder.output_size()
        #     )
        #
        # self.error_calculator = None
        #
        if ctc_weight == 1.0:
            self.decoder = None
        else:
            self.decoder = decoder
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        if use_1st_decoder_loss:
            self.criterion_att_1st = LabelSmoothingLoss(
                size=vocab_size,
                padding_idx=ignore_id,
                smoothing=lsm_weight,
                normalize_length=length_normalized_loss,
            )
        #
        # if report_cer or report_wer:
        #     self.error_calculator = ErrorCalculator(
        #         token_list, sym_space, sym_blank, report_cer, report_wer
        #     )
        #
        if ctc_weight == 0.0:
            self.ctc = None
        else:
            self.ctc = ctc
        #
        # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
        self.predictor = predictor
        self.predictor_weight = predictor_weight
        self.predictor_bias = predictor_bias
        self.sampling_ratio = sampling_ratio
        self.criterion_pre = mae_loss(normalize_length=length_normalized_loss)
        self.share_embedding = share_embedding
        if self.share_embedding:
            self.decoder.embed = None
        self.use_1st_decoder_loss = use_1st_decoder_loss
        self.length_normalized_loss = length_normalized_loss
        self.beam_search = None
        self.error_calculator = None
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        loss_ctc, cer_ctc = None, None
        loss_pre = None
        stats = dict()
        # decoder: CTC branch
        if self.ctc_weight != 0.0:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        # decoder: Attention decoder branch
        loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss(
            encoder_out, encoder_out_lens, text, text_lengths
        )
        # 3. CTC-Att loss definition
        if self.ctc_weight == 0.0:
            loss = loss_att + loss_pre * self.predictor_weight
        else:
            loss = (
                self.ctc_weight * loss_ctc
                + (1 - self.ctc_weight) * loss_att
                + loss_pre * self.predictor_weight
            )
        if pre_loss_att is not None:
            loss += pre_loss_att
        # Collect Attn branch stats
        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
        stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None
        stats["acc"] = acc_att
        stats["cer"] = cer_att
        stats["wer"] = wer_att
        stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
        stats["loss"] = torch.clone(loss.detach())
        stats["batch_size"] = batch_size
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = (text_lengths + self.predictor_bias).sum()
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        # Forward encoder
        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        return encoder_out, encoder_out_lens
    def calc_predictor(self, encoder_out, encoder_out_lens):
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(
            encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id
        )
        return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index
    def cal_decoder_with_predictor(
        self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
    ):
        decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
    def _calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        encoder_out_mask = (
            ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]
        ).to(encoder_out.device)
        if self.predictor_bias == 1:
            _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
            ys_pad_lens = ys_pad_lens + self.predictor_bias
        if self.predictor_bias == 2:
            _, ys_pad = add_sos_and_eos(ys_pad, self.sos, self.eos, self.ignore_id)
            ys_pad_lens = ys_pad_lens + self.predictor_bias
        pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(
            encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id
        )
        # 0. sampler
        decoder_out_1st = None
        pre_loss_att = None
        if self.sampling_ratio > 0.0:
            if self.use_1st_decoder_loss:
                sematic_embeds, decoder_out_1st = self.sampler_with_grad(
                    encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds
                )
            else:
                sematic_embeds, decoder_out_1st = self.sampler(
                    encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds
                )
        else:
            sematic_embeds = pre_acoustic_embeds
        # 1. Forward decoder
        decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens)
        decoder_out, _ = decoder_outs[0], decoder_outs[1]
        if decoder_out_1st is None:
            decoder_out_1st = decoder_out
        # 2. Compute attention loss
        if self.use_1st_decoder_loss:
            pre_loss_att = self.criterion_att_1st(decoder_out_1st, ys_pad)
        loss_att = self.criterion_att(decoder_out, ys_pad)
        acc_att = th_accuracy(
            decoder_out_1st.view(-1, self.vocab_size),
            ys_pad,
            ignore_label=self.ignore_id,
        )
        loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
        # Compute cer/wer using attention-decoder
        if self.training or self.error_calculator is None:
            cer_att, wer_att = None, None
        else:
            ys_hat = decoder_out_1st.argmax(dim=-1)
            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
        return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
    def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        with torch.no_grad():
            decoder_outs = self.decoder(
                encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
            )
            decoder_out, _ = decoder_outs[0], decoder_outs[1]
            pred_tokens = decoder_out.argmax(-1)
            nonpad_positions = ys_pad.ne(self.ignore_id)
            seq_lens = (nonpad_positions).sum(1)
            same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
            input_mask = torch.ones_like(nonpad_positions)
            bsz, seq_len = ys_pad.size()
            for li in range(bsz):
                target_num = (
                    ((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio
                ).long()
                if target_num > 0:
                    input_mask[li].scatter_(
                        dim=0,
                        index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
                        value=0,
                    )
            input_mask = input_mask.eq(1)
            input_mask = input_mask.masked_fill(~nonpad_positions, False)
            input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
        sematic_embeds = pre_acoustic_embeds.masked_fill(
            ~input_mask_expand_dim, 0
        ) + ys_pad_embed.masked_fill(input_mask_expand_dim, 0)
        return sematic_embeds * tgt_mask, decoder_out * tgt_mask
    def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
            ys_pad.device
        )
        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
        if self.share_embedding:
            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
        else:
            ys_pad_embed = self.decoder.embed(ys_pad_masked)
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
        )
        decoder_out, _ = decoder_outs[0], decoder_outs[1]
        pred_tokens = decoder_out.argmax(-1)
        nonpad_positions = ys_pad.ne(self.ignore_id)
        seq_lens = (nonpad_positions).sum(1)
        same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
        input_mask = torch.ones_like(nonpad_positions)
        bsz, seq_len = ys_pad.size()
        for li in range(bsz):
            target_num = (
                ((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio
            ).long()
            if target_num > 0:
                input_mask[li].scatter_(
                    dim=0,
                    index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
                    value=0,
                )
        input_mask = input_mask.eq(1)
        input_mask = input_mask.masked_fill(~nonpad_positions, False)
        input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
        sematic_embeds = pre_acoustic_embeds.masked_fill(
            ~input_mask_expand_dim, 0
        ) + ys_pad_embed.masked_fill(input_mask_expand_dim, 0)
        return sematic_embeds * tgt_mask, decoder_out * tgt_mask
    def _calc_ctc_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        # Calc CTC loss
        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
        # Calc CER using CTC
        cer_ctc = None
        if not self.training and self.error_calculator is not None:
            ys_hat = self.ctc.argmax(encoder_out).data
            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
        return loss_ctc, cer_ctc
    def init_beam_search(
        self,
        **kwargs,
    ):
        from funasr.models.paraformer.search import BeamSearchPara
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus
        # 1. Build ASR model
        scorers = {}
        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(ctc=ctc)
        token_list = kwargs.get("token_list")
        scorers.update(
            length_bonus=LengthBonus(len(token_list)),
        )
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        weights = dict(
            decoder=1.0 - kwargs.get("decoding_ctc_weight"),
            ctc=kwargs.get("decoding_ctc_weight", 0.0),
            lm=kwargs.get("lm_weight", 0.0),
            ngram=kwargs.get("ngram_weight", 0.0),
            length_bonus=kwargs.get("penalty", 0.0),
        )
        beam_search = BeamSearchPara(
            beam_size=kwargs.get("beam_size", 2),
            weights=weights,
            scorers=scorers,
            sos=self.sos,
            eos=self.eos,
            vocab_size=len(token_list),
            token_list=token_list,
            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
        )
        # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
        # for scorer in scorers.values():
        #     if isinstance(scorer, torch.nn.Module):
        #         scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval()
        self.beam_search = beam_search
    def inference(
        self,
        data_in,
        data_lengths=None,
        key: list = None,
        tokenizer=None,
        frontend=None,
        **kwargs,
    ):
        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        pred_timestamp = kwargs.get("pred_timestamp", False)
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
        if (
            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
        ):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is not None:
                speech_lengths = speech_lengths.squeeze(-1)
            else:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_text_image_video(
                data_in,
                fs=frontend.fs,
                audio_fs=kwargs.get("fs", 16000),
                data_type=kwargs.get("data_type", "sound"),
                tokenizer=tokenizer,
            )
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(
                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
            )
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = (
                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
            )
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # Encoder
        if kwargs.get("fp16", False):
            speech = speech.half()
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        # predictor
        predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = (
            predictor_outs[0],
            predictor_outs[1],
            predictor_outs[2],
            predictor_outs[3],
        )
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        decoder_outs = self.cal_decoder_with_predictor(
            encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length
        )
        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        results = []
        b, n, d = decoder_out.size()
        if isinstance(key[0], (list, tuple)):
            key = key[0]
        if len(key) < b:
            key = key * b
        for i in range(b):
            x = encoder_out[i, : encoder_out_lens[i], :]
            am_scores = decoder_out[i, : pre_token_length[i], :]
            if self.beam_search is not None:
                nbest_hyps = self.beam_search(
                    x=x,
                    am_scores=am_scores,
                    maxlenratio=kwargs.get("maxlenratio", 0.0),
                    minlenratio=kwargs.get("minlenratio", 0.0),
                )
                nbest_hyps = nbest_hyps[: self.nbest]
            else:
                yseq = am_scores.argmax(dim=-1)
                score = am_scores.max(dim=-1)[0]
                score = torch.sum(score, dim=-1)
                # pad with mask tokens to ensure compatibility with sos/eos tokens
                yseq = torch.tensor([self.sos] + yseq.tolist() + [self.eos], device=yseq.device)
                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
            for nbest_idx, hyp in enumerate(nbest_hyps):
                ibest_writer = None
                if kwargs.get("output_dir") is not None:
                    if not hasattr(self, "writer"):
                        self.writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = self.writer[f"{nbest_idx+1}best_recog"]
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                # remove blank symbol id, which is assumed to be 0
                token_int = list(
                    filter(
                        lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
                    )
                )
                if tokenizer is not None:
                    # Change integer-ids to tokens
                    token = tokenizer.ids2tokens(token_int)
                    text_postprocessed = tokenizer.tokens2text(token)
                    if pred_timestamp:
                        timestamp_str, timestamp = ts_prediction_lfr6_standard(
                            pre_peak_index[i],
                            alphas[i],
                            copy.copy(token),
                            vad_offset=kwargs.get("begin_time", 0),
                            upsample_rate=1,
                        )
                        if not hasattr(tokenizer, "bpemodel"):
                            text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
                        result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,}
                    else:
                        if not hasattr(tokenizer, "bpemodel"):
                            text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                        result_i = {"key": key[i], "text": text_postprocessed}
                    if ibest_writer is not None:
                        ibest_writer["token"][key[i]] = " ".join(token)
                        # ibest_writer["text"][key[i]] = text
                        ibest_writer["text"][key[i]] = text_postprocessed
                else:
                    result_i = {"key": key[i], "token_int": token_int}
                results.append(result_i)
        return results, meta_data
    def export(self, **kwargs):
        from .export_meta import export_rebuild_model
        if "max_seq_len" not in kwargs:
            kwargs["max_seq_len"] = 512
        models = export_rebuild_model(model=self, **kwargs)
        return models
funasr/models/e_paraformer/pif_predictor.py
New file
@@ -0,0 +1,107 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# Copyright 2024 Kun Zou (chinazoukun@gmail.com). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import logging
import numpy as np
from funasr.register import tables
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from torch.cuda.amp import autocast
@tables.register("predictor_classes", "PifPredictor")
class PifPredictor(torch.nn.Module):
    """
    Author: Kun Zou, chinazoukun@gmail.com
    E-Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognition
    https://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf
    """
    def __init__(
        self,
        idim,
        l_order,
        r_order,
        threshold=1.0,
        dropout=0.1,
        smooth_factor=1.0,
        noise_threshold=0,
        sigma=0.5,
        bias=0.0,
        sigma_heads=4,
    ):
        super().__init__()
        self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
        self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
        self.cif_output = torch.nn.Linear(idim, 1)
        self.dropout = torch.nn.Dropout(p=dropout)
        self.threshold = threshold
        self.smooth_factor = smooth_factor
        self.noise_threshold = noise_threshold
        self.sigma = torch.nn.Parameter(torch.tensor([sigma]*sigma_heads))
        self.bias = torch.nn.Parameter(torch.tensor([bias]*sigma_heads))
        self.sigma_heads = sigma_heads
    def forward(
        self,
        hidden,
        target_label=None,
        mask=None,
        ignore_id=-1,
        mask_chunk_predictor=None,
        target_label_length=None,
    ):
        with autocast(False):
            h = hidden
            context = h.transpose(1, 2)
            queries = self.pad(context)
            memory = self.cif_conv1d(queries)
            output = memory + context
            output = self.dropout(output)
            output = output.transpose(1, 2)
            output = torch.relu(output)
            output = self.cif_output(output)
            alphas = torch.sigmoid(output)
            alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
            if mask is not None:
                mask = mask.transpose(-1, -2).float()
                alphas = alphas * mask
            if mask_chunk_predictor is not None:
                alphas = alphas * mask_chunk_predictor
            alphas = alphas.squeeze(-1)
            mask = mask.squeeze(-1)
            if target_label_length is not None:
                target_length = target_label_length
            elif target_label is not None:
                target_mask = (target_label != ignore_id).float()
                target_length = target_mask.sum(-1)
            else:
                target_mask = None
                target_length = None
            token_num = alphas.sum(-1)
            if target_length is not None:
                alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
                max_token_num = torch.max(target_length)
            else:
                token_num_int = token_num.round()
                alphas *=(token_num_int / token_num)[:, None]
                max_token_num = torch.max(token_num_int)
            alignment = torch.cumsum(alphas, dim=-1)
            fire_positions = (torch.arange(max_token_num) + 0.5).type_as(alphas).unsqueeze(0)
            scores = - ((fire_positions[:, None, :, None] - alignment[:, None, None, :]) * self.sigma[None, :, None, None]) **2 + self.bias[None, :, None, None]
            scores = scores.masked_fill(~(mask[:, None, None, :].to(torch.bool)), float("-inf"))
            weights = torch.softmax(scores, dim=-1)
            n_hidden = hidden.view(hidden.size(0), -1, self.sigma_heads, hidden.size(-1) // self.sigma_heads).transpose(1, 2)
            acoustic_embeds = torch.matmul(weights, n_hidden).transpose(1,2).contiguous().view(hidden.size(0), -1, hidden.size(-1))
            if target_mask is not None:
                acoustic_embeds *= target_mask[:, :, None]
            cif_peak = None
        return acoustic_embeds, token_num, alphas, cif_peak
funasr/models/e_paraformer/search.py
New file
@@ -0,0 +1,451 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import logging
from itertools import chain
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from funasr.metrics.common import end_detect
from funasr.models.transformer.scorers.scorer_interface import (
    PartialScorerInterface,
    ScorerInterface,
)
class Hypothesis(NamedTuple):
    """Hypothesis data type."""
    yseq: torch.Tensor
    score: Union[float, torch.Tensor] = 0
    scores: Dict[str, Union[float, torch.Tensor]] = dict()
    states: Dict[str, Any] = dict()
    def asdict(self) -> dict:
        """Convert data to JSON-friendly dict."""
        return self._replace(
            yseq=self.yseq.tolist(),
            score=float(self.score),
            scores={k: float(v) for k, v in self.scores.items()},
        )._asdict()
class BeamSearchPara(torch.nn.Module):
    """Beam search implementation."""
    def __init__(
        self,
        scorers: Dict[str, ScorerInterface],
        weights: Dict[str, float],
        beam_size: int,
        vocab_size: int,
        sos: int,
        eos: int,
        token_list: List[str] = None,
        pre_beam_ratio: float = 1.5,
        pre_beam_score_key: str = None,
    ):
        """Initialize beam search.
        Args:
            scorers (dict[str, ScorerInterface]): Dict of decoder modules
                e.g., Decoder, CTCPrefixScorer, LM
                The scorer will be ignored if it is `None`
            weights (dict[str, float]): Dict of weights for each scorers
                The scorer will be ignored if its weight is 0
            beam_size (int): The number of hypotheses kept during search
            vocab_size (int): The number of vocabulary
            sos (int): Start of sequence id
            eos (int): End of sequence id
            token_list (list[str]): List of tokens for debug log
            pre_beam_score_key (str): key of scores to perform pre-beam search
            pre_beam_ratio (float): beam size in the pre-beam search
                will be `int(pre_beam_ratio * beam_size)`
        """
        super().__init__()
        # set scorers
        self.weights = weights
        self.scorers = dict()
        self.full_scorers = dict()
        self.part_scorers = dict()
        # this module dict is required for recursive cast
        # `self.to(device, dtype)` in `recog.py`
        self.nn_dict = torch.nn.ModuleDict()
        for k, v in scorers.items():
            w = weights.get(k, 0)
            if w == 0 or v is None:
                continue
            assert isinstance(
                v, ScorerInterface
            ), f"{k} ({type(v)}) does not implement ScorerInterface"
            self.scorers[k] = v
            if isinstance(v, PartialScorerInterface):
                self.part_scorers[k] = v
            else:
                self.full_scorers[k] = v
            if isinstance(v, torch.nn.Module):
                self.nn_dict[k] = v
        # set configurations
        self.sos = sos
        self.eos = eos
        self.token_list = token_list
        self.pre_beam_size = int(pre_beam_ratio * beam_size)
        self.beam_size = beam_size
        self.n_vocab = vocab_size
        if (
            pre_beam_score_key is not None
            and pre_beam_score_key != "full"
            and pre_beam_score_key not in self.full_scorers
        ):
            raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
        self.pre_beam_score_key = pre_beam_score_key
        self.do_pre_beam = (
            self.pre_beam_score_key is not None
            and self.pre_beam_size < self.n_vocab
            and len(self.part_scorers) > 0
        )
    def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
        """Get an initial hypothesis data.
        Args:
            x (torch.Tensor): The encoder output feature
        Returns:
            Hypothesis: The initial hypothesis.
        """
        init_states = dict()
        init_scores = dict()
        for k, d in self.scorers.items():
            init_states[k] = d.init_state(x)
            init_scores[k] = 0.0
        return [
            Hypothesis(
                score=0.0,
                scores=init_scores,
                states=init_states,
                yseq=torch.tensor([self.sos], device=x.device),
            )
        ]
    @staticmethod
    def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
        """Append new token to prefix tokens.
        Args:
            xs (torch.Tensor): The prefix token
            x (int): The new token to append
        Returns:
            torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
        """
        x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
        return torch.cat((xs, x))
    def score_full(
        self, hyp: Hypothesis, x: torch.Tensor
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
        """Score new hypothesis by `self.full_scorers`.
        Args:
            hyp (Hypothesis): Hypothesis with prefix tokens to score
            x (torch.Tensor): Corresponding input feature
        Returns:
            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
                score dict of `hyp` that has string keys of `self.full_scorers`
                and tensor score values of shape: `(self.n_vocab,)`,
                and state dict that has string keys
                and state values of `self.full_scorers`
        """
        scores = dict()
        states = dict()
        for k, d in self.full_scorers.items():
            scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
        return scores, states
    def score_partial(
        self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
        """Score new hypothesis by `self.part_scorers`.
        Args:
            hyp (Hypothesis): Hypothesis with prefix tokens to score
            ids (torch.Tensor): 1D tensor of new partial tokens to score
            x (torch.Tensor): Corresponding input feature
        Returns:
            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
                score dict of `hyp` that has string keys of `self.part_scorers`
                and tensor score values of shape: `(len(ids),)`,
                and state dict that has string keys
                and state values of `self.part_scorers`
        """
        scores = dict()
        states = dict()
        for k, d in self.part_scorers.items():
            scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
        return scores, states
    def beam(
        self, weighted_scores: torch.Tensor, ids: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute topk full token ids and partial token ids.
        Args:
            weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
            Its shape is `(self.n_vocab,)`.
            ids (torch.Tensor): The partial token ids to compute topk
        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                The topk full token ids and partial token ids.
                Their shapes are `(self.beam_size,)`
        """
        # no pre beam performed
        if weighted_scores.size(0) == ids.size(0):
            top_ids = weighted_scores.topk(self.beam_size)[1]
            return top_ids, top_ids
        # mask pruned in pre-beam not to select in topk
        tmp = weighted_scores[ids]
        weighted_scores[:] = -float("inf")
        weighted_scores[ids] = tmp
        top_ids = weighted_scores.topk(self.beam_size)[1]
        local_ids = weighted_scores[ids].topk(self.beam_size)[1]
        return top_ids, local_ids
    @staticmethod
    def merge_scores(
        prev_scores: Dict[str, float],
        next_full_scores: Dict[str, torch.Tensor],
        full_idx: int,
        next_part_scores: Dict[str, torch.Tensor],
        part_idx: int,
    ) -> Dict[str, torch.Tensor]:
        """Merge scores for new hypothesis.
        Args:
            prev_scores (Dict[str, float]):
                The previous hypothesis scores by `self.scorers`
            next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
            full_idx (int): The next token id for `next_full_scores`
            next_part_scores (Dict[str, torch.Tensor]):
                scores of partial tokens by `self.part_scorers`
            part_idx (int): The new token id for `next_part_scores`
        Returns:
            Dict[str, torch.Tensor]: The new score dict.
                Its keys are names of `self.full_scorers` and `self.part_scorers`.
                Its values are scalar tensors by the scorers.
        """
        new_scores = dict()
        for k, v in next_full_scores.items():
            new_scores[k] = prev_scores[k] + v[full_idx]
        for k, v in next_part_scores.items():
            new_scores[k] = prev_scores[k] + v[part_idx]
        return new_scores
    def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
        """Merge states for new hypothesis.
        Args:
            states: states of `self.full_scorers`
            part_states: states of `self.part_scorers`
            part_idx (int): The new token id for `part_scores`
        Returns:
            Dict[str, torch.Tensor]: The new score dict.
                Its keys are names of `self.full_scorers` and `self.part_scorers`.
                Its values are states of the scorers.
        """
        new_states = dict()
        for k, v in states.items():
            new_states[k] = v
        for k, d in self.part_scorers.items():
            new_states[k] = d.select_state(part_states[k], part_idx)
        return new_states
    def search(
        self, running_hyps: List[Hypothesis], x: torch.Tensor, am_score: torch.Tensor
    ) -> List[Hypothesis]:
        """Search new tokens for running hypotheses and encoded speech x.
        Args:
            running_hyps (List[Hypothesis]): Running hypotheses on beam
            x (torch.Tensor): Encoded speech feature (T, D)
        Returns:
            List[Hypotheses]: Best sorted hypotheses
        """
        best_hyps = []
        part_ids = torch.arange(self.n_vocab, device=x.device)  # no pre-beam
        for hyp in running_hyps:
            # scoring
            weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
            weighted_scores += am_score
            scores, states = self.score_full(hyp, x)
            for k in self.full_scorers:
                weighted_scores += self.weights[k] * scores[k]
            # partial scoring
            if self.do_pre_beam:
                pre_beam_scores = (
                    weighted_scores
                    if self.pre_beam_score_key == "full"
                    else scores[self.pre_beam_score_key]
                )
                part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
            part_scores, part_states = self.score_partial(hyp, part_ids, x)
            for k in self.part_scorers:
                weighted_scores[part_ids] += self.weights[k] * part_scores[k]
            # add previous hyp score
            weighted_scores += hyp.score
            # update hyps
            for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
                # will be (2 x beam at most)
                best_hyps.append(
                    Hypothesis(
                        score=weighted_scores[j],
                        yseq=self.append_token(hyp.yseq, j),
                        scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
                        states=self.merge_states(states, part_states, part_j),
                    )
                )
            # sort and prune 2 x beam -> beam
            best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
                : min(len(best_hyps), self.beam_size)
            ]
        return best_hyps
    def forward(
        self,
        x: torch.Tensor,
        am_scores: torch.Tensor,
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
    ) -> List[Hypothesis]:
        """Perform beam search.
        Args:
            x (torch.Tensor): Encoded speech feature (T, D)
            maxlenratio (float): Input length ratio to obtain max output length.
                If maxlenratio=0.0 (default), it uses a end-detect function
                to automatically find maximum hypothesis lengths
                If maxlenratio<0.0, its absolute value is interpreted
                as a constant max output length.
            minlenratio (float): Input length ratio to obtain min output length.
        Returns:
            list[Hypothesis]: N-best decoding results
        """
        # set length bounds
        maxlen = am_scores.shape[0]
        logging.info("decoder input length: " + str(x.shape[0]))
        logging.info("max output length: " + str(maxlen))
        # main loop of prefix search
        running_hyps = self.init_hyp(x)
        ended_hyps = []
        for i in range(maxlen):
            logging.debug("position " + str(i))
            best = self.search(running_hyps, x, am_scores[i])
            # post process of one iteration
            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
            # end detection
            if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
                logging.info(f"end detected at {i}")
                break
            if len(running_hyps) == 0:
                logging.info("no hypothesis. Finish decoding.")
                break
            else:
                logging.debug(f"remained hypotheses: {len(running_hyps)}")
        nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
        # check the number of hypotheses reaching to eos
        if len(nbest_hyps) == 0:
            logging.warning(
                "there is no N-best results, perform recognition " "again with smaller minlenratio."
            )
            return (
                []
                if minlenratio < 0.1
                else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
            )
        # report the best result
        best = nbest_hyps[0]
        for k, v in best.scores.items():
            logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
        logging.info(f"total log probability: {best.score:.2f}")
        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
        if self.token_list is not None:
            logging.info(
                "best hypo: " + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]]) + "\n"
            )
        return nbest_hyps
    def post_process(
        self,
        i: int,
        maxlen: int,
        maxlenratio: float,
        running_hyps: List[Hypothesis],
        ended_hyps: List[Hypothesis],
    ) -> List[Hypothesis]:
        """Perform post-processing of beam search iterations.
        Args:
            i (int): The length of hypothesis tokens.
            maxlen (int): The maximum length of tokens in beam search.
            maxlenratio (int): The maximum length ratio in beam search.
            running_hyps (List[Hypothesis]): The running hypotheses in beam search.
            ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
        Returns:
            List[Hypothesis]: The new running hypotheses.
        """
        logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
        if self.token_list is not None:
            logging.debug(
                "best hypo: "
                + "".join([self.token_list[x.item()] for x in running_hyps[0].yseq[1:]])
            )
        # add eos in the final loop to avoid that there are no ended hyps
        if i == maxlen - 1:
            logging.info("adding <eos> in the last position in the loop")
            running_hyps = [
                h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
            ]
        # add ended hypotheses to a final list, and removed them from current hypotheses
        # (this will be a problem, number of hyps < beam)
        remained_hyps = []
        for hyp in running_hyps:
            if hyp.yseq[-1] == self.eos:
                # e.g., Word LM needs to add final <eos> score
                for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
                    s = d.final_score(hyp.states[k])
                    hyp.scores[k] += s
                    hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
                ended_hyps.append(hyp)
            else:
                remained_hyps.append(hyp)
        return remained_hyps