From a4bd736b038a64fb14c3849e4a2bd26deb02517b Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 18 四月 2023 14:44:59 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/modules/embedding.py | 77
funasr/modules/nets_utils.py | 195 ++
funasr/bin/asr_train_transducer.py | 46
funasr/models/encoder/conformer_encoder.py | 634 ++++++
egs/aishell/rnnt/path.sh | 5
funasr/models/joint_net/joint_network.py | 61
funasr/bin/asr_inference_rnnt.py | 1185 +++++-------
funasr/tasks/asr.py | 391 ++++
funasr/modules/subsampling.py | 202 ++
egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml | 5
egs/aishell/rnnt/local/aishell_data_prep.sh | 66
egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml | 80
egs/aishell/rnnt/run.sh | 247 ++
egs/aishell/rnnt/utils | 1
funasr/modules/attention.py | 220 ++
egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml | 8
funasr/modules/e2e_asr_common.py | 150 +
funasr/modules/beam_search/beam_search_transducer.py | 704 +++++++
egs/aishell/rnnt/README.md | 18
funasr/models/decoder/rnnt_decoder.py | 258 ++
funasr/modules/repeat.py | 91
funasr/bin/asr_inference_launch.py | 43
funasr/models/e2e_asr_transducer.py | 1013 ++++++++++
23 files changed, 4,998 insertions(+), 702 deletions(-)
diff --git a/egs/aishell/rnnt/README.md b/egs/aishell/rnnt/README.md
new file mode 100644
index 0000000..45f1f3f
--- /dev/null
+++ b/egs/aishell/rnnt/README.md
@@ -0,0 +1,18 @@
+
+# Streaming RNN-T Result
+
+## Training Config
+- 8 gpu(Tesla V100)
+- Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment
+- Train config: conf/train_conformer_rnnt_unified.yaml
+- chunk config: chunk size 16, full left chunk
+- LM config: LM was not used
+- Model size: 90M
+
+## Results (CER)
+- Decode config: conf/train_conformer_rnnt_unified.yaml
+
+| testset | CER(%) |
+|:-----------:|:-------:|
+| dev | 5.53 |
+| test | 6.24 |
diff --git a/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml
new file mode 100644
index 0000000..26e43c6
--- /dev/null
+++ b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming.yaml
@@ -0,0 +1,8 @@
+# The conformer transducer decoding configuration from @jeon30c
+beam_size: 10
+simu_streaming: false
+streaming: true
+chunk_size: 16
+left_context: 16
+right_context: 0
+
diff --git a/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml
new file mode 100644
index 0000000..dc3eff2
--- /dev/null
+++ b/egs/aishell/rnnt/conf/decode_rnnt_conformer_streaming_simu.yaml
@@ -0,0 +1,5 @@
+# The conformer transducer decoding configuration from @jeon30c
+beam_size: 10
+simu_streaming: true
+streaming: false
+chunk_size: 16
diff --git a/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
new file mode 100644
index 0000000..8a1c40c
--- /dev/null
+++ b/egs/aishell/rnnt/conf/train_conformer_rnnt_unified.yaml
@@ -0,0 +1,80 @@
+encoder: chunk_conformer
+encoder_conf:
+ activation_type: swish
+ positional_dropout_rate: 0.5
+ time_reduction_factor: 2
+ unified_model_training: true
+ default_chunk_size: 16
+ jitter_range: 4
+ left_chunk_size: 0
+ embed_vgg_like: false
+ subsampling_factor: 4
+ linear_units: 2048
+ output_size: 512
+ attention_heads: 8
+ dropout_rate: 0.5
+ positional_dropout_rate: 0.5
+ attention_dropout_rate: 0.5
+ cnn_module_kernel: 15
+ num_blocks: 12
+
+# decoder related
+rnnt_decoder: rnnt
+rnnt_decoder_conf:
+ embed_size: 512
+ hidden_size: 512
+ embed_dropout_rate: 0.5
+ dropout_rate: 0.5
+
+joint_network_conf:
+ joint_space_size: 512
+
+# Auxiliary CTC
+model_conf:
+ auxiliary_ctc_weight: 0.0
+
+# minibatch related
+use_amp: true
+batch_type: unsorted
+batch_size: 16
+num_workers: 16
+
+# optimization related
+accum_grad: 1
+grad_clip: 5
+max_epoch: 200
+val_scheduler_criterion:
+ - valid
+ - loss
+best_model_criterion:
+- - valid
+ - cer_transducer_chunk
+ - min
+keep_nbest_models: 10
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 25000
+
+normalize: None
+
+specaug: specaug
+specaug_conf:
+ apply_time_warp: true
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 40
+ num_freq_mask: 2
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 50
+ num_time_mask: 5
+
+log_interval: 50
diff --git a/egs/aishell/rnnt/local/aishell_data_prep.sh b/egs/aishell/rnnt/local/aishell_data_prep.sh
new file mode 100755
index 0000000..83f489b
--- /dev/null
+++ b/egs/aishell/rnnt/local/aishell_data_prep.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+
+# Copyright 2017 Xingyu Na
+# Apache 2.0
+
+#. ./path.sh || exit 1;
+
+if [ $# != 3 ]; then
+ echo "Usage: $0 <audio-path> <text-path> <output-path>"
+ echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data"
+ exit 1;
+fi
+
+aishell_audio_dir=$1
+aishell_text=$2/aishell_transcript_v0.8.txt
+output_dir=$3
+
+train_dir=$output_dir/data/local/train
+dev_dir=$output_dir/data/local/dev
+test_dir=$output_dir/data/local/test
+tmp_dir=$output_dir/data/local/tmp
+
+mkdir -p $train_dir
+mkdir -p $dev_dir
+mkdir -p $test_dir
+mkdir -p $tmp_dir
+
+# data directory check
+if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then
+ echo "Error: $0 requires two directory arguments"
+ exit 1;
+fi
+
+# find wav audio file for train, dev and test resp.
+find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist
+n=`cat $tmp_dir/wav.flist | wc -l`
+[ $n -ne 141925 ] && \
+ echo Warning: expected 141925 data data files, found $n
+
+grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1;
+grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1;
+grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1;
+
+rm -r $tmp_dir
+
+# Transcriptions preparation
+for dir in $train_dir $dev_dir $test_dir; do
+ echo Preparing $dir transcriptions
+ sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list
+ paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all
+ utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt
+ awk '{print $1}' $dir/transcripts.txt > $dir/utt.list
+ utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp
+ sort -u $dir/transcripts.txt > $dir/text
+done
+
+mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test
+
+for f in wav.scp text; do
+ cp $train_dir/$f $output_dir/data/train/$f || exit 1;
+ cp $dev_dir/$f $output_dir/data/dev/$f || exit 1;
+ cp $test_dir/$f $output_dir/data/test/$f || exit 1;
+done
+
+echo "$0: AISHELL data preparation succeeded"
+exit 0;
diff --git a/egs/aishell/rnnt/path.sh b/egs/aishell/rnnt/path.sh
new file mode 100644
index 0000000..7972642
--- /dev/null
+++ b/egs/aishell/rnnt/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/egs/aishell/rnnt/run.sh b/egs/aishell/rnnt/run.sh
new file mode 100755
index 0000000..bcd4a8b
--- /dev/null
+++ b/egs/aishell/rnnt/run.sh
@@ -0,0 +1,247 @@
+#!/usr/bin/env bash
+
+. ./path.sh || exit 1;
+
+# machines configuration
+CUDA_VISIBLE_DEVICES="0,1,2,3"
+gpu_num=4
+count=1
+gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding
+# for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob
+njob=5
+train_cmd=utils/run.pl
+infer_cmd=utils/run.pl
+
+# general configuration
+feats_dir= #feature output dictionary
+exp_dir=
+lang=zh
+dumpdir=dump/fbank
+feats_type=fbank
+token_type=char
+scp=feats.scp
+type=kaldi_ark
+stage=0
+stop_stage=4
+
+# feature configuration
+feats_dim=80
+sample_frequency=16000
+nj=32
+speed_perturb="0.9,1.0,1.1"
+
+# data
+data_aishell=
+
+# exp tag
+tag="exp1"
+
+. utils/parse_options.sh || exit 1;
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+train_set=train
+valid_set=dev
+test_sets="dev test"
+
+asr_config=conf/train_conformer_rnnt_unified.yaml
+model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}"
+
+inference_config=conf/decode_rnnt_conformer_streaming.yaml
+inference_asr_model=valid.cer_transducer_chunk.ave_5best.pth
+
+# you can set gpu num for decoding here
+gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default
+ngpu=$(echo $gpuid_list | awk -F "," '{print NF}')
+
+if ${gpu_inference}; then
+ inference_nj=$[${ngpu}*${njob}]
+ _ngpu=1
+else
+ inference_nj=$njob
+ _ngpu=0
+fi
+
+if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
+ echo "stage 0: Data preparation"
+ # Data preparation
+ local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir}
+ for x in train dev test; do
+ cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org
+ paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \
+ > ${feats_dir}/data/${x}/text
+ utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org
+ mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text
+ done
+fi
+
+feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir}
+feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir}
+feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir}
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Feature Generation"
+ # compute fbank features
+ fbankdir=${feats_dir}/fbank
+ utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \
+ ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train
+ utils/fix_data_feat.sh ${fbankdir}/train
+ utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
+ ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev
+ utils/fix_data_feat.sh ${fbankdir}/dev
+ utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \
+ ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test
+ utils/fix_data_feat.sh ${fbankdir}/test
+
+ # compute global cmvn
+ utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \
+ ${fbankdir}/train ${exp_dir}/exp/make_fbank/train
+
+ # apply cmvn
+ utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir}
+ utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir}
+ utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \
+ ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir}
+
+ cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir}
+ cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir}
+ cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir}
+
+ utils/fix_data_feat.sh ${feat_train_dir}
+ utils/fix_data_feat.sh ${feat_dev_dir}
+ utils/fix_data_feat.sh ${feat_test_dir}
+
+ #generate ark list
+ utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir}
+ utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir}
+fi
+
+token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt
+echo "dictionary: ${token_list}"
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: Dictionary Preparation"
+ mkdir -p ${feats_dir}/data/${lang}_token_list/char/
+
+ echo "make a dictionary"
+ echo "<blank>" > ${token_list}
+ echo "<s>" >> ${token_list}
+ echo "</s>" >> ${token_list}
+ utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \
+ | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list}
+ num_token=$(cat ${token_list} | wc -l)
+ echo "<unk>" >> ${token_list}
+ vocab_size=$(cat ${token_list} | wc -l)
+ awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char
+ awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char
+ mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train
+ mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev
+ cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train
+ cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev
+fi
+
+# Training Stage
+world_size=$gpu_num # run on one machine
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: Training"
+ mkdir -p ${exp_dir}/exp/${model_dir}
+ mkdir -p ${exp_dir}/exp/${model_dir}/log
+ INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init
+ if [ -f $INIT_FILE ];then
+ rm -f $INIT_FILE
+ fi
+ init_method=file://$(readlink -f $INIT_FILE)
+ echo "$0: init method is $init_method"
+ for ((i = 0; i < $gpu_num; ++i)); do
+ {
+ rank=$i
+ local_rank=$i
+ gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
+ asr_train_transducer.py \
+ --gpu_id $gpu_id \
+ --use_preprocessor true \
+ --token_type char \
+ --token_list $token_list \
+ --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \
+ --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \
+ --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \
+ --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \
+ --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \
+ --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \
+ --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \
+ --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \
+ --resume true \
+ --output_dir ${exp_dir}/exp/${model_dir} \
+ --config $asr_config \
+ --input_size $feats_dim \
+ --ngpu $gpu_num \
+ --num_worker_count $count \
+ --multiprocessing_distributed true \
+ --dist_init_method $init_method \
+ --dist_world_size $world_size \
+ --dist_rank $rank \
+ --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1
+ } &
+ done
+ wait
+fi
+
+# Testing Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: Inference"
+ for dset in ${test_sets}; do
+ asr_exp=${exp_dir}/exp/${model_dir}
+ inference_tag="$(basename "${inference_config}" .yaml)"
+ _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}"
+ _logdir="${_dir}/logdir"
+ if [ -d ${_dir} ]; then
+ echo "${_dir} is already exists. if you want to decode again, please delete this dir first."
+ exit 0
+ fi
+ mkdir -p "${_logdir}"
+ _data="${feats_dir}/${dumpdir}/${dset}"
+ key_file=${_data}/${scp}
+ num_scp_file="$(<${key_file} wc -l)"
+ _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file")
+ split_scps=
+ for n in $(seq "${_nj}"); do
+ split_scps+=" ${_logdir}/keys.${n}.scp"
+ done
+ # shellcheck disable=SC2086
+ utils/split_scp.pl "${key_file}" ${split_scps}
+ _opts=
+ if [ -n "${inference_config}" ]; then
+ _opts+="--config ${inference_config} "
+ fi
+ ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \
+ python -m funasr.bin.asr_inference_launch \
+ --batch_size 1 \
+ --ngpu "${_ngpu}" \
+ --njob ${njob} \
+ --gpuid_list ${gpuid_list} \
+ --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \
+ --key_file "${_logdir}"/keys.JOB.scp \
+ --asr_train_config "${asr_exp}"/config.yaml \
+ --asr_model_file "${asr_exp}"/"${inference_asr_model}" \
+ --output_dir "${_logdir}"/output.JOB \
+ --mode rnnt \
+ ${_opts}
+
+ for f in token token_int score text; do
+ if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then
+ for i in $(seq "${_nj}"); do
+ cat "${_logdir}/output.${i}/1best_recog/${f}"
+ done | sort -k1 >"${_dir}/${f}"
+ fi
+ done
+ python utils/proce_text.py ${_dir}/text ${_dir}/text.proc
+ python utils/proce_text.py ${_data}/text ${_data}/text.proc
+ python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer
+ tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt
+ cat ${_dir}/text.cer.txt
+ done
+fi
diff --git a/egs/aishell/rnnt/utils b/egs/aishell/rnnt/utils
new file mode 120000
index 0000000..4072eac
--- /dev/null
+++ b/egs/aishell/rnnt/utils
@@ -0,0 +1 @@
+../transformer/utils
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 7add960..2b6716e 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -134,6 +134,11 @@
help="Pretrained model tag. If specify this option, *_train_config and "
"*_file will be overwritten",
)
+ group.add_argument(
+ "--beam_search_config",
+ default={},
+ help="The keyword arguments for transducer beam search.",
+ )
group = parser.add_argument_group("Beam-search related")
group.add_argument(
@@ -171,6 +176,41 @@
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
group.add_argument("--streaming", type=str2bool, default=False)
+ group.add_argument("--simu_streaming", type=str2bool, default=False)
+ group.add_argument("--chunk_size", type=int, default=16)
+ group.add_argument("--left_context", type=int, default=16)
+ group.add_argument("--right_context", type=int, default=0)
+ group.add_argument(
+ "--display_partial_hypotheses",
+ type=bool,
+ default=False,
+ help="Whether to display partial hypotheses during chunk-by-chunk inference.",
+ )
+
+ group = parser.add_argument_group("Dynamic quantization related")
+ group.add_argument(
+ "--quantize_asr_model",
+ type=bool,
+ default=False,
+ help="Apply dynamic quantization to ASR model.",
+ )
+ group.add_argument(
+ "--quantize_modules",
+ nargs="*",
+ default=None,
+ help="""Module names to apply dynamic quantization on.
+ The module names are provided as a list, where each name is separated
+ by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
+ Each specified name should be an attribute of 'torch.nn', e.g.:
+ torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
+ )
+ group.add_argument(
+ "--quantize_dtype",
+ type=str,
+ default="qint8",
+ choices=["float16", "qint8"],
+ help="Dtype for dynamic quantization.",
+ )
group = parser.add_argument_group("Text converter related")
group.add_argument(
@@ -268,6 +308,9 @@
elif mode == "mfcca":
from funasr.bin.asr_inference_mfcca import inference_modelscope
return inference_modelscope(**kwargs)
+ elif mode == "rnnt":
+ from funasr.bin.asr_inference_rnnt import inference
+ return inference(**kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index 2189a71..bff8702 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -1,396 +1,149 @@
#!/usr/bin/env python3
+
+""" Inference class definition for Transducer models."""
+
+from __future__ import annotations
+
import argparse
import logging
+import math
import sys
-import time
-import copy
-import os
-import codecs
-import tempfile
-import requests
from pathlib import Path
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-from typing import Dict
-from typing import Any
-from typing import List
+from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
-from typeguard import check_argument_types
+from packaging.version import parse as V
+from typeguard import check_argument_types, check_return_type
+from funasr.modules.beam_search.beam_search_transducer import (
+ BeamSearchTransducer,
+ Hypothesis,
+)
+from funasr.modules.nets_utils import TooShortUttError
from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch
-from funasr.modules.beam_search.beam_search import Hypothesis
-from funasr.modules.scorers.ctc import CTCPrefixScorer
-from funasr.modules.scorers.length_bonus import LengthBonus
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.asr import ASRTaskParaformer as ASRTask
+from funasr.tasks.asr import ASRTransducerTask
from funasr.tasks.lm import LMTask
from funasr.text.build_tokenizer import build_tokenizer
from funasr.text.token_id_converter import TokenIDConverter
from funasr.torch_utils.device_funcs import to_device
from funasr.torch_utils.set_all_random_seed import set_all_random_seed
from funasr.utils import config_argparse
+from funasr.utils.types import str2bool, str2triple_str, str_or_none
from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
-from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
-
class Speech2Text:
- """Speech2Text class
-
- Examples:
- >>> import soundfile
- >>> speech2text = Speech2Text("asr_config.yml", "asr.pb")
- >>> audio, rate = soundfile.read("speech.wav")
- >>> speech2text(audio)
- [(text, token, token_int, hypothesis object), ...]
-
+ """Speech2Text class for Transducer models.
+ Args:
+ asr_train_config: ASR model training config path.
+ asr_model_file: ASR model path.
+ beam_search_config: Beam search config path.
+ lm_train_config: Language Model training config path.
+ lm_file: Language Model config path.
+ token_type: Type of token units.
+ bpemodel: BPE model path.
+ device: Device to use for inference.
+ beam_size: Size of beam during search.
+ dtype: Data type.
+ lm_weight: Language model weight.
+ quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+ quantize_modules: List of module names to apply dynamic quantization on.
+ quantize_dtype: Dynamic quantization data type.
+ nbest: Number of final hypothesis.
+ streaming: Whether to perform chunk-by-chunk inference.
+ chunk_size: Number of frames in chunk AFTER subsampling.
+ left_context: Number of frames in left context AFTER subsampling.
+ right_context: Number of frames in right context AFTER subsampling.
+ display_partial_hypotheses: Whether to display partial hypotheses.
"""
def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- frontend_conf: dict = None,
- hotword_list_or_file: str = None,
- **kwargs,
- ):
- assert check_argument_types()
+ self,
+ asr_train_config: Union[Path, str] = None,
+ asr_model_file: Union[Path, str] = None,
+ cmvn_file: Union[Path, str] = None,
+ beam_search_config: Dict[str, Any] = None,
+ lm_train_config: Union[Path, str] = None,
+ lm_file: Union[Path, str] = None,
+ token_type: str = None,
+ bpemodel: str = None,
+ device: str = "cpu",
+ beam_size: int = 5,
+ dtype: str = "float32",
+ lm_weight: float = 1.0,
+ quantize_asr_model: bool = False,
+ quantize_modules: List[str] = None,
+ quantize_dtype: str = "qint8",
+ nbest: int = 1,
+ streaming: bool = False,
+ simu_streaming: bool = False,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ display_partial_hypotheses: bool = False,
+ ) -> None:
+ """Construct a Speech2Text object."""
+ super().__init__()
- # 1. Build ASR model
- scorers = {}
- asr_model, asr_train_args = ASRTask.build_model_from_file(
+ assert check_argument_types()
+ asr_model, asr_train_args = ASRTransducerTask.build_model_from_file(
asr_train_config, asr_model_file, cmvn_file, device
)
+
frontend = None
if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
+ if quantize_asr_model:
+ if quantize_modules is not None:
+ if not all([q in ["LSTM", "Linear"] for q in quantize_modules]):
+ raise ValueError(
+ "Only 'Linear' and 'LSTM' modules are currently supported"
+ " by PyTorch and in --quantize_modules"
+ )
- if asr_model.ctc != None:
- ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = asr_model.token_list
- scorers.update(
- length_bonus=LengthBonus(len(token_list)),
- )
+ q_config = set([getattr(torch.nn, q) for q in quantize_modules])
+ else:
+ q_config = {torch.nn.Linear}
- # 2. Build Language model
+ if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")):
+ raise ValueError(
+ "float16 dtype for dynamic quantization is not supported with torch"
+ " version < 1.5.0. Switching to qint8 dtype instead."
+ )
+ q_dtype = getattr(torch, quantize_dtype)
+
+ asr_model = torch.quantization.quantize_dynamic(
+ asr_model, q_config, dtype=q_dtype
+ ).eval()
+ else:
+ asr_model.to(dtype=getattr(torch, dtype)).eval()
+
if lm_train_config is not None:
lm, lm_train_args = LMTask.build_model_from_file(
lm_train_config, lm_file, device
)
- scorers["lm"] = lm.lm
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
+ lm_scorer = lm.lm
+ else:
+ lm_scorer = None
# 4. Build BeamSearch object
- # transducer is not supported now
- beam_search_transducer = None
+ if beam_search_config is None:
+ beam_search_config = {}
- weights = dict(
- decoder=1.0 - ctc_weight,
- ctc=ctc_weight,
- lm=lm_weight,
- ngram=ngram_weight,
- length_bonus=penalty,
+ beam_search = BeamSearchTransducer(
+ asr_model.decoder,
+ asr_model.joint_network,
+ beam_size,
+ lm=lm_scorer,
+ lm_weight=lm_weight,
+ nbest=nbest,
+ **beam_search_config,
)
- beam_search = BeamSearch(
- beam_size=beam_size,
- weights=weights,
- scorers=scorers,
- sos=asr_model.sos,
- eos=asr_model.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if ctc_weight == 1.0 else "full",
- )
-
- beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
- for scorer in scorers.values():
- if isinstance(scorer, torch.nn.Module):
- scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
- if token_type is None:
- token_type = asr_train_args.token_type
- if bpemodel is None:
- bpemodel = asr_train_args.bpemodel
-
- if token_type is None:
- tokenizer = None
- elif token_type == "bpe":
- if bpemodel is not None:
- tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel)
- else:
- tokenizer = None
- else:
- tokenizer = build_tokenizer(token_type=token_type)
- converter = TokenIDConverter(token_list=token_list)
- logging.info(f"Text tokenizer: {tokenizer}")
-
- self.asr_model = asr_model
- self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
-
- # 6. [Optional] Build hotword list from str, local file or url
- self.hotword_list = None
- self.hotword_list = self.generate_hotwords_list(hotword_list_or_file)
-
- is_use_lm = lm_weight != 0.0 and lm_file is not None
- if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
- beam_search = None
- self.beam_search = beam_search
- logging.info(f"Beam_search: {self.beam_search}")
- self.beam_search_transducer = beam_search_transducer
- self.maxlenratio = maxlenratio
- self.minlenratio = minlenratio
- self.device = device
- self.dtype = dtype
- self.nbest = nbest
- self.frontend = frontend
- self.encoder_downsampling_factor = 1
- if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d":
- self.encoder_downsampling_factor = 4
-
- @torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ):
- """Inference
-
- Args:
- speech: Input speech data
- Returns:
- text, token, token_int, hyp
-
- """
- assert check_argument_types()
-
- # Input as audio signal
- if isinstance(speech, np.ndarray):
- speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- # b. Forward Encoder
- enc, enc_len = self.asr_model.encode(**batch)
- if isinstance(enc, tuple):
- enc = enc[0]
- # assert len(enc) == 1, len(enc)
- enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
-
- predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
- pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
- predictor_outs[2], predictor_outs[3]
- pre_token_length = pre_token_length.round().long()
- if torch.max(pre_token_length) < 1:
- return []
- if not isinstance(self.asr_model, ContextualParaformer):
- if self.hotword_list:
- logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
- else:
- decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
-
- results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- x = enc[i, :enc_len[i], :]
- am_scores = decoder_out[i, :pre_token_length[i], :]
- if self.beam_search is not None:
- nbest_hyps = self.beam_search(
- x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
- else:
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
-
- results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
-
- # assert check_return_type(results)
- return results
-
- def generate_hotwords_list(self, hotword_list_or_file):
- # for None
- if hotword_list_or_file is None:
- hotword_list = None
- # for local txt inputs
- elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords from local txt...")
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for url, download and generate txt
- elif hotword_list_or_file.startswith('http'):
- logging.info("Attempting to parse hotwords from url...")
- work_dir = tempfile.TemporaryDirectory().name
- if not os.path.exists(work_dir):
- os.makedirs(work_dir)
- text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file))
- local_file = requests.get(hotword_list_or_file)
- open(text_file_path, "wb").write(local_file.content)
- hotword_list_or_file = text_file_path
- hotword_list = []
- hotword_str_list = []
- with codecs.open(hotword_list_or_file, 'r') as fin:
- for line in fin.readlines():
- hw = line.strip()
- hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Initialized hotword list from file: {}, hotword list: {}."
- .format(hotword_list_or_file, hotword_str_list))
- # for text str input
- elif not hotword_list_or_file.endswith('.txt'):
- logging.info("Attempting to parse hotwords as str...")
- hotword_list = []
- hotword_str_list = []
- for hw in hotword_list_or_file.strip().split():
- hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
- hotword_list.append([self.asr_model.sos])
- hotword_str_list.append('<s>')
- logging.info("Hotword list: {}.".format(hotword_str_list))
- else:
- hotword_list = None
- return hotword_list
-
-class Speech2TextExport:
- """Speech2TextExport class
-
- """
-
- def __init__(
- self,
- asr_train_config: Union[Path, str] = None,
- asr_model_file: Union[Path, str] = None,
- cmvn_file: Union[Path, str] = None,
- lm_train_config: Union[Path, str] = None,
- lm_file: Union[Path, str] = None,
- token_type: str = None,
- bpemodel: str = None,
- device: str = "cpu",
- maxlenratio: float = 0.0,
- minlenratio: float = 0.0,
- dtype: str = "float32",
- beam_size: int = 20,
- ctc_weight: float = 0.5,
- lm_weight: float = 1.0,
- ngram_weight: float = 0.9,
- penalty: float = 0.0,
- nbest: int = 1,
- frontend_conf: dict = None,
- hotword_list_or_file: str = None,
- **kwargs,
- ):
-
- # 1. Build ASR model
- asr_model, asr_train_args = ASRTask.build_model_from_file(
- asr_train_config, asr_model_file, cmvn_file, device
- )
- frontend = None
- if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
- frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
-
- logging.info("asr_model: {}".format(asr_model))
- logging.info("asr_train_args: {}".format(asr_train_args))
- asr_model.to(dtype=getattr(torch, dtype)).eval()
token_list = asr_model.token_list
-
-
- logging.info(f"Decoding device={device}, dtype={dtype}")
-
- # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
if token_type is None:
token_type = asr_train_args.token_type
if bpemodel is None:
@@ -407,197 +160,277 @@
tokenizer = build_tokenizer(token_type=token_type)
converter = TokenIDConverter(token_list=token_list)
logging.info(f"Text tokenizer: {tokenizer}")
-
- # self.asr_model = asr_model
+
+ self.asr_model = asr_model
self.asr_train_args = asr_train_args
- self.converter = converter
- self.tokenizer = tokenizer
-
self.device = device
self.dtype = dtype
self.nbest = nbest
- self.frontend = frontend
- model = Paraformer_export(asr_model, onnx=False)
- self.asr_model = model
+ self.converter = converter
+ self.tokenizer = tokenizer
+
+ self.beam_search = beam_search
+ self.streaming = streaming
+ self.simu_streaming = simu_streaming
+ self.chunk_size = max(chunk_size, 0)
+ self.left_context = max(left_context, 0)
+ self.right_context = max(right_context, 0)
+
+ if not streaming or chunk_size == 0:
+ self.streaming = False
+ self.asr_model.encoder.dynamic_chunk_training = False
+ if not simu_streaming or chunk_size == 0:
+ self.simu_streaming = False
+ self.asr_model.encoder.dynamic_chunk_training = False
+
+ self.frontend = frontend
+ self.window_size = self.chunk_size + self.right_context
+
+ self._ctx = self.asr_model.encoder.get_encoder_input_size(
+ self.window_size
+ )
+
+ #self.last_chunk_length = (
+ # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+ #) * self.hop_length
+
+ self.last_chunk_length = (
+ self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+ )
+ self.reset_inference_cache()
+
+ def reset_inference_cache(self) -> None:
+ """Reset Speech2Text parameters."""
+ self.frontend_cache = None
+
+ self.asr_model.encoder.reset_streaming_cache(
+ self.left_context, device=self.device
+ )
+ self.beam_search.reset_inference_cache()
+
+ self.num_processed_frames = torch.tensor([[0]], device=self.device)
+
@torch.no_grad()
- def __call__(
- self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
- ):
- """Inference
-
+ def streaming_decode(
+ self,
+ speech: Union[torch.Tensor, np.ndarray],
+ is_final: bool = True,
+ ) -> List[Hypothesis]:
+ """Speech2Text streaming call.
Args:
- speech: Input speech data
+ speech: Chunk of speech data. (S)
+ is_final: Whether speech corresponds to the final chunk of data.
Returns:
- text, token, token_int, hyp
+ nbest_hypothesis: N-best hypothesis.
+ """
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+ if is_final:
+ if self.streaming and speech.size(0) < self.last_chunk_length:
+ pad = torch.zeros(
+ self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype
+ )
+ speech = torch.cat([speech, pad], dim=0) #feats, feats_length = self.apply_frontend(speech, is_final=is_final)
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ if self.asr_model.normalize is not None:
+ feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+ enc_out = self.asr_model.encoder.chunk_forward(
+ feats,
+ feats_lengths,
+ self.num_processed_frames,
+ chunk_size=self.chunk_size,
+ left_context=self.left_context,
+ right_context=self.right_context,
+ )
+ nbest_hyps = self.beam_search(enc_out[0], is_final=is_final)
+
+ self.num_processed_frames += self.chunk_size
+
+ if is_final:
+ self.reset_inference_cache()
+
+ return nbest_hyps
+
+ @torch.no_grad()
+ def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]:
+ """Speech2Text call.
+ Args:
+ speech: Speech data. (S)
+ Returns:
+ nbest_hypothesis: N-best hypothesis.
"""
assert check_argument_types()
- # Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
-
- if self.frontend is not None:
- feats, feats_len = self.frontend.forward(speech, speech_lengths)
- feats = to_device(feats, device=self.device)
- feats_len = feats_len.int()
- self.asr_model.frontend = None
- else:
- feats = speech
- feats_len = speech_lengths
-
- enc_len_batch_total = feats_len.sum()
- lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
- batch = {"speech": feats, "speech_lengths": feats_len}
-
- # a. To device
- batch = to_device(batch, device=self.device)
-
- decoder_outs = self.asr_model(**batch)
- decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ if self.asr_model.normalize is not None:
+ feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths)
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+ enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context, self.right_context)
+ nbest_hyps = self.beam_search(enc_out[0])
+
+ return nbest_hyps
+
+ @torch.no_grad()
+ def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]:
+ """Speech2Text call.
+ Args:
+ speech: Speech data. (S)
+ Returns:
+ nbest_hypothesis: N-best hypothesis.
+ """
+ assert check_argument_types()
+
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+
+ feats = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+ feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1))
+
+ feats = to_device(feats, device=self.device)
+ feats_lengths = to_device(feats_lengths, device=self.device)
+
+ enc_out, _ = self.asr_model.encoder(feats, feats_lengths)
+
+ nbest_hyps = self.beam_search(enc_out[0])
+
+ return nbest_hyps
+
+ def hypotheses_to_results(self, nbest_hyps: List[Hypothesis]) -> List[Any]:
+ """Build partial or final results from the hypotheses.
+ Args:
+ nbest_hyps: N-best hypothesis.
+ Returns:
+ results: Results containing different representation for the hypothesis.
+ """
results = []
- b, n, d = decoder_out.size()
- for i in range(b):
- am_scores = decoder_out[i, :ys_pad_lens[i], :]
- yseq = am_scores.argmax(dim=-1)
- score = am_scores.max(dim=-1)[0]
- score = torch.sum(score, dim=-1)
- # pad with mask tokens to ensure compatibility with sos/eos tokens
- yseq = torch.tensor(
- yseq.tolist(), device=yseq.device
- )
- nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
+ for hyp in nbest_hyps:
+ token_int = list(filter(lambda x: x != 0, hyp.yseq))
- for hyp in nbest_hyps:
- assert isinstance(hyp, (Hypothesis)), type(hyp)
+ token = self.converter.ids2tokens(token_int)
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
+ if self.tokenizer is not None:
+ text = self.tokenizer.tokens2text(token)
+ else:
+ text = None
+ results.append((text, token, token_int, hyp))
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
-
- # Change integer-ids to tokens
- token = self.converter.ids2tokens(token_int)
-
- if self.tokenizer is not None:
- text = self.tokenizer.tokens2text(token)
- else:
- text = None
-
- results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
+ assert check_return_type(results)
return results
+ @staticmethod
+ def from_pretrained(
+ model_tag: Optional[str] = None,
+ **kwargs: Optional[Any],
+ ) -> Speech2Text:
+ """Build Speech2Text instance from the pretrained model.
+ Args:
+ model_tag: Model tag of the pretrained models.
+ Return:
+ : Speech2Text instance.
+ """
+ if model_tag is not None:
+ try:
+ from espnet_model_zoo.downloader import ModelDownloader
+
+ except ImportError:
+ logging.error(
+ "`espnet_model_zoo` is not installed. "
+ "Please install via `pip install -U espnet_model_zoo`."
+ )
+ raise
+ d = ModelDownloader()
+ kwargs.update(**d.download_and_unpack(model_tag))
+
+ return Speech2Text(**kwargs)
+
def inference(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- streaming: bool = False,
- output_dir: Optional[str] = None,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
-
- **kwargs,
-):
- inference_pipeline = inference_modelscope(
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
- batch_size=batch_size,
- beam_size=beam_size,
- ngpu=ngpu,
- ctc_weight=ctc_weight,
- lm_weight=lm_weight,
- penalty=penalty,
- log_level=log_level,
- asr_train_config=asr_train_config,
- asr_model_file=asr_model_file,
- cmvn_file=cmvn_file,
- raw_inputs=raw_inputs,
- lm_train_config=lm_train_config,
- lm_file=lm_file,
- token_type=token_type,
- key_file=key_file,
- word_lm_train_config=word_lm_train_config,
- bpemodel=bpemodel,
- allow_variable_data_keys=allow_variable_data_keys,
- streaming=streaming,
- output_dir=output_dir,
- dtype=dtype,
- seed=seed,
- ngram_weight=ngram_weight,
- nbest=nbest,
- num_workers=num_workers,
-
- **kwargs,
- )
- return inference_pipeline(data_path_and_name_and_type, raw_inputs)
-
-
-def inference_modelscope(
- maxlenratio: float,
- minlenratio: float,
- batch_size: int,
- beam_size: int,
- ngpu: int,
- ctc_weight: float,
- lm_weight: float,
- penalty: float,
- log_level: Union[int, str],
- # data_path_and_name_and_type,
- asr_train_config: Optional[str],
- asr_model_file: Optional[str],
- cmvn_file: Optional[str] = None,
- lm_train_config: Optional[str] = None,
- lm_file: Optional[str] = None,
- token_type: Optional[str] = None,
- key_file: Optional[str] = None,
- word_lm_train_config: Optional[str] = None,
- bpemodel: Optional[str] = None,
- allow_variable_data_keys: bool = False,
- dtype: str = "float32",
- seed: int = 0,
- ngram_weight: float = 0.9,
- nbest: int = 1,
- num_workers: int = 1,
- output_dir: Optional[str] = None,
- param_dict: dict = None,
- **kwargs,
-):
+ output_dir: str,
+ batch_size: int,
+ dtype: str,
+ beam_size: int,
+ ngpu: int,
+ seed: int,
+ lm_weight: float,
+ nbest: int,
+ num_workers: int,
+ log_level: Union[int, str],
+ data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str],
+ beam_search_config: Optional[dict],
+ lm_train_config: Optional[str],
+ lm_file: Optional[str],
+ model_tag: Optional[str],
+ token_type: Optional[str],
+ bpemodel: Optional[str],
+ key_file: Optional[str],
+ allow_variable_data_keys: bool,
+ quantize_asr_model: Optional[bool],
+ quantize_modules: Optional[List[str]],
+ quantize_dtype: Optional[str],
+ streaming: Optional[bool],
+ simu_streaming: Optional[bool],
+ chunk_size: Optional[int],
+ left_context: Optional[int],
+ right_context: Optional[int],
+ display_partial_hypotheses: bool,
+ **kwargs,
+) -> None:
+ """Transducer model inference.
+ Args:
+ output_dir: Output directory path.
+ batch_size: Batch decoding size.
+ dtype: Data type.
+ beam_size: Beam size.
+ ngpu: Number of GPUs.
+ seed: Random number generator seed.
+ lm_weight: Weight of language model.
+ nbest: Number of final hypothesis.
+ num_workers: Number of workers.
+ log_level: Level of verbose for logs.
+ data_path_and_name_and_type:
+ asr_train_config: ASR model training config path.
+ asr_model_file: ASR model path.
+ beam_search_config: Beam search config path.
+ lm_train_config: Language Model training config path.
+ lm_file: Language Model path.
+ model_tag: Model tag.
+ token_type: Type of token units.
+ bpemodel: BPE model path.
+ key_file: File key.
+ allow_variable_data_keys: Whether to allow variable data keys.
+ quantize_asr_model: Whether to apply dynamic quantization to ASR model.
+ quantize_modules: List of module names to apply dynamic quantization on.
+ quantize_dtype: Dynamic quantization data type.
+ streaming: Whether to perform chunk-by-chunk inference.
+ chunk_size: Number of frames in chunk AFTER subsampling.
+ left_context: Number of frames in left context AFTER subsampling.
+ right_context: Number of frames in right context AFTER subsampling.
+ display_partial_hypotheses: Whether to display partial hypotheses.
+ """
assert check_argument_types()
- if word_lm_train_config is not None:
- raise NotImplementedError("Word LM is not implemented")
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
if ngpu > 1:
raise NotImplementedError("only single GPU decoding is supported")
@@ -605,20 +438,11 @@
level=log_level,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
-
- export_mode = False
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
- export_mode = param_dict.get("export_mode", False)
- else:
- hotword_list_or_file = None
- if ngpu >= 1 and torch.cuda.is_available():
+ if ngpu >= 1:
device = "cuda"
else:
device = "cpu"
- batch_size = 1
-
# 1. Set random-seed
set_all_random_seed(seed)
@@ -627,143 +451,105 @@
asr_train_config=asr_train_config,
asr_model_file=asr_model_file,
cmvn_file=cmvn_file,
+ beam_search_config=beam_search_config,
lm_train_config=lm_train_config,
lm_file=lm_file,
token_type=token_type,
bpemodel=bpemodel,
device=device,
- maxlenratio=maxlenratio,
- minlenratio=minlenratio,
dtype=dtype,
beam_size=beam_size,
- ctc_weight=ctc_weight,
lm_weight=lm_weight,
- ngram_weight=ngram_weight,
- penalty=penalty,
nbest=nbest,
- hotword_list_or_file=hotword_list_or_file,
+ quantize_asr_model=quantize_asr_model,
+ quantize_modules=quantize_modules,
+ quantize_dtype=quantize_dtype,
+ streaming=streaming,
+ simu_streaming=simu_streaming,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
)
- if export_mode:
- speech2text = Speech2TextExport(**speech2text_kwargs)
- else:
- speech2text = Speech2Text(**speech2text_kwargs)
+ speech2text = Speech2Text.from_pretrained(
+ model_tag=model_tag,
+ **speech2text_kwargs,
+ )
- def _forward(
- data_path_and_name_and_type,
- raw_inputs: Union[np.ndarray, torch.Tensor] = None,
- output_dir_v2: Optional[str] = None,
- fs: dict = None,
- param_dict: dict = None,
- **kwargs,
- ):
+ # 3. Build data-iterator
+ loader = ASRTransducerTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=ASRTransducerTask.build_preprocess_fn(
+ speech2text.asr_train_args, False
+ ),
+ collate_fn=ASRTransducerTask.build_collate_fn(
+ speech2text.asr_train_args, False
+ ),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
- hotword_list_or_file = None
- if param_dict is not None:
- hotword_list_or_file = param_dict.get('hotword')
- if 'hotword' in kwargs:
- hotword_list_or_file = kwargs['hotword']
- if hotword_list_or_file is not None or 'hotword' in kwargs:
- speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
- cache = None
- if 'cache' in param_dict:
- cache = param_dict['cache']
- # 3. Build data-iterator
- if data_path_and_name_and_type is None and raw_inputs is not None:
- if isinstance(raw_inputs, torch.Tensor):
- raw_inputs = raw_inputs.numpy()
- data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
- loader = ASRTask.build_streaming_iterator(
- data_path_and_name_and_type,
- dtype=dtype,
- fs=fs,
- batch_size=batch_size,
- key_file=key_file,
- num_workers=num_workers,
- preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
- collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
- allow_variable_data_keys=allow_variable_data_keys,
- inference=True,
- )
-
- forward_time_total = 0.0
- length_total = 0.0
- finish_count = 0
- file_count = 1
- # 7 .Start for-loop
- # FIXME(kamo): The output format should be discussed about
- asr_result_list = []
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path is not None:
- writer = DatadirWriter(output_path)
- else:
- writer = None
-
+ # 4 .Start for-loop
+ with DatadirWriter(output_dir) as writer:
for keys, batch in loader:
assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys
+
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
+ batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+ assert len(batch.keys()) == 1
- logging.info("decoding, utt_id: {}".format(keys))
- # N-best list of (text, token, token_int, hyp_object)
+ try:
+ if speech2text.streaming:
+ speech = batch["speech"]
- time_beg = time.time()
- results = speech2text(cache=cache, **batch)
- if len(results) < 1:
- hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
- results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
- time_end = time.time()
- forward_time = time_end - time_beg
- lfr_factor = results[0][-1]
- length = results[0][-2]
- forward_time_total += forward_time
- length_total += length
- rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor))
- logging.info(rtf_cur)
+ _steps = len(speech) // speech2text._ctx
+ _end = 0
+ for i in range(_steps):
+ _end = (i + 1) * speech2text._ctx
- for batch_id in range(_bs):
- result = [results[batch_id][:-2]]
+ speech2text.streaming_decode(
+ speech[i * speech2text._ctx : _end], is_final=False
+ )
- key = keys[batch_id]
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result):
- # Create a directory: outdir/{n}best_recog
- if writer is not None:
- ibest_writer = writer[f"{n}best_recog"]
+ final_hyps = speech2text.streaming_decode(
+ speech[_end : len(speech)], is_final=True
+ )
+ elif speech2text.simu_streaming:
+ final_hyps = speech2text.simu_streaming_decode(**batch)
+ else:
+ final_hyps = speech2text(**batch)
- # Write the result to each file
- ibest_writer["token"][key] = " ".join(token)
- # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
- ibest_writer["score"][key] = str(hyp.score)
- ibest_writer["rtf"][key] = rtf_cur
+ results = speech2text.hypotheses_to_results(final_hyps)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
- if text is not None:
- text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token)
- item = {'key': key, 'value': text_postprocessed}
- asr_result_list.append(item)
- finish_count += 1
- # asr_utils.print_progress(finish_count / file_count)
- if writer is not None:
- ibest_writer["text"][key] = " ".join(word_lists)
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ ibest_writer = writer[f"{n}best_recog"]
- logging.info("decoding, utt: {}, predictions: {}".format(key, text))
- rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
- logging.info(rtf_avg)
- if writer is not None:
- ibest_writer["rtf"]["rtf_avf"] = rtf_avg
- return asr_result_list
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["score"][key] = str(hyp.score)
- return _forward
+ if text is not None:
+ ibest_writer["text"][key] = text
def get_parser():
+ """Get Transducer model inference parser."""
+
parser = config_argparse.ArgumentParser(
- description="ASR Decoding",
+ description="ASR Transducer Decoding",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
- # Note(kamo): Use '_' instead of '-' as separator.
- # '-' is confusing if written in yaml.
parser.add_argument(
"--log_level",
type=lambda x: x.upper(),
@@ -792,17 +578,12 @@
default=1,
help="The number of workers used for DataLoader",
)
- parser.add_argument(
- "--hotword",
- type=str_or_none,
- default=None,
- help="hotword file path or hotwords seperated by space"
- )
+
group = parser.add_argument_group("Input data related")
group.add_argument(
"--data_path_and_name_and_type",
type=str2triple_str,
- required=False,
+ required=True,
action="append",
)
group.add_argument("--key_file", type=str_or_none)
@@ -835,25 +616,10 @@
help="LM parameter file",
)
group.add_argument(
- "--word_lm_train_config",
- type=str,
- help="Word LM training configuration",
- )
- group.add_argument(
- "--word_lm_file",
- type=str,
- help="Word LM parameter file",
- )
- group.add_argument(
- "--ngram_file",
- type=str,
- help="N-gram parameter file",
- )
- group.add_argument(
"--model_tag",
type=str,
help="Pretrained model tag. If specify this option, *_train_config and "
- "*_file will be overwritten",
+ "*_file will be overwritten",
)
group = parser.add_argument_group("Beam-search related")
@@ -864,42 +630,13 @@
help="The batch size for inference",
)
group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses")
- group.add_argument("--beam_size", type=int, default=20, help="Beam size")
- group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
- group.add_argument(
- "--maxlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain max output length. "
- "If maxlenratio=0.0 (default), it uses a end-detect "
- "function "
- "to automatically find maximum hypothesis lengths."
- "If maxlenratio<0.0, its absolute value is interpreted"
- "as a constant max output length",
- )
- group.add_argument(
- "--minlenratio",
- type=float,
- default=0.0,
- help="Input length ratio to obtain min output length",
- )
- group.add_argument(
- "--ctc_weight",
- type=float,
- default=0.5,
- help="CTC weight in joint decoding",
- )
+ group.add_argument("--beam_size", type=int, default=5, help="Beam size")
group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
- group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
- group.add_argument("--streaming", type=str2bool, default=False)
-
group.add_argument(
- "--frontend_conf",
- default=None,
- help="",
+ "--beam_search_config",
+ default={},
+ help="The keyword arguments for transducer beam search.",
)
- group.add_argument("--raw_inputs", type=list, default=None)
- # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
group = parser.add_argument_group("Text converter related")
group.add_argument(
@@ -908,14 +645,77 @@
default=None,
choices=["char", "bpe", None],
help="The token type for ASR model. "
- "If not given, refers from the training args",
+ "If not given, refers from the training args",
)
group.add_argument(
"--bpemodel",
type=str_or_none,
default=None,
help="The model path of sentencepiece. "
- "If not given, refers from the training args",
+ "If not given, refers from the training args",
+ )
+
+ group = parser.add_argument_group("Dynamic quantization related")
+ parser.add_argument(
+ "--quantize_asr_model",
+ type=bool,
+ default=False,
+ help="Apply dynamic quantization to ASR model.",
+ )
+ parser.add_argument(
+ "--quantize_modules",
+ nargs="*",
+ default=None,
+ help="""Module names to apply dynamic quantization on.
+ The module names are provided as a list, where each name is separated
+ by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
+ Each specified name should be an attribute of 'torch.nn', e.g.:
+ torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
+ )
+ parser.add_argument(
+ "--quantize_dtype",
+ type=str,
+ default="qint8",
+ choices=["float16", "qint8"],
+ help="Dtype for dynamic quantization.",
+ )
+
+ group = parser.add_argument_group("Streaming related")
+ parser.add_argument(
+ "--streaming",
+ type=bool,
+ default=False,
+ help="Whether to perform chunk-by-chunk inference.",
+ )
+ parser.add_argument(
+ "--simu_streaming",
+ type=bool,
+ default=False,
+ help="Whether to simulate chunk-by-chunk inference.",
+ )
+ parser.add_argument(
+ "--chunk_size",
+ type=int,
+ default=16,
+ help="Number of frames in chunk AFTER subsampling.",
+ )
+ parser.add_argument(
+ "--left_context",
+ type=int,
+ default=32,
+ help="Number of frames in left context of the chunk AFTER subsampling.",
+ )
+ parser.add_argument(
+ "--right_context",
+ type=int,
+ default=0,
+ help="Number of frames in right context of the chunk AFTER subsampling.",
+ )
+ parser.add_argument(
+ "--display_partial_hypotheses",
+ type=bool,
+ default=False,
+ help="Whether to display partial hypotheses during chunk-by-chunk inference.",
)
return parser
@@ -923,24 +723,15 @@
def main(cmd=None):
print(get_commandline_args(), file=sys.stderr)
+
parser = get_parser()
args = parser.parse_args(cmd)
- param_dict = {'hotword': args.hotword}
kwargs = vars(args)
+
kwargs.pop("config", None)
- kwargs['param_dict'] = param_dict
inference(**kwargs)
if __name__ == "__main__":
main()
- # from modelscope.pipelines import pipeline
- # from modelscope.utils.constant import Tasks
- #
- # inference_16k_pipline = pipeline(
- # task=Tasks.auto_speech_recognition,
- # model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
- #
- # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
- # print(rec_result)
diff --git a/funasr/bin/asr_train_transducer.py b/funasr/bin/asr_train_transducer.py
new file mode 100755
index 0000000..fe418db
--- /dev/null
+++ b/funasr/bin/asr_train_transducer.py
@@ -0,0 +1,46 @@
+#!/usr/bin/env python3
+
+import os
+
+from funasr.tasks.asr import ASRTransducerTask
+
+
+# for ASR Training
+def parse_args():
+ parser = ASRTransducerTask.get_parser()
+ parser.add_argument(
+ "--gpu_id",
+ type=int,
+ default=0,
+ help="local gpu id.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main(args=None, cmd=None):
+ # for ASR Training
+ ASRTransducerTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ # setup local gpu_id
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+ # DDP settings
+ if args.ngpu > 1:
+ args.distributed = True
+ else:
+ args.distributed = False
+ assert args.num_worker_count == 1
+
+ # re-compute batch size: when dataset type is small
+ if args.dataset_type == "small":
+ if args.batch_size is not None:
+ args.batch_size = args.batch_size * args.ngpu
+ if args.batch_bins is not None:
+ args.batch_bins = args.batch_bins * args.ngpu
+
+ main(args=args)
diff --git a/funasr/models/decoder/rnnt_decoder.py b/funasr/models/decoder/rnnt_decoder.py
new file mode 100644
index 0000000..5401ab2
--- /dev/null
+++ b/funasr/models/decoder/rnnt_decoder.py
@@ -0,0 +1,258 @@
+"""RNN decoder definition for Transducer models."""
+
+from typing import List, Optional, Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.modules.beam_search.beam_search_transducer import Hypothesis
+from funasr.models.specaug.specaug import SpecAug
+
+class RNNTDecoder(torch.nn.Module):
+ """RNN decoder module.
+
+ Args:
+ vocab_size: Vocabulary size.
+ embed_size: Embedding size.
+ hidden_size: Hidden size..
+ rnn_type: Decoder layers type.
+ num_layers: Number of decoder layers.
+ dropout_rate: Dropout rate for decoder layers.
+ embed_dropout_rate: Dropout rate for embedding layer.
+ embed_pad: Embedding padding symbol ID.
+
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ embed_size: int = 256,
+ hidden_size: int = 256,
+ rnn_type: str = "lstm",
+ num_layers: int = 1,
+ dropout_rate: float = 0.0,
+ embed_dropout_rate: float = 0.0,
+ embed_pad: int = 0,
+ ) -> None:
+ """Construct a RNNDecoder object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ if rnn_type not in ("lstm", "gru"):
+ raise ValueError(f"Not supported: rnn_type={rnn_type}")
+
+ self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad)
+ self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate)
+
+ rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
+
+ self.rnn = torch.nn.ModuleList(
+ [rnn_class(embed_size, hidden_size, 1, batch_first=True)]
+ )
+
+ for _ in range(1, num_layers):
+ self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
+
+ self.dropout_rnn = torch.nn.ModuleList(
+ [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)]
+ )
+
+ self.dlayers = num_layers
+ self.dtype = rnn_type
+
+ self.output_size = hidden_size
+ self.vocab_size = vocab_size
+
+ self.device = next(self.parameters()).device
+ self.score_cache = {}
+
+ def forward(
+ self,
+ labels: torch.Tensor,
+ label_lens: torch.Tensor,
+ states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None,
+ ) -> torch.Tensor:
+ """Encode source label sequences.
+
+ Args:
+ labels: Label ID sequences. (B, L)
+ states: Decoder hidden states.
+ ((N, B, D_dec), (N, B, D_dec) or None) or None
+
+ Returns:
+ dec_out: Decoder output sequences. (B, U, D_dec)
+
+ """
+ if states is None:
+ states = self.init_state(labels.size(0))
+
+ dec_embed = self.dropout_embed(self.embed(labels))
+ dec_out, states = self.rnn_forward(dec_embed, states)
+ return dec_out
+
+ def rnn_forward(
+ self,
+ x: torch.Tensor,
+ state: Tuple[torch.Tensor, Optional[torch.Tensor]],
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
+ """Encode source label sequences.
+
+ Args:
+ x: RNN input sequences. (B, D_emb)
+ state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+ Returns:
+ x: RNN output sequences. (B, D_dec)
+ (h_next, c_next): Decoder hidden states.
+ (N, B, D_dec), (N, B, D_dec) or None)
+
+ """
+ h_prev, c_prev = state
+ h_next, c_next = self.init_state(x.size(0))
+
+ for layer in range(self.dlayers):
+ if self.dtype == "lstm":
+ x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
+ layer
+ ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
+ else:
+ x, h_next[layer : layer + 1] = self.rnn[layer](
+ x, hx=h_prev[layer : layer + 1]
+ )
+
+ x = self.dropout_rnn[layer](x)
+
+ return x, (h_next, c_next)
+
+ def score(
+ self,
+ label: torch.Tensor,
+ label_sequence: List[int],
+ dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]],
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
+ """One-step forward hypothesis.
+
+ Args:
+ label: Previous label. (1, 1)
+ label_sequence: Current label sequence.
+ dec_state: Previous decoder hidden states.
+ ((N, 1, D_dec), (N, 1, D_dec) or None)
+
+ Returns:
+ dec_out: Decoder output sequence. (1, D_dec)
+ dec_state: Decoder hidden states.
+ ((N, 1, D_dec), (N, 1, D_dec) or None)
+
+ """
+ str_labels = "_".join(map(str, label_sequence))
+
+ if str_labels in self.score_cache:
+ dec_out, dec_state = self.score_cache[str_labels]
+ else:
+ dec_embed = self.embed(label)
+ dec_out, dec_state = self.rnn_forward(dec_embed, dec_state)
+
+ self.score_cache[str_labels] = (dec_out, dec_state)
+
+ return dec_out[0], dec_state
+
+ def batch_score(
+ self,
+ hyps: List[Hypothesis],
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
+ """One-step forward hypotheses.
+
+ Args:
+ hyps: Hypotheses.
+
+ Returns:
+ dec_out: Decoder output sequences. (B, D_dec)
+ states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+ """
+ labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device)
+ dec_embed = self.embed(labels)
+
+ states = self.create_batch_states([h.dec_state for h in hyps])
+ dec_out, states = self.rnn_forward(dec_embed, states)
+
+ return dec_out.squeeze(1), states
+
+ def set_device(self, device: torch.device) -> None:
+ """Set GPU device to use.
+
+ Args:
+ device: Device ID.
+
+ """
+ self.device = device
+
+ def init_state(
+ self, batch_size: int
+ ) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
+ """Initialize decoder states.
+
+ Args:
+ batch_size: Batch size.
+
+ Returns:
+ : Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+ """
+ h_n = torch.zeros(
+ self.dlayers,
+ batch_size,
+ self.output_size,
+ device=self.device,
+ )
+
+ if self.dtype == "lstm":
+ c_n = torch.zeros(
+ self.dlayers,
+ batch_size,
+ self.output_size,
+ device=self.device,
+ )
+
+ return (h_n, c_n)
+
+ return (h_n, None)
+
+ def select_state(
+ self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Get specified ID state from decoder hidden states.
+
+ Args:
+ states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+ idx: State ID to extract.
+
+ Returns:
+ : Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None)
+
+ """
+ return (
+ states[0][:, idx : idx + 1, :],
+ states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None,
+ )
+
+ def create_batch_states(
+ self,
+ new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]],
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Create decoder hidden states.
+
+ Args:
+ new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)]
+
+ Returns:
+ states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None)
+
+ """
+ return (
+ torch.cat([s[0] for s in new_states], dim=1),
+ torch.cat([s[1] for s in new_states], dim=1)
+ if self.dtype == "lstm"
+ else None,
+ )
diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py
new file mode 100644
index 0000000..0cae306
--- /dev/null
+++ b/funasr/models/e2e_asr_transducer.py
@@ -0,0 +1,1013 @@
+"""ESPnet2 ASR Transducer model."""
+
+import logging
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from packaging.version import parse as V
+from typeguard import check_argument_types
+
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
+from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
+from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
+from funasr.models.joint_net.joint_network import JointNetwork
+from funasr.modules.nets_utils import get_transducer_task_io
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if V(torch.__version__) >= V("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+class TransducerModel(AbsESPnetModel):
+ """ESPnet2ASRTransducerModel module definition.
+
+ Args:
+ vocab_size: Size of complete vocabulary (w/ EOS and blank included).
+ token_list: List of token
+ frontend: Frontend module.
+ specaug: SpecAugment module.
+ normalize: Normalization module.
+ encoder: Encoder module.
+ decoder: Decoder module.
+ joint_network: Joint Network module.
+ transducer_weight: Weight of the Transducer loss.
+ fastemit_lambda: FastEmit lambda value.
+ auxiliary_ctc_weight: Weight of auxiliary CTC loss.
+ auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
+ auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
+ auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
+ ignore_id: Initial padding ID.
+ sym_space: Space symbol.
+ sym_blank: Blank Symbol
+ report_cer: Whether to report Character Error Rate during validation.
+ report_wer: Whether to report Word Error Rate during validation.
+ extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
+
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ encoder: Encoder,
+ decoder: RNNTDecoder,
+ joint_network: JointNetwork,
+ att_decoder: Optional[AbsAttDecoder] = None,
+ transducer_weight: float = 1.0,
+ fastemit_lambda: float = 0.0,
+ auxiliary_ctc_weight: float = 0.0,
+ auxiliary_ctc_dropout_rate: float = 0.0,
+ auxiliary_lm_loss_weight: float = 0.0,
+ auxiliary_lm_loss_smoothing: float = 0.0,
+ ignore_id: int = -1,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ report_cer: bool = True,
+ report_wer: bool = True,
+ extract_feats_in_collect_stats: bool = True,
+ ) -> None:
+ """Construct an ESPnetASRTransducerModel object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
+ self.blank_id = 0
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.token_list = token_list.copy()
+
+ self.sym_space = sym_space
+ self.sym_blank = sym_blank
+
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+
+ self.encoder = encoder
+ self.decoder = decoder
+ self.joint_network = joint_network
+
+ self.criterion_transducer = None
+ self.error_calculator = None
+
+ self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+ self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+ if self.use_auxiliary_ctc:
+ self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
+ self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+ if self.use_auxiliary_lm_loss:
+ self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+ self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+ self.transducer_weight = transducer_weight
+ self.fastemit_lambda = fastemit_lambda
+
+ self.auxiliary_ctc_weight = auxiliary_ctc_weight
+ self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+
+ self.report_cer = report_cer
+ self.report_wer = report_wer
+
+ self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Forward architecture and compute loss(es).
+
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ text: Label ID sequences. (B, L)
+ text_lengths: Label ID sequences lengths. (B,)
+ kwargs: Contains "utts_id".
+
+ Return:
+ loss: Main loss value.
+ stats: Task statistics.
+ weight: Task weights.
+
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+
+ batch_size = speech.shape[0]
+ text = text[:, : text_lengths.max()]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ # 2. Transducer-related I/O preparation
+ decoder_in, target, t_len, u_len = get_transducer_task_io(
+ text,
+ encoder_out_lens,
+ ignore_id=self.ignore_id,
+ )
+
+ # 3. Decoder
+ self.decoder.set_device(encoder_out.device)
+ decoder_out = self.decoder(decoder_in, u_len)
+
+ # 4. Joint Network
+ joint_out = self.joint_network(
+ encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
+ )
+
+ # 5. Losses
+ loss_trans, cer_trans, wer_trans = self._calc_transducer_loss(
+ encoder_out,
+ joint_out,
+ target,
+ t_len,
+ u_len,
+ )
+
+ loss_ctc, loss_lm = 0.0, 0.0
+
+ if self.use_auxiliary_ctc:
+ loss_ctc = self._calc_ctc_loss(
+ encoder_out,
+ target,
+ t_len,
+ u_len,
+ )
+
+ if self.use_auxiliary_lm_loss:
+ loss_lm = self._calc_lm_loss(decoder_out, target)
+
+ loss = (
+ self.transducer_weight * loss_trans
+ + self.auxiliary_ctc_weight * loss_ctc
+ + self.auxiliary_lm_loss_weight * loss_lm
+ )
+
+ stats = dict(
+ loss=loss.detach(),
+ loss_transducer=loss_trans.detach(),
+ aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+ aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+ cer_transducer=cer_trans,
+ wer_transducer=wer_trans,
+ )
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ """Collect features sequences and features lengths sequences.
+
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ text: Label ID sequences. (B, L)
+ text_lengths: Label ID sequences lengths. (B,)
+ kwargs: Contains "utts_id".
+
+ Return:
+ {}: "feats": Features sequences. (B, T, D_feats),
+ "feats_lengths": Features sequences lengths. (B,)
+
+ """
+ if self.extract_feats_in_collect_stats:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ else:
+ # Generate dummy stats if extract_feats_in_collect_stats is False
+ logging.warning(
+ "Generating dummy stats for feats and feats_lengths, "
+ "because encoder_conf.extract_feats_in_collect_stats is "
+ f"{self.extract_feats_in_collect_stats}"
+ )
+
+ feats, feats_lengths = speech, speech_lengths
+
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encoder speech sequences.
+
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+
+ Return:
+ encoder_out: Encoder outputs. (B, T, D_enc)
+ encoder_out_lens: Encoder outputs lengths. (B,)
+
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # 4. Forward encoder
+ encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
+
+ assert encoder_out.size(0) == speech.size(0), (
+ encoder_out.size(),
+ speech.size(0),
+ )
+ assert encoder_out.size(1) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+
+ return encoder_out, encoder_out_lens
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Extract features sequences and features sequences lengths.
+
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+
+ Return:
+ feats: Features sequences. (B, T, D_feats)
+ feats_lengths: Features sequences lengths. (B,)
+
+ """
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+
+ if self.frontend is not None:
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ feats, feats_lengths = speech, speech_lengths
+
+ return feats, feats_lengths
+
+ def _calc_transducer_loss(
+ self,
+ encoder_out: torch.Tensor,
+ joint_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
+ """Compute Transducer loss.
+
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ joint_out: Joint Network output sequences (B, T, U, D_joint)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
+
+ Return:
+ loss_transducer: Transducer loss value.
+ cer_transducer: Character error rate for Transducer.
+ wer_transducer: Word Error Rate for Transducer.
+
+ """
+ if self.criterion_transducer is None:
+ try:
+ # from warprnnt_pytorch import RNNTLoss
+ # self.criterion_transducer = RNNTLoss(
+ # reduction="mean",
+ # fastemit_lambda=self.fastemit_lambda,
+ # )
+ from warp_rnnt import rnnt_loss as RNNTLoss
+ self.criterion_transducer = RNNTLoss
+
+ except ImportError:
+ logging.error(
+ "warp-rnnt was not installed."
+ "Please consult the installation documentation."
+ )
+ exit(1)
+
+ # loss_transducer = self.criterion_transducer(
+ # joint_out,
+ # target,
+ # t_len,
+ # u_len,
+ # )
+ log_probs = torch.log_softmax(joint_out, dim=-1)
+
+ loss_transducer = self.criterion_transducer(
+ log_probs,
+ target,
+ t_len,
+ u_len,
+ reduction="mean",
+ blank=self.blank_id,
+ fastemit_lambda=self.fastemit_lambda,
+ gather=True,
+ )
+
+ if not self.training and (self.report_cer or self.report_wer):
+ if self.error_calculator is None:
+ from espnet2.asr_transducer.error_calculator import ErrorCalculator
+
+ self.error_calculator = ErrorCalculator(
+ self.decoder,
+ self.joint_network,
+ self.token_list,
+ self.sym_space,
+ self.sym_blank,
+ report_cer=self.report_cer,
+ report_wer=self.report_wer,
+ )
+
+ cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+
+ return loss_transducer, cer_transducer, wer_transducer
+
+ return loss_transducer, None, None
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute CTC loss.
+
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
+
+ Return:
+ loss_ctc: CTC loss value.
+
+ """
+ ctc_in = self.ctc_lin(
+ torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
+ )
+ ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
+
+ target_mask = target != 0
+ ctc_target = target[target_mask].cpu()
+
+ with torch.backends.cudnn.flags(deterministic=True):
+ loss_ctc = torch.nn.functional.ctc_loss(
+ ctc_in,
+ ctc_target,
+ t_len,
+ u_len,
+ zero_infinity=True,
+ reduction="sum",
+ )
+ loss_ctc /= target.size(0)
+
+ return loss_ctc
+
+ def _calc_lm_loss(
+ self,
+ decoder_out: torch.Tensor,
+ target: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute LM loss.
+
+ Args:
+ decoder_out: Decoder output sequences. (B, U, D_dec)
+ target: Target label ID sequences. (B, L)
+
+ Return:
+ loss_lm: LM loss value.
+
+ """
+ lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
+ lm_target = target.view(-1).type(torch.int64)
+
+ with torch.no_grad():
+ true_dist = lm_loss_in.clone()
+ true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
+
+ # Ignore blank ID (0)
+ ignore = lm_target == 0
+ lm_target = lm_target.masked_fill(ignore, 0)
+
+ true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
+
+ loss_lm = torch.nn.functional.kl_div(
+ torch.log_softmax(lm_loss_in, dim=1),
+ true_dist,
+ reduction="none",
+ )
+ loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
+ 0
+ )
+
+ return loss_lm
+
+class UnifiedTransducerModel(AbsESPnetModel):
+ """ESPnet2ASRTransducerModel module definition.
+ Args:
+ vocab_size: Size of complete vocabulary (w/ EOS and blank included).
+ token_list: List of token
+ frontend: Frontend module.
+ specaug: SpecAugment module.
+ normalize: Normalization module.
+ encoder: Encoder module.
+ decoder: Decoder module.
+ joint_network: Joint Network module.
+ transducer_weight: Weight of the Transducer loss.
+ fastemit_lambda: FastEmit lambda value.
+ auxiliary_ctc_weight: Weight of auxiliary CTC loss.
+ auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs.
+ auxiliary_lm_loss_weight: Weight of auxiliary LM loss.
+ auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing.
+ ignore_id: Initial padding ID.
+ sym_space: Space symbol.
+ sym_blank: Blank Symbol
+ report_cer: Whether to report Character Error Rate during validation.
+ report_wer: Whether to report Word Error Rate during validation.
+ extract_feats_in_collect_stats: Whether to use extract_feats stats collection.
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ token_list: Union[Tuple[str, ...], List[str]],
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ encoder: Encoder,
+ decoder: RNNTDecoder,
+ joint_network: JointNetwork,
+ att_decoder: Optional[AbsAttDecoder] = None,
+ transducer_weight: float = 1.0,
+ fastemit_lambda: float = 0.0,
+ auxiliary_ctc_weight: float = 0.0,
+ auxiliary_att_weight: float = 0.0,
+ auxiliary_ctc_dropout_rate: float = 0.0,
+ auxiliary_lm_loss_weight: float = 0.0,
+ auxiliary_lm_loss_smoothing: float = 0.0,
+ ignore_id: int = -1,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_sos: str = "<sos/eos>",
+ sym_eos: str = "<sos/eos>",
+ extract_feats_in_collect_stats: bool = True,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ ) -> None:
+ """Construct an ESPnetASRTransducerModel object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos)
+ self.blank_id = 0
+
+ if sym_sos in token_list:
+ self.sos = token_list.index(sym_sos)
+ else:
+ self.sos = vocab_size - 1
+ if sym_eos in token_list:
+ self.eos = token_list.index(sym_eos)
+ else:
+ self.eos = vocab_size - 1
+
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.token_list = token_list.copy()
+
+ self.sym_space = sym_space
+ self.sym_blank = sym_blank
+
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+
+ self.encoder = encoder
+ self.decoder = decoder
+ self.joint_network = joint_network
+
+ self.criterion_transducer = None
+ self.error_calculator = None
+
+ self.use_auxiliary_ctc = auxiliary_ctc_weight > 0
+ self.use_auxiliary_att = auxiliary_att_weight > 0
+ self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0
+
+ if self.use_auxiliary_ctc:
+ self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size)
+ self.ctc_dropout_rate = auxiliary_ctc_dropout_rate
+
+ if self.use_auxiliary_att:
+ self.att_decoder = att_decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+
+ if self.use_auxiliary_lm_loss:
+ self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size)
+ self.lm_loss_smoothing = auxiliary_lm_loss_smoothing
+
+ self.transducer_weight = transducer_weight
+ self.fastemit_lambda = fastemit_lambda
+
+ self.auxiliary_ctc_weight = auxiliary_ctc_weight
+ self.auxiliary_att_weight = auxiliary_att_weight
+ self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight
+
+ self.report_cer = report_cer
+ self.report_wer = report_wer
+
+ self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Forward architecture and compute loss(es).
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ text: Label ID sequences. (B, L)
+ text_lengths: Label ID sequences lengths. (B,)
+ kwargs: Contains "utts_id".
+ Return:
+ loss: Main loss value.
+ stats: Task statistics.
+ weight: Task weights.
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+
+ batch_size = speech.shape[0]
+ text = text[:, : text_lengths.max()]
+ #print(speech.shape)
+ # 1. Encoder
+ encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ loss_att, loss_att_chunk = 0.0, 0.0
+
+ if self.use_auxiliary_att:
+ loss_att, _ = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+ loss_att_chunk, _ = self._calc_att_loss(
+ encoder_out_chunk, encoder_out_lens, text, text_lengths
+ )
+
+ # 2. Transducer-related I/O preparation
+ decoder_in, target, t_len, u_len = get_transducer_task_io(
+ text,
+ encoder_out_lens,
+ ignore_id=self.ignore_id,
+ )
+
+ # 3. Decoder
+ self.decoder.set_device(encoder_out.device)
+ decoder_out = self.decoder(decoder_in, u_len)
+
+ # 4. Joint Network
+ joint_out = self.joint_network(
+ encoder_out.unsqueeze(2), decoder_out.unsqueeze(1)
+ )
+
+ joint_out_chunk = self.joint_network(
+ encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1)
+ )
+
+ # 5. Losses
+ loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss(
+ encoder_out,
+ joint_out,
+ target,
+ t_len,
+ u_len,
+ )
+
+ loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss(
+ encoder_out_chunk,
+ joint_out_chunk,
+ target,
+ t_len,
+ u_len,
+ )
+
+ loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0
+
+ if self.use_auxiliary_ctc:
+ loss_ctc = self._calc_ctc_loss(
+ encoder_out,
+ target,
+ t_len,
+ u_len,
+ )
+ loss_ctc_chunk = self._calc_ctc_loss(
+ encoder_out_chunk,
+ target,
+ t_len,
+ u_len,
+ )
+
+ if self.use_auxiliary_lm_loss:
+ loss_lm = self._calc_lm_loss(decoder_out, target)
+
+ loss_trans = loss_trans_utt + loss_trans_chunk
+ loss_ctc = loss_ctc + loss_ctc_chunk
+ loss_ctc = loss_att + loss_att_chunk
+
+ loss = (
+ self.transducer_weight * loss_trans
+ + self.auxiliary_ctc_weight * loss_ctc
+ + self.auxiliary_att_weight * loss_att
+ + self.auxiliary_lm_loss_weight * loss_lm
+ )
+
+ stats = dict(
+ loss=loss.detach(),
+ loss_transducer=loss_trans_utt.detach(),
+ loss_transducer_chunk=loss_trans_chunk.detach(),
+ aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None,
+ aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None,
+ aux_att_loss=loss_att.detach() if loss_att > 0.0 else None,
+ aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None,
+ aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None,
+ cer_transducer=cer_trans,
+ wer_transducer=wer_trans,
+ cer_transducer_chunk=cer_trans_chunk,
+ wer_transducer_chunk=wer_trans_chunk,
+ )
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Dict[str, torch.Tensor]:
+ """Collect features sequences and features lengths sequences.
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ text: Label ID sequences. (B, L)
+ text_lengths: Label ID sequences lengths. (B,)
+ kwargs: Contains "utts_id".
+ Return:
+ {}: "feats": Features sequences. (B, T, D_feats),
+ "feats_lengths": Features sequences lengths. (B,)
+ """
+ if self.extract_feats_in_collect_stats:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ else:
+ # Generate dummy stats if extract_feats_in_collect_stats is False
+ logging.warning(
+ "Generating dummy stats for feats and feats_lengths, "
+ "because encoder_conf.extract_feats_in_collect_stats is "
+ f"{self.extract_feats_in_collect_stats}"
+ )
+
+ feats, feats_lengths = speech, speech_lengths
+
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encoder speech sequences.
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ Return:
+ encoder_out: Encoder outputs. (B, T, D_enc)
+ encoder_out_lens: Encoder outputs lengths. (B,)
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # 4. Forward encoder
+ encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths)
+
+ assert encoder_out.size(0) == speech.size(0), (
+ encoder_out.size(),
+ speech.size(0),
+ )
+ assert encoder_out.size(1) <= encoder_out_lens.max(), (
+ encoder_out.size(),
+ encoder_out_lens.max(),
+ )
+
+ return encoder_out, encoder_out_chunk, encoder_out_lens
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Extract features sequences and features sequences lengths.
+ Args:
+ speech: Speech sequences. (B, S)
+ speech_lengths: Speech sequences lengths. (B,)
+ Return:
+ feats: Features sequences. (B, T, D_feats)
+ feats_lengths: Features sequences lengths. (B,)
+ """
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+
+ if self.frontend is not None:
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ feats, feats_lengths = speech, speech_lengths
+
+ return feats, feats_lengths
+
+ def _calc_transducer_loss(
+ self,
+ encoder_out: torch.Tensor,
+ joint_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]:
+ """Compute Transducer loss.
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ joint_out: Joint Network output sequences (B, T, U, D_joint)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
+ Return:
+ loss_transducer: Transducer loss value.
+ cer_transducer: Character error rate for Transducer.
+ wer_transducer: Word Error Rate for Transducer.
+ """
+ if self.criterion_transducer is None:
+ try:
+ # from warprnnt_pytorch import RNNTLoss
+ # self.criterion_transducer = RNNTLoss(
+ # reduction="mean",
+ # fastemit_lambda=self.fastemit_lambda,
+ # )
+ from warp_rnnt import rnnt_loss as RNNTLoss
+ self.criterion_transducer = RNNTLoss
+
+ except ImportError:
+ logging.error(
+ "warp-rnnt was not installed."
+ "Please consult the installation documentation."
+ )
+ exit(1)
+
+ # loss_transducer = self.criterion_transducer(
+ # joint_out,
+ # target,
+ # t_len,
+ # u_len,
+ # )
+ log_probs = torch.log_softmax(joint_out, dim=-1)
+
+ loss_transducer = self.criterion_transducer(
+ log_probs,
+ target,
+ t_len,
+ u_len,
+ reduction="mean",
+ blank=self.blank_id,
+ fastemit_lambda=self.fastemit_lambda,
+ gather=True,
+ )
+
+ if not self.training and (self.report_cer or self.report_wer):
+ if self.error_calculator is None:
+ self.error_calculator = ErrorCalculator(
+ self.decoder,
+ self.joint_network,
+ self.token_list,
+ self.sym_space,
+ self.sym_blank,
+ report_cer=self.report_cer,
+ report_wer=self.report_wer,
+ )
+
+ cer_transducer, wer_transducer = self.error_calculator(encoder_out, target)
+ return loss_transducer, cer_transducer, wer_transducer
+
+ return loss_transducer, None, None
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ target: torch.Tensor,
+ t_len: torch.Tensor,
+ u_len: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute CTC loss.
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ target: Target label ID sequences. (B, L)
+ t_len: Encoder output sequences lengths. (B,)
+ u_len: Target label ID sequences lengths. (B,)
+ Return:
+ loss_ctc: CTC loss value.
+ """
+ ctc_in = self.ctc_lin(
+ torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate)
+ )
+ ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1)
+
+ target_mask = target != 0
+ ctc_target = target[target_mask].cpu()
+
+ with torch.backends.cudnn.flags(deterministic=True):
+ loss_ctc = torch.nn.functional.ctc_loss(
+ ctc_in,
+ ctc_target,
+ t_len,
+ u_len,
+ zero_infinity=True,
+ reduction="sum",
+ )
+ loss_ctc /= target.size(0)
+
+ return loss_ctc
+
+ def _calc_lm_loss(
+ self,
+ decoder_out: torch.Tensor,
+ target: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute LM loss.
+ Args:
+ decoder_out: Decoder output sequences. (B, U, D_dec)
+ target: Target label ID sequences. (B, L)
+ Return:
+ loss_lm: LM loss value.
+ """
+ lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size)
+ lm_target = target.view(-1).type(torch.int64)
+
+ with torch.no_grad():
+ true_dist = lm_loss_in.clone()
+ true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1))
+
+ # Ignore blank ID (0)
+ ignore = lm_target == 0
+ lm_target = lm_target.masked_fill(ignore, 0)
+
+ true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing))
+
+ loss_lm = torch.nn.functional.kl_div(
+ torch.log_softmax(lm_loss_in, dim=1),
+ true_dist,
+ reduction="none",
+ )
+ loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size(
+ 0
+ )
+
+ return loss_lm
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ if hasattr(self, "lang_token_id") and self.lang_token_id is not None:
+ ys_pad = torch.cat(
+ [
+ self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device),
+ ys_pad,
+ ],
+ dim=1,
+ )
+ ys_pad_lens += 1
+
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, _ = self.att_decoder(
+ encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+ )
+
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_out_pad)
+ acc_att = th_accuracy(
+ decoder_out.view(-1, self.vocab_size),
+ ys_out_pad,
+ ignore_label=self.ignore_id,
+ )
+
+ return loss_att, acc_att
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 7c7f661..9777cee 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -8,6 +8,7 @@
from typing import Optional
from typing import Tuple
from typing import Union
+from typing import Dict
import torch
from torch import nn
@@ -18,6 +19,7 @@
from funasr.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
+ RelPositionMultiHeadedAttentionChunk,
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr.modules.embedding import (
@@ -25,16 +27,23 @@
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
LegacyRelPositionalEncoding, # noqa: H301
+ StreamingRelPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.nets_utils import get_activation
from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.nets_utils import (
+ TooShortUttError,
+ check_short_utt,
+ make_chunk_mask,
+ make_source_mask,
+)
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
-from funasr.modules.repeat import repeat
+from funasr.modules.repeat import repeat, MultiBlocks
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
@@ -42,6 +51,8 @@
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.modules.subsampling import Conv2dSubsamplingPad
+from funasr.modules.subsampling import StreamingConvInput
+
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
@@ -275,6 +286,188 @@
return (x, pos_emb), mask
return x, mask
+
+class ChunkEncoderLayer(torch.nn.Module):
+ """Chunk Conformer module definition.
+ Args:
+ block_size: Input/output size.
+ self_att: Self-attention module instance.
+ feed_forward: Feed-forward module instance.
+ feed_forward_macaron: Feed-forward module instance for macaron network.
+ conv_mod: Convolution module instance.
+ norm_class: Normalization module class.
+ norm_args: Normalization module arguments.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self,
+ block_size: int,
+ self_att: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ feed_forward_macaron: torch.nn.Module,
+ conv_mod: torch.nn.Module,
+ norm_class: torch.nn.Module = torch.nn.LayerNorm,
+ norm_args: Dict = {},
+ dropout_rate: float = 0.0,
+ ) -> None:
+ """Construct a Conformer object."""
+ super().__init__()
+
+ self.self_att = self_att
+
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.feed_forward_scale = 0.5
+
+ self.conv_mod = conv_mod
+
+ self.norm_feed_forward = norm_class(block_size, **norm_args)
+ self.norm_self_att = norm_class(block_size, **norm_args)
+
+ self.norm_macaron = norm_class(block_size, **norm_args)
+ self.norm_conv = norm_class(block_size, **norm_args)
+ self.norm_final = norm_class(block_size, **norm_args)
+
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ self.block_size = block_size
+ self.cache = None
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset self-attention and convolution modules cache for streaming.
+ Args:
+ left_context: Number of left frames during chunk-by-chunk inference.
+ device: Device to use for cache tensor.
+ """
+ self.cache = [
+ torch.zeros(
+ (1, left_context, self.block_size),
+ device=device,
+ ),
+ torch.zeros(
+ (
+ 1,
+ self.block_size,
+ self.conv_mod.kernel_size - 1,
+ ),
+ device=device,
+ ),
+ ]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Conformer input sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ mask: Source mask. (B, T)
+ chunk_mask: Chunk mask. (T_2, T_2)
+ Returns:
+ x: Conformer output sequences. (B, T, D_block)
+ mask: Source mask. (B, T)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ """
+ residual = x
+
+ x = self.norm_macaron(x)
+ x = residual + self.feed_forward_scale * self.dropout(
+ self.feed_forward_macaron(x)
+ )
+
+ residual = x
+ x = self.norm_self_att(x)
+ x_q = x
+ x = residual + self.dropout(
+ self.self_att(
+ x_q,
+ x,
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+ )
+
+ residual = x
+
+ x = self.norm_conv(x)
+ x, _ = self.conv_mod(x)
+ x = residual + self.dropout(x)
+ residual = x
+
+ x = self.norm_feed_forward(x)
+ x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
+
+ x = self.norm_final(x)
+ return x, mask, pos_enc
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: int = 16,
+ left_context: int = 0,
+ right_context: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode chunk of input sequence.
+ Args:
+ x: Conformer input sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ mask: Source mask. (B, T_2)
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: Conformer output sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ """
+ residual = x
+
+ x = self.norm_macaron(x)
+ x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
+
+ residual = x
+ x = self.norm_self_att(x)
+ if left_context > 0:
+ key = torch.cat([self.cache[0], x], dim=1)
+ else:
+ key = x
+ val = key
+
+ if right_context > 0:
+ att_cache = key[:, -(left_context + right_context) : -right_context, :]
+ else:
+ att_cache = key[:, -left_context:, :]
+ x = residual + self.self_att(
+ x,
+ key,
+ val,
+ pos_enc,
+ mask,
+ left_context=left_context,
+ )
+
+ residual = x
+ x = self.norm_conv(x)
+ x, conv_cache = self.conv_mod(
+ x, cache=self.cache[1], right_context=right_context
+ )
+ x = residual + x
+ residual = x
+
+ x = self.norm_feed_forward(x)
+ x = residual + self.feed_forward_scale * self.feed_forward(x)
+
+ x = self.norm_final(x)
+ self.cache = [att_cache, conv_cache]
+
+ return x, pos_enc
class ConformerEncoder(AbsEncoder):
@@ -604,3 +797,442 @@
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
+
+
+class CausalConvolution(torch.nn.Module):
+ """ConformerConvolution module definition.
+ Args:
+ channels: The number of channels.
+ kernel_size: Size of the convolving kernel.
+ activation: Type of activation function.
+ norm_args: Normalization module arguments.
+ causal: Whether to use causal convolution (set to True if streaming).
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ norm_args: Dict = {},
+ causal: bool = False,
+ ) -> None:
+ """Construct an ConformerConvolution object."""
+ super().__init__()
+
+ assert (kernel_size - 1) % 2 == 0
+
+ self.kernel_size = kernel_size
+
+ self.pointwise_conv1 = torch.nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ if causal:
+ self.lorder = kernel_size - 1
+ padding = 0
+ else:
+ self.lorder = 0
+ padding = (kernel_size - 1) // 2
+
+ self.depthwise_conv = torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ )
+ self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
+ self.pointwise_conv2 = torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ self.activation = activation
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ cache: Optional[torch.Tensor] = None,
+ right_context: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x: ConformerConvolution input sequences. (B, T, D_hidden)
+ cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
+ right_context: Number of frames in right context.
+ Returns:
+ x: ConformerConvolution output sequences. (B, T, D_hidden)
+ cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
+ """
+ x = self.pointwise_conv1(x.transpose(1, 2))
+ x = torch.nn.functional.glu(x, dim=1)
+
+ if self.lorder > 0:
+ if cache is None:
+ x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+ else:
+ x = torch.cat([cache, x], dim=2)
+
+ if right_context > 0:
+ cache = x[:, :, -(self.lorder + right_context) : -right_context]
+ else:
+ cache = x[:, :, -self.lorder :]
+
+ x = self.depthwise_conv(x)
+ x = self.activation(self.norm(x))
+
+ x = self.pointwise_conv2(x).transpose(1, 2)
+
+ return x, cache
+
+class ConformerChunkEncoder(AbsEncoder):
+ """Encoder module definition.
+ Args:
+ input_size: Input size.
+ body_conf: Encoder body configuration.
+ input_conf: Encoder input configuration.
+ main_conf: Encoder main configuration.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ embed_vgg_like: bool = False,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 3,
+ macaron_style: bool = False,
+ rel_pos_type: str = "legacy",
+ pos_enc_layer_type: str = "rel_pos",
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ zero_triu: bool = False,
+ norm_type: str = "layer_norm",
+ cnn_module_kernel: int = 31,
+ conv_mod_norm_eps: float = 0.00001,
+ conv_mod_norm_momentum: float = 0.1,
+ simplified_att_score: bool = False,
+ dynamic_chunk_training: bool = False,
+ short_chunk_threshold: float = 0.75,
+ short_chunk_size: int = 25,
+ left_chunk_size: int = 0,
+ time_reduction_factor: int = 1,
+ unified_model_training: bool = False,
+ default_chunk_size: int = 16,
+ jitter_range: int = 4,
+ subsampling_factor: int = 1,
+ ) -> None:
+ """Construct an Encoder object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ self.embed = StreamingConvInput(
+ input_size,
+ output_size,
+ subsampling_factor,
+ vgg_like=embed_vgg_like,
+ output_size=output_size,
+ )
+
+ self.pos_enc = StreamingRelPositionalEncoding(
+ output_size,
+ positional_dropout_rate,
+ )
+
+ activation = get_activation(
+ activation_type
+ )
+
+ pos_wise_args = (
+ output_size,
+ linear_units,
+ positional_dropout_rate,
+ activation,
+ )
+
+ conv_mod_norm_args = {
+ "eps": conv_mod_norm_eps,
+ "momentum": conv_mod_norm_momentum,
+ }
+
+ conv_mod_args = (
+ output_size,
+ cnn_module_kernel,
+ activation,
+ conv_mod_norm_args,
+ dynamic_chunk_training or unified_model_training,
+ )
+
+ mult_att_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ simplified_att_score,
+ )
+
+
+ fn_modules = []
+ for _ in range(num_blocks):
+ module = lambda: ChunkEncoderLayer(
+ output_size,
+ RelPositionMultiHeadedAttentionChunk(*mult_att_args),
+ PositionwiseFeedForward(*pos_wise_args),
+ PositionwiseFeedForward(*pos_wise_args),
+ CausalConvolution(*conv_mod_args),
+ dropout_rate=dropout_rate,
+ )
+ fn_modules.append(module)
+
+ self.encoders = MultiBlocks(
+ [fn() for fn in fn_modules],
+ output_size,
+ )
+
+ self._output_size = output_size
+
+ self.dynamic_chunk_training = dynamic_chunk_training
+ self.short_chunk_threshold = short_chunk_threshold
+ self.short_chunk_size = short_chunk_size
+ self.left_chunk_size = left_chunk_size
+
+ self.unified_model_training = unified_model_training
+ self.default_chunk_size = default_chunk_size
+ self.jitter_range = jitter_range
+
+ self.time_reduction_factor = time_reduction_factor
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
+ """Return the corresponding number of sample for a given chunk size, in frames.
+ Where size is the number of features frames after applying subsampling.
+ Args:
+ size: Number of frames after subsampling.
+ hop_length: Frontend's hop length
+ Returns:
+ : Number of raw samples
+ """
+ return self.embed.get_size_before_subsampling(size) * hop_length
+
+ def get_encoder_input_size(self, size: int) -> int:
+ """Return the corresponding number of sample for a given chunk size, in frames.
+ Where size is the number of features frames after applying subsampling.
+ Args:
+ size: Number of frames after subsampling.
+ Returns:
+ : Number of raw samples
+ """
+ return self.embed.get_size_before_subsampling(size)
+
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset encoder streaming cache.
+ Args:
+ left_context: Number of frames in left context.
+ device: Device ID.
+ """
+ return self.encoders.reset_streaming_cache(left_context, device)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Encoder input features. (B, T_in, F)
+ x_len: Encoder input features lengths. (B,)
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ x_len: Encoder outputs lenghts. (B,)
+ """
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len)
+
+ if self.unified_model_training:
+ chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+ x_utt = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=None,
+ )
+ x_chunk = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x_utt = x_utt[:,::self.time_reduction_factor,:]
+ x_chunk = x_chunk[:,::self.time_reduction_factor,:]
+ olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+ return x_utt, x_chunk, olens
+
+ elif self.dynamic_chunk_training:
+ max_len = x.size(1)
+ chunk_size = torch.randint(1, max_len, (1,)).item()
+
+ if chunk_size > (max_len * self.short_chunk_threshold):
+ chunk_size = max_len
+ else:
+ chunk_size = (chunk_size % self.short_chunk_size) + 1
+
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+ else:
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = None
+ x = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+ olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+ return x, olens
+
+ def simu_chunk_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len)
+
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+
+ x = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+
+ return x
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ processed_frames: torch.tensor,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ """Encode input sequences as chunks.
+ Args:
+ x: Encoder input features. (1, T_in, F)
+ x_len: Encoder input features lengths. (1,)
+ processed_frames: Number of frames already seen.
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ """
+ mask = make_source_mask(x_len)
+ x, mask = self.embed(x, mask, None)
+
+ if left_context > 0:
+ processed_mask = (
+ torch.arange(left_context, device=x.device)
+ .view(1, left_context)
+ .flip(1)
+ )
+ processed_mask = processed_mask >= processed_frames
+ mask = torch.cat([processed_mask, mask], dim=1)
+ pos_enc = self.pos_enc(x, left_context=left_context)
+ x = self.encoders.chunk_forward(
+ x,
+ pos_enc,
+ mask,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
+ )
+
+ if right_context > 0:
+ x = x[:, 0:-right_context, :]
+
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+ return x
diff --git a/funasr/models/joint_net/joint_network.py b/funasr/models/joint_net/joint_network.py
new file mode 100644
index 0000000..ed827c4
--- /dev/null
+++ b/funasr/models/joint_net/joint_network.py
@@ -0,0 +1,61 @@
+"""Transducer joint network implementation."""
+
+import torch
+
+from funasr.modules.nets_utils import get_activation
+
+
+class JointNetwork(torch.nn.Module):
+ """Transducer joint network module.
+
+ Args:
+ output_size: Output size.
+ encoder_size: Encoder output size.
+ decoder_size: Decoder output size..
+ joint_space_size: Joint space size.
+ joint_act_type: Type of activation for joint network.
+ **activation_parameters: Parameters for the activation function.
+
+ """
+
+ def __init__(
+ self,
+ output_size: int,
+ encoder_size: int,
+ decoder_size: int,
+ joint_space_size: int = 256,
+ joint_activation_type: str = "tanh",
+ ) -> None:
+ """Construct a JointNetwork object."""
+ super().__init__()
+
+ self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size)
+ self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False)
+
+ self.lin_out = torch.nn.Linear(joint_space_size, output_size)
+
+ self.joint_activation = get_activation(
+ joint_activation_type
+ )
+
+ def forward(
+ self,
+ enc_out: torch.Tensor,
+ dec_out: torch.Tensor,
+ project_input: bool = True,
+ ) -> torch.Tensor:
+ """Joint computation of encoder and decoder hidden state sequences.
+
+ Args:
+ enc_out: Expanded encoder output state sequences (B, T, 1, D_enc)
+ dec_out: Expanded decoder output state sequences (B, 1, U, D_dec)
+
+ Returns:
+ joint_out: Joint output state sequences. (B, T, U, D_out)
+
+ """
+ if project_input:
+ joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out))
+ else:
+ joint_out = self.joint_activation(enc_out + dec_out)
+ return self.lin_out(joint_out)
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index 31d5a87..6202079 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -11,7 +11,7 @@
import numpy
import torch
from torch import nn
-
+from typing import Optional, Tuple
class MultiHeadedAttention(nn.Module):
"""Multi-Head Attention layer.
@@ -741,3 +741,221 @@
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs
+
+class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
+ """RelPositionMultiHeadedAttention definition.
+ Args:
+ num_heads: Number of attention heads.
+ embed_size: Embedding size.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self,
+ num_heads: int,
+ embed_size: int,
+ dropout_rate: float = 0.0,
+ simplified_attention_score: bool = False,
+ ) -> None:
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+
+ self.d_k = embed_size // num_heads
+ self.num_heads = num_heads
+
+ assert self.d_k * num_heads == embed_size, (
+ "embed_size (%d) must be divisible by num_heads (%d)",
+ (embed_size, num_heads),
+ )
+
+ self.linear_q = torch.nn.Linear(embed_size, embed_size)
+ self.linear_k = torch.nn.Linear(embed_size, embed_size)
+ self.linear_v = torch.nn.Linear(embed_size, embed_size)
+
+ self.linear_out = torch.nn.Linear(embed_size, embed_size)
+
+ if simplified_attention_score:
+ self.linear_pos = torch.nn.Linear(embed_size, num_heads)
+
+ self.compute_att_score = self.compute_simplified_attention_score
+ else:
+ self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
+
+ self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+ self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+ torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ self.compute_att_score = self.compute_attention_score
+
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.attn = None
+
+ def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+ """Compute relative positional encoding.
+ Args:
+ x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
+ left_context: Number of frames in left context.
+ Returns:
+ x: Output sequence. (B, H, T_1, T_2)
+ """
+ batch_size, n_heads, time1, n = x.shape
+ time2 = time1 + left_context
+
+ batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
+
+ return x.as_strided(
+ (batch_size, n_heads, time1, time2),
+ (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
+ storage_offset=(n_stride * (time1 - 1)),
+ )
+
+ def compute_simplified_attention_score(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ pos_enc: torch.Tensor,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Simplified attention score computation.
+ Reference: https://github.com/k2-fsa/icefall/pull/458
+ Args:
+ query: Transformed query tensor. (B, H, T_1, d_k)
+ key: Transformed key tensor. (B, H, T_2, d_k)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ left_context: Number of frames in left context.
+ Returns:
+ : Attention score. (B, H, T_1, T_2)
+ """
+ pos_enc = self.linear_pos(pos_enc)
+
+ matrix_ac = torch.matmul(query, key.transpose(2, 3))
+
+ matrix_bd = self.rel_shift(
+ pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
+ left_context=left_context,
+ )
+
+ return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+ def compute_attention_score(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ pos_enc: torch.Tensor,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Attention score computation.
+ Args:
+ query: Transformed query tensor. (B, H, T_1, d_k)
+ key: Transformed key tensor. (B, H, T_2, d_k)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ left_context: Number of frames in left context.
+ Returns:
+ : Attention score. (B, H, T_1, T_2)
+ """
+ p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
+
+ query = query.transpose(1, 2)
+ q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+ q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+ matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+ matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
+ matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
+
+ return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+ Args:
+ query: Query tensor. (B, T_1, size)
+ key: Key tensor. (B, T_2, size)
+ v: Value tensor. (B, T_2, size)
+ Returns:
+ q: Transformed query tensor. (B, H, T_1, d_k)
+ k: Transformed key tensor. (B, H, T_2, d_k)
+ v: Transformed value tensor. (B, H, T_2, d_k)
+ """
+ n_batch = query.size(0)
+
+ q = (
+ self.linear_q(query)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+ k = (
+ self.linear_k(key)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+ v = (
+ self.linear_v(value)
+ .view(n_batch, -1, self.num_heads, self.d_k)
+ .transpose(1, 2)
+ )
+
+ return q, k, v
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+ Args:
+ value: Transformed value. (B, H, T_2, d_k)
+ scores: Attention score. (B, H, T_1, T_2)
+ mask: Source mask. (B, T_2)
+ chunk_mask: Chunk mask. (T_1, T_1)
+ Returns:
+ attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
+ """
+ batch_size = scores.size(0)
+ mask = mask.unsqueeze(1).unsqueeze(2)
+ if chunk_mask is not None:
+ mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
+ scores = scores.masked_fill(mask, float("-inf"))
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+
+ attn_output = self.dropout(self.attn)
+ attn_output = torch.matmul(attn_output, value)
+
+ attn_output = self.linear_out(
+ attn_output.transpose(1, 2)
+ .contiguous()
+ .view(batch_size, -1, self.num_heads * self.d_k)
+ )
+
+ return attn_output
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ left_context: int = 0,
+ ) -> torch.Tensor:
+ """Compute scaled dot product attention with rel. positional encoding.
+ Args:
+ query: Query tensor. (B, T_1, size)
+ key: Key tensor. (B, T_2, size)
+ value: Value tensor. (B, T_2, size)
+ pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+ mask: Source mask. (B, T_2)
+ chunk_mask: Chunk mask. (T_1, T_1)
+ left_context: Number of frames in left context.
+ Returns:
+ : Output tensor. (B, T_1, H * d_k)
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
+ return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
diff --git a/funasr/modules/beam_search/beam_search_transducer.py b/funasr/modules/beam_search/beam_search_transducer.py
new file mode 100644
index 0000000..3eb8e08
--- /dev/null
+++ b/funasr/modules/beam_search/beam_search_transducer.py
@@ -0,0 +1,704 @@
+"""Search algorithms for Transducer models."""
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from funasr.models.joint_net.joint_network import JointNetwork
+
+
+@dataclass
+class Hypothesis:
+ """Default hypothesis definition for Transducer search algorithms.
+
+ Args:
+ score: Total log-probability.
+ yseq: Label sequence as integer ID sequence.
+ dec_state: RNNDecoder or StatelessDecoder state.
+ ((N, 1, D_dec), (N, 1, D_dec) or None) or None
+ lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None
+
+ """
+
+ score: float
+ yseq: List[int]
+ dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None
+ lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None
+
+
+@dataclass
+class ExtendedHypothesis(Hypothesis):
+ """Extended hypothesis definition for NSC beam search and mAES.
+
+ Args:
+ : Hypothesis dataclass arguments.
+ dec_out: Decoder output sequence. (B, D_dec)
+ lm_score: Log-probabilities of the LM for given label. (vocab_size)
+
+ """
+
+ dec_out: torch.Tensor = None
+ lm_score: torch.Tensor = None
+
+
+class BeamSearchTransducer:
+ """Beam search implementation for Transducer.
+
+ Args:
+ decoder: Decoder module.
+ joint_network: Joint network module.
+ beam_size: Size of the beam.
+ lm: LM class.
+ lm_weight: LM weight for soft fusion.
+ search_type: Search algorithm to use during inference.
+ max_sym_exp: Number of maximum symbol expansions at each time step. (TSD)
+ u_max: Maximum expected target sequence length. (ALSD)
+ nstep: Number of maximum expansion steps at each time step. (mAES)
+ expansion_gamma: Allowed logp difference for prune-by-value method. (mAES)
+ expansion_beta:
+ Number of additional candidates for expanded hypotheses selection. (mAES)
+ score_norm: Normalize final scores by length.
+ nbest: Number of final hypothesis.
+ streaming: Whether to perform chunk-by-chunk beam search.
+
+ """
+
+ def __init__(
+ self,
+ decoder,
+ joint_network: JointNetwork,
+ beam_size: int,
+ lm: Optional[torch.nn.Module] = None,
+ lm_weight: float = 0.1,
+ search_type: str = "default",
+ max_sym_exp: int = 3,
+ u_max: int = 50,
+ nstep: int = 2,
+ expansion_gamma: float = 2.3,
+ expansion_beta: int = 2,
+ score_norm: bool = False,
+ nbest: int = 1,
+ streaming: bool = False,
+ ) -> None:
+ """Construct a BeamSearchTransducer object."""
+ super().__init__()
+
+ self.decoder = decoder
+ self.joint_network = joint_network
+
+ self.vocab_size = decoder.vocab_size
+
+ assert beam_size <= self.vocab_size, (
+ "beam_size (%d) should be smaller than or equal to vocabulary size (%d)."
+ % (
+ beam_size,
+ self.vocab_size,
+ )
+ )
+ self.beam_size = beam_size
+
+ if search_type == "default":
+ self.search_algorithm = self.default_beam_search
+ elif search_type == "tsd":
+ assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % (
+ max_sym_exp
+ )
+ self.max_sym_exp = max_sym_exp
+
+ self.search_algorithm = self.time_sync_decoding
+ elif search_type == "alsd":
+ assert not streaming, "ALSD is not available in streaming mode."
+
+ assert u_max >= 0, "u_max should be a positive integer, a portion of max_T."
+ self.u_max = u_max
+
+ self.search_algorithm = self.align_length_sync_decoding
+ elif search_type == "maes":
+ assert self.vocab_size >= beam_size + expansion_beta, (
+ "beam_size (%d) + expansion_beta (%d) "
+ " should be smaller than or equal to vocab size (%d)."
+ % (beam_size, expansion_beta, self.vocab_size)
+ )
+ self.max_candidates = beam_size + expansion_beta
+
+ self.nstep = nstep
+ self.expansion_gamma = expansion_gamma
+
+ self.search_algorithm = self.modified_adaptive_expansion_search
+ else:
+ raise NotImplementedError(
+ "Specified search type (%s) is not supported." % search_type
+ )
+
+ self.use_lm = lm is not None
+
+ if self.use_lm:
+ assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported."
+
+ self.sos = self.vocab_size - 1
+
+ self.lm = lm
+ self.lm_weight = lm_weight
+
+ self.score_norm = score_norm
+ self.nbest = nbest
+
+ self.reset_inference_cache()
+
+ def __call__(
+ self,
+ enc_out: torch.Tensor,
+ is_final: bool = True,
+ ) -> List[Hypothesis]:
+ """Perform beam search.
+
+ Args:
+ enc_out: Encoder output sequence. (T, D_enc)
+ is_final: Whether enc_out is the final chunk of data.
+
+ Returns:
+ nbest_hyps: N-best decoding results
+
+ """
+ self.decoder.set_device(enc_out.device)
+
+ hyps = self.search_algorithm(enc_out)
+
+ if is_final:
+ self.reset_inference_cache()
+
+ return self.sort_nbest(hyps)
+
+ self.search_cache = hyps
+
+ return hyps
+
+ def reset_inference_cache(self) -> None:
+ """Reset cache for decoder scoring and streaming."""
+ self.decoder.score_cache = {}
+ self.search_cache = None
+
+ def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
+ """Sort in-place hypotheses by score or score given sequence length.
+
+ Args:
+ hyps: Hypothesis.
+
+ Return:
+ hyps: Sorted hypothesis.
+
+ """
+ if self.score_norm:
+ hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True)
+ else:
+ hyps.sort(key=lambda x: x.score, reverse=True)
+
+ return hyps[: self.nbest]
+
+ def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]:
+ """Recombine hypotheses with same label ID sequence.
+
+ Args:
+ hyps: Hypotheses.
+
+ Returns:
+ final: Recombined hypotheses.
+
+ """
+ final = {}
+
+ for hyp in hyps:
+ str_yseq = "_".join(map(str, hyp.yseq))
+
+ if str_yseq in final:
+ final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score)
+ else:
+ final[str_yseq] = hyp
+
+ return [*final.values()]
+
+ def select_k_expansions(
+ self,
+ hyps: List[ExtendedHypothesis],
+ topk_idx: torch.Tensor,
+ topk_logp: torch.Tensor,
+ ) -> List[ExtendedHypothesis]:
+ """Return K hypotheses candidates for expansion from a list of hypothesis.
+
+ K candidates are selected according to the extended hypotheses probabilities
+ and a prune-by-value method. Where K is equal to beam_size + beta.
+
+ Args:
+ hyps: Hypotheses.
+ topk_idx: Indices of candidates hypothesis.
+ topk_logp: Log-probabilities of candidates hypothesis.
+
+ Returns:
+ k_expansions: Best K expansion hypotheses candidates.
+
+ """
+ k_expansions = []
+
+ for i, hyp in enumerate(hyps):
+ hyp_i = [
+ (int(k), hyp.score + float(v))
+ for k, v in zip(topk_idx[i], topk_logp[i])
+ ]
+ k_best_exp = max(hyp_i, key=lambda x: x[1])[1]
+
+ k_expansions.append(
+ sorted(
+ filter(
+ lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i
+ ),
+ key=lambda x: x[1],
+ reverse=True,
+ )
+ )
+
+ return k_expansions
+
+ def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor:
+ """Make batch of inputs with left padding for LM scoring.
+
+ Args:
+ hyps_seq: Hypothesis sequences.
+
+ Returns:
+ : Padded batch of sequences.
+
+ """
+ max_len = max([len(h) for h in hyps_seq])
+
+ return torch.LongTensor(
+ [[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq],
+ device=self.decoder.device,
+ )
+
+ def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]:
+ """Beam search implementation without prefix search.
+
+ Modified from https://arxiv.org/pdf/1211.3711.pdf
+
+ Args:
+ enc_out: Encoder output sequence. (T, D)
+
+ Returns:
+ nbest_hyps: N-best hypothesis.
+
+ """
+ beam_k = min(self.beam_size, (self.vocab_size - 1))
+ max_t = len(enc_out)
+
+ if self.search_cache is not None:
+ kept_hyps = self.search_cache
+ else:
+ kept_hyps = [
+ Hypothesis(
+ score=0.0,
+ yseq=[0],
+ dec_state=self.decoder.init_state(1),
+ )
+ ]
+
+ for t in range(max_t):
+ hyps = kept_hyps
+ kept_hyps = []
+
+ while True:
+ max_hyp = max(hyps, key=lambda x: x.score)
+ hyps.remove(max_hyp)
+
+ label = torch.full(
+ (1, 1),
+ max_hyp.yseq[-1],
+ dtype=torch.long,
+ device=self.decoder.device,
+ )
+ dec_out, state = self.decoder.score(
+ label,
+ max_hyp.yseq,
+ max_hyp.dec_state,
+ )
+
+ logp = torch.log_softmax(
+ self.joint_network(enc_out[t : t + 1, :], dec_out),
+ dim=-1,
+ ).squeeze(0)
+ top_k = logp[1:].topk(beam_k, dim=-1)
+
+ kept_hyps.append(
+ Hypothesis(
+ score=(max_hyp.score + float(logp[0:1])),
+ yseq=max_hyp.yseq,
+ dec_state=max_hyp.dec_state,
+ lm_state=max_hyp.lm_state,
+ )
+ )
+
+ if self.use_lm:
+ lm_scores, lm_state = self.lm.score(
+ torch.LongTensor(
+ [self.sos] + max_hyp.yseq[1:], device=self.decoder.device
+ ),
+ max_hyp.lm_state,
+ None,
+ )
+ else:
+ lm_state = max_hyp.lm_state
+
+ for logp, k in zip(*top_k):
+ score = max_hyp.score + float(logp)
+
+ if self.use_lm:
+ score += self.lm_weight * lm_scores[k + 1]
+
+ hyps.append(
+ Hypothesis(
+ score=score,
+ yseq=max_hyp.yseq + [int(k + 1)],
+ dec_state=state,
+ lm_state=lm_state,
+ )
+ )
+
+ hyps_max = float(max(hyps, key=lambda x: x.score).score)
+ kept_most_prob = sorted(
+ [hyp for hyp in kept_hyps if hyp.score > hyps_max],
+ key=lambda x: x.score,
+ )
+ if len(kept_most_prob) >= self.beam_size:
+ kept_hyps = kept_most_prob
+ break
+
+ return kept_hyps
+
+ def align_length_sync_decoding(
+ self,
+ enc_out: torch.Tensor,
+ ) -> List[Hypothesis]:
+ """Alignment-length synchronous beam search implementation.
+
+ Based on https://ieeexplore.ieee.org/document/9053040
+
+ Args:
+ h: Encoder output sequences. (T, D)
+
+ Returns:
+ nbest_hyps: N-best hypothesis.
+
+ """
+ t_max = int(enc_out.size(0))
+ u_max = min(self.u_max, (t_max - 1))
+
+ B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))]
+ final = []
+
+ if self.use_lm:
+ B[0].lm_state = self.lm.zero_state()
+
+ for i in range(t_max + u_max):
+ A = []
+
+ B_ = []
+ B_enc_out = []
+ for hyp in B:
+ u = len(hyp.yseq) - 1
+ t = i - u
+
+ if t > (t_max - 1):
+ continue
+
+ B_.append(hyp)
+ B_enc_out.append((t, enc_out[t]))
+
+ if B_:
+ beam_enc_out = torch.stack([b[1] for b in B_enc_out])
+ beam_dec_out, beam_state = self.decoder.batch_score(B_)
+
+ beam_logp = torch.log_softmax(
+ self.joint_network(beam_enc_out, beam_dec_out),
+ dim=-1,
+ )
+ beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
+
+ if self.use_lm:
+ beam_lm_scores, beam_lm_states = self.lm.batch_score(
+ self.create_lm_batch_inputs([b.yseq for b in B_]),
+ [b.lm_state for b in B_],
+ None,
+ )
+
+ for i, hyp in enumerate(B_):
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(beam_logp[i, 0])),
+ yseq=hyp.yseq[:],
+ dec_state=hyp.dec_state,
+ lm_state=hyp.lm_state,
+ )
+
+ A.append(new_hyp)
+
+ if B_enc_out[i][0] == (t_max - 1):
+ final.append(new_hyp)
+
+ for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(logp)),
+ yseq=(hyp.yseq[:] + [int(k)]),
+ dec_state=self.decoder.select_state(beam_state, i),
+ lm_state=hyp.lm_state,
+ )
+
+ if self.use_lm:
+ new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
+ new_hyp.lm_state = beam_lm_states[i]
+
+ A.append(new_hyp)
+
+ B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
+ B = self.recombine_hyps(B)
+
+ if final:
+ return final
+
+ return B
+
+ def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]:
+ """Time synchronous beam search implementation.
+
+ Based on https://ieeexplore.ieee.org/document/9053040
+
+ Args:
+ enc_out: Encoder output sequence. (T, D)
+
+ Returns:
+ nbest_hyps: N-best hypothesis.
+
+ """
+ if self.search_cache is not None:
+ B = self.search_cache
+ else:
+ B = [
+ Hypothesis(
+ yseq=[0],
+ score=0.0,
+ dec_state=self.decoder.init_state(1),
+ )
+ ]
+
+ if self.use_lm:
+ B[0].lm_state = self.lm.zero_state()
+
+ for enc_out_t in enc_out:
+ A = []
+ C = B
+
+ enc_out_t = enc_out_t.unsqueeze(0)
+
+ for v in range(self.max_sym_exp):
+ D = []
+
+ beam_dec_out, beam_state = self.decoder.batch_score(C)
+
+ beam_logp = torch.log_softmax(
+ self.joint_network(enc_out_t, beam_dec_out),
+ dim=-1,
+ )
+ beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1)
+
+ seq_A = [h.yseq for h in A]
+
+ for i, hyp in enumerate(C):
+ if hyp.yseq not in seq_A:
+ A.append(
+ Hypothesis(
+ score=(hyp.score + float(beam_logp[i, 0])),
+ yseq=hyp.yseq[:],
+ dec_state=hyp.dec_state,
+ lm_state=hyp.lm_state,
+ )
+ )
+ else:
+ dict_pos = seq_A.index(hyp.yseq)
+
+ A[dict_pos].score = np.logaddexp(
+ A[dict_pos].score, (hyp.score + float(beam_logp[i, 0]))
+ )
+
+ if v < (self.max_sym_exp - 1):
+ if self.use_lm:
+ beam_lm_scores, beam_lm_states = self.lm.batch_score(
+ self.create_lm_batch_inputs([c.yseq for c in C]),
+ [c.lm_state for c in C],
+ None,
+ )
+
+ for i, hyp in enumerate(C):
+ for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1):
+ new_hyp = Hypothesis(
+ score=(hyp.score + float(logp)),
+ yseq=(hyp.yseq + [int(k)]),
+ dec_state=self.decoder.select_state(beam_state, i),
+ lm_state=hyp.lm_state,
+ )
+
+ if self.use_lm:
+ new_hyp.score += self.lm_weight * beam_lm_scores[i, k]
+ new_hyp.lm_state = beam_lm_states[i]
+
+ D.append(new_hyp)
+
+ C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size]
+
+ B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size]
+
+ return B
+
+ def modified_adaptive_expansion_search(
+ self,
+ enc_out: torch.Tensor,
+ ) -> List[ExtendedHypothesis]:
+ """Modified version of Adaptive Expansion Search (mAES).
+
+ Based on AES (https://ieeexplore.ieee.org/document/9250505) and
+ NSC (https://arxiv.org/abs/2201.05420).
+
+ Args:
+ enc_out: Encoder output sequence. (T, D_enc)
+
+ Returns:
+ nbest_hyps: N-best hypothesis.
+
+ """
+ if self.search_cache is not None:
+ kept_hyps = self.search_cache
+ else:
+ init_tokens = [
+ ExtendedHypothesis(
+ yseq=[0],
+ score=0.0,
+ dec_state=self.decoder.init_state(1),
+ )
+ ]
+
+ beam_dec_out, beam_state = self.decoder.batch_score(
+ init_tokens,
+ )
+
+ if self.use_lm:
+ beam_lm_scores, beam_lm_states = self.lm.batch_score(
+ self.create_lm_batch_inputs([h.yseq for h in init_tokens]),
+ [h.lm_state for h in init_tokens],
+ None,
+ )
+
+ lm_state = beam_lm_states[0]
+ lm_score = beam_lm_scores[0]
+ else:
+ lm_state = None
+ lm_score = None
+
+ kept_hyps = [
+ ExtendedHypothesis(
+ yseq=[0],
+ score=0.0,
+ dec_state=self.decoder.select_state(beam_state, 0),
+ dec_out=beam_dec_out[0],
+ lm_state=lm_state,
+ lm_score=lm_score,
+ )
+ ]
+
+ for enc_out_t in enc_out:
+ hyps = kept_hyps
+ kept_hyps = []
+
+ beam_enc_out = enc_out_t.unsqueeze(0)
+
+ list_b = []
+ for n in range(self.nstep):
+ beam_dec_out = torch.stack([h.dec_out for h in hyps])
+
+ beam_logp, beam_idx = torch.log_softmax(
+ self.joint_network(beam_enc_out, beam_dec_out),
+ dim=-1,
+ ).topk(self.max_candidates, dim=-1)
+
+ k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp)
+
+ list_exp = []
+ for i, hyp in enumerate(hyps):
+ for k, new_score in k_expansions[i]:
+ new_hyp = ExtendedHypothesis(
+ yseq=hyp.yseq[:],
+ score=new_score,
+ dec_out=hyp.dec_out,
+ dec_state=hyp.dec_state,
+ lm_state=hyp.lm_state,
+ lm_score=hyp.lm_score,
+ )
+
+ if k == 0:
+ list_b.append(new_hyp)
+ else:
+ new_hyp.yseq.append(int(k))
+
+ if self.use_lm:
+ new_hyp.score += self.lm_weight * float(hyp.lm_score[k])
+
+ list_exp.append(new_hyp)
+
+ if not list_exp:
+ kept_hyps = sorted(
+ self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True
+ )[: self.beam_size]
+
+ break
+ else:
+ beam_dec_out, beam_state = self.decoder.batch_score(
+ list_exp,
+ )
+
+ if self.use_lm:
+ beam_lm_scores, beam_lm_states = self.lm.batch_score(
+ self.create_lm_batch_inputs([h.yseq for h in list_exp]),
+ [h.lm_state for h in list_exp],
+ None,
+ )
+
+ if n < (self.nstep - 1):
+ for i, hyp in enumerate(list_exp):
+ hyp.dec_out = beam_dec_out[i]
+ hyp.dec_state = self.decoder.select_state(beam_state, i)
+
+ if self.use_lm:
+ hyp.lm_state = beam_lm_states[i]
+ hyp.lm_score = beam_lm_scores[i]
+
+ hyps = list_exp[:]
+ else:
+ beam_logp = torch.log_softmax(
+ self.joint_network(beam_enc_out, beam_dec_out),
+ dim=-1,
+ )
+
+ for i, hyp in enumerate(list_exp):
+ hyp.score += float(beam_logp[i, 0])
+
+ hyp.dec_out = beam_dec_out[i]
+ hyp.dec_state = self.decoder.select_state(beam_state, i)
+
+ if self.use_lm:
+ hyp.lm_state = beam_lm_states[i]
+ hyp.lm_score = beam_lm_scores[i]
+
+ kept_hyps = sorted(
+ self.recombine_hyps(list_b + list_exp),
+ key=lambda x: x.score,
+ reverse=True,
+ )[: self.beam_size]
+
+ return kept_hyps
diff --git a/funasr/modules/e2e_asr_common.py b/funasr/modules/e2e_asr_common.py
index 92f9079..f430fcb 100644
--- a/funasr/modules/e2e_asr_common.py
+++ b/funasr/modules/e2e_asr_common.py
@@ -6,6 +6,8 @@
"""Common functions for ASR."""
+from typing import List, Optional, Tuple
+
import json
import logging
import sys
@@ -13,7 +15,10 @@
from itertools import groupby
import numpy as np
import six
+import torch
+from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer
+from funasr.models.joint_net.joint_network import JointNetwork
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
"""End detection.
@@ -247,3 +252,148 @@
word_eds.append(editdistance.eval(hyp_words, ref_words))
word_ref_lens.append(len(ref_words))
return float(sum(word_eds)) / sum(word_ref_lens)
+
+class ErrorCalculatorTransducer:
+ """Calculate CER and WER for transducer models.
+ Args:
+ decoder: Decoder module.
+ joint_network: Joint Network module.
+ token_list: List of token units.
+ sym_space: Space symbol.
+ sym_blank: Blank symbol.
+ report_cer: Whether to compute CER.
+ report_wer: Whether to compute WER.
+ """
+
+ def __init__(
+ self,
+ decoder,
+ joint_network: JointNetwork,
+ token_list: List[int],
+ sym_space: str,
+ sym_blank: str,
+ report_cer: bool = False,
+ report_wer: bool = False,
+ ) -> None:
+ """Construct an ErrorCalculatorTransducer object."""
+ super().__init__()
+
+ self.beam_search = BeamSearchTransducer(
+ decoder=decoder,
+ joint_network=joint_network,
+ beam_size=1,
+ search_type="default",
+ score_norm=False,
+ )
+
+ self.decoder = decoder
+
+ self.token_list = token_list
+ self.space = sym_space
+ self.blank = sym_blank
+
+ self.report_cer = report_cer
+ self.report_wer = report_wer
+
+ def __call__(
+ self, encoder_out: torch.Tensor, target: torch.Tensor
+ ) -> Tuple[Optional[float], Optional[float]]:
+ """Calculate sentence-level WER or/and CER score for Transducer model.
+ Args:
+ encoder_out: Encoder output sequences. (B, T, D_enc)
+ target: Target label ID sequences. (B, L)
+ Returns:
+ : Sentence-level CER score.
+ : Sentence-level WER score.
+ """
+ cer, wer = None, None
+
+ batchsize = int(encoder_out.size(0))
+
+ encoder_out = encoder_out.to(next(self.decoder.parameters()).device)
+
+ batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)]
+ pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest]
+
+ char_pred, char_target = self.convert_to_char(pred, target)
+
+ if self.report_cer:
+ cer = self.calculate_cer(char_pred, char_target)
+
+ if self.report_wer:
+ wer = self.calculate_wer(char_pred, char_target)
+
+ return cer, wer
+
+ def convert_to_char(
+ self, pred: torch.Tensor, target: torch.Tensor
+ ) -> Tuple[List, List]:
+ """Convert label ID sequences to character sequences.
+ Args:
+ pred: Prediction label ID sequences. (B, U)
+ target: Target label ID sequences. (B, L)
+ Returns:
+ char_pred: Prediction character sequences. (B, ?)
+ char_target: Target character sequences. (B, ?)
+ """
+ char_pred, char_target = [], []
+
+ for i, pred_i in enumerate(pred):
+ char_pred_i = [self.token_list[int(h)] for h in pred_i]
+ char_target_i = [self.token_list[int(r)] for r in target[i]]
+
+ char_pred_i = "".join(char_pred_i).replace(self.space, " ")
+ char_pred_i = char_pred_i.replace(self.blank, "")
+
+ char_target_i = "".join(char_target_i).replace(self.space, " ")
+ char_target_i = char_target_i.replace(self.blank, "")
+
+ char_pred.append(char_pred_i)
+ char_target.append(char_target_i)
+
+ return char_pred, char_target
+
+ def calculate_cer(
+ self, char_pred: torch.Tensor, char_target: torch.Tensor
+ ) -> float:
+ """Calculate sentence-level CER score.
+ Args:
+ char_pred: Prediction character sequences. (B, ?)
+ char_target: Target character sequences. (B, ?)
+ Returns:
+ : Average sentence-level CER score.
+ """
+ import editdistance
+
+ distances, lens = [], []
+
+ for i, char_pred_i in enumerate(char_pred):
+ pred = char_pred_i.replace(" ", "")
+ target = char_target[i].replace(" ", "")
+ distances.append(editdistance.eval(pred, target))
+ lens.append(len(target))
+
+ return float(sum(distances)) / sum(lens)
+
+ def calculate_wer(
+ self, char_pred: torch.Tensor, char_target: torch.Tensor
+ ) -> float:
+ """Calculate sentence-level WER score.
+ Args:
+ char_pred: Prediction character sequences. (B, ?)
+ char_target: Target character sequences. (B, ?)
+ Returns:
+ : Average sentence-level WER score
+ """
+ import editdistance
+
+ distances, lens = [], []
+
+ for i, char_pred_i in enumerate(char_pred):
+ pred = char_pred_i.replace("鈻�", " ").split()
+ target = char_target[i].replace("鈻�", " ").split()
+
+ distances.append(editdistance.eval(pred, target))
+ lens.append(len(target))
+
+ return float(sum(distances)) / sum(lens)
diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py
index 4b292a7..c347e24 100644
--- a/funasr/modules/embedding.py
+++ b/funasr/modules/embedding.py
@@ -440,4 +440,79 @@
outputs = F.pad(outputs, (pad_left, pad_right))
outputs = outputs.transpose(1, 2)
return outputs
-
+
+class StreamingRelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding.
+ Args:
+ size: Module size.
+ max_len: Maximum input length.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
+ ) -> None:
+ """Construct a RelativePositionalEncoding object."""
+ super().__init__()
+
+ self.size = size
+
+ self.pe = None
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+ self._register_load_state_dict_pre_hook(_pre_hook)
+
+ def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None:
+ """Reset positional encoding.
+ Args:
+ x: Input sequences. (B, T, ?)
+ left_context: Number of frames in left context.
+ """
+ time1 = x.size(1) + left_context
+
+ if self.pe is not None:
+ if self.pe.size(1) >= time1 * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(device=x.device, dtype=x.dtype)
+ return
+
+ pe_positive = torch.zeros(time1, self.size)
+ pe_negative = torch.zeros(time1, self.size)
+
+ position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
+ div_term = torch.exp(
+ torch.arange(0, self.size, 2, dtype=torch.float32)
+ * -(math.log(10000.0) / self.size)
+ )
+
+ pe_positive[:, 0::2] = torch.sin(position * div_term)
+ pe_positive[:, 1::2] = torch.cos(position * div_term)
+ pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
+
+ pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
+ pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
+ pe_negative = pe_negative[1:].unsqueeze(0)
+
+ self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
+ dtype=x.dtype, device=x.device
+ )
+
+ def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+ """Compute positional encoding.
+ Args:
+ x: Input sequences. (B, T, ?)
+ left_context: Number of frames in left context.
+ Returns:
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?)
+ """
+ self.extend_pe(x, left_context=left_context)
+
+ time1 = x.size(1) + left_context
+
+ pos_enc = self.pe[
+ :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
+ ]
+ pos_enc = self.dropout(pos_enc)
+
+ return pos_enc
diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py
index 6d77d69..5d4fe1c 100644
--- a/funasr/modules/nets_utils.py
+++ b/funasr/modules/nets_utils.py
@@ -3,7 +3,7 @@
"""Network related utility tools."""
import logging
-from typing import Dict
+from typing import Dict, List, Tuple
import numpy as np
import torch
@@ -506,3 +506,196 @@
}
return activation_funcs[act]()
+
+class TooShortUttError(Exception):
+ """Raised when the utt is too short for subsampling.
+
+ Args:
+ message: Error message to display.
+ actual_size: The size that cannot pass the subsampling.
+ limit: The size limit for subsampling.
+
+ """
+
+ def __init__(self, message: str, actual_size: int, limit: int) -> None:
+ """Construct a TooShortUttError module."""
+ super().__init__(message)
+
+ self.actual_size = actual_size
+ self.limit = limit
+
+
+def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
+ """Check if the input is too short for subsampling.
+
+ Args:
+ sub_factor: Subsampling factor for Conv2DSubsampling.
+ size: Input size.
+
+ Returns:
+ : Whether an error should be sent.
+ : Size limit for specified subsampling factor.
+
+ """
+ if sub_factor == 2 and size < 3:
+ return True, 7
+ elif sub_factor == 4 and size < 7:
+ return True, 7
+ elif sub_factor == 6 and size < 11:
+ return True, 11
+
+ return False, -1
+
+
+def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
+ """Get conv2D second layer parameters for given subsampling factor.
+
+ Args:
+ sub_factor: Subsampling factor (1/X).
+ input_size: Input size.
+
+ Returns:
+ : Kernel size for second convolution.
+ : Stride for second convolution.
+ : Conv2DSubsampling output size.
+
+ """
+ if sub_factor == 2:
+ return 3, 1, (((input_size - 1) // 2 - 2))
+ elif sub_factor == 4:
+ return 3, 2, (((input_size - 1) // 2 - 1) // 2)
+ elif sub_factor == 6:
+ return 5, 3, (((input_size - 1) // 2 - 2) // 3)
+ else:
+ raise ValueError(
+ "subsampling_factor parameter should be set to either 2, 4 or 6."
+ )
+
+
+def make_chunk_mask(
+ size: int,
+ chunk_size: int,
+ left_chunk_size: int = 0,
+ device: torch.device = None,
+) -> torch.Tensor:
+ """Create chunk mask for the subsequent steps (size, size).
+
+ Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+ Args:
+ size: Size of the source mask.
+ chunk_size: Number of frames in chunk.
+ left_chunk_size: Size of the left context in chunks (0 means full context).
+ device: Device for the mask tensor.
+
+ Returns:
+ mask: Chunk mask. (size, size)
+
+ """
+ mask = torch.zeros(size, size, device=device, dtype=torch.bool)
+
+ for i in range(size):
+ if left_chunk_size <= 0:
+ start = 0
+ else:
+ start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
+
+ end = min((i // chunk_size + 1) * chunk_size, size)
+ mask[i, start:end] = True
+
+ return ~mask
+
+def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
+ """Create source mask for given lengths.
+
+ Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
+
+ Args:
+ lengths: Sequence lengths. (B,)
+
+ Returns:
+ : Mask for the sequence lengths. (B, max_len)
+
+ """
+ max_len = lengths.max()
+ batch_size = lengths.size(0)
+
+ expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
+
+ return expanded_lengths >= lengths.unsqueeze(1)
+
+
+def get_transducer_task_io(
+ labels: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Get Transducer loss I/O.
+
+ Args:
+ labels: Label ID sequences. (B, L)
+ encoder_out_lens: Encoder output lengths. (B,)
+ ignore_id: Padding symbol ID.
+ blank_id: Blank symbol ID.
+
+ Returns:
+ decoder_in: Decoder inputs. (B, U)
+ target: Target label ID sequences. (B, U)
+ t_len: Time lengths. (B,)
+ u_len: Label lengths. (B,)
+
+ """
+
+ def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
+ """Create padded batch of labels from a list of labels sequences.
+
+ Args:
+ labels: Labels sequences. [B x (?)]
+ padding_value: Padding value.
+
+ Returns:
+ labels: Batch of padded labels sequences. (B,)
+
+ """
+ batch_size = len(labels)
+
+ padded = (
+ labels[0]
+ .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
+ .fill_(padding_value)
+ )
+
+ for i in range(batch_size):
+ padded[i, : labels[i].size(0)] = labels[i]
+
+ return padded
+
+ device = labels.device
+
+ labels_unpad = [y[y != ignore_id] for y in labels]
+ blank = labels[0].new([blank_id])
+
+ decoder_in = pad_list(
+ [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
+ ).to(device)
+
+ target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
+
+ encoder_out_lens = list(map(int, encoder_out_lens))
+ t_len = torch.IntTensor(encoder_out_lens).to(device)
+
+ u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
+
+ return decoder_in, target, t_len, u_len
+
+def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
+ """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
+ if t.size(dim) == pad_len:
+ return t
+ else:
+ pad_size = list(t.shape)
+ pad_size[dim] = pad_len - t.size(dim)
+ return torch.cat(
+ [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
+ )
diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py
index a3d2676..2b2dac8 100644
--- a/funasr/modules/repeat.py
+++ b/funasr/modules/repeat.py
@@ -6,6 +6,8 @@
"""Repeat the same layer definition."""
+from typing import Dict, List, Optional
+
import torch
@@ -31,3 +33,92 @@
"""
return MultiSequential(*[fn(n) for n in range(N)])
+
+
+class MultiBlocks(torch.nn.Module):
+ """MultiBlocks definition.
+ Args:
+ block_list: Individual blocks of the encoder architecture.
+ output_size: Architecture output size.
+ norm_class: Normalization module class.
+ norm_args: Normalization module arguments.
+ """
+
+ def __init__(
+ self,
+ block_list: List[torch.nn.Module],
+ output_size: int,
+ norm_class: torch.nn.Module = torch.nn.LayerNorm,
+ ) -> None:
+ """Construct a MultiBlocks object."""
+ super().__init__()
+
+ self.blocks = torch.nn.ModuleList(block_list)
+ self.norm_blocks = norm_class(output_size)
+
+ self.num_blocks = len(block_list)
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset encoder streaming cache.
+ Args:
+ left_context: Number of left frames during chunk-by-chunk inference.
+ device: Device to use for cache tensor.
+ """
+ for idx in range(self.num_blocks):
+ self.blocks[idx].reset_streaming_cache(left_context, device)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Forward each block of the encoder architecture.
+ Args:
+ x: MultiBlocks input sequences. (B, T, D_block_1)
+ pos_enc: Positional embedding sequences.
+ mask: Source mask. (B, T)
+ chunk_mask: Chunk mask. (T_2, T_2)
+ Returns:
+ x: Output sequences. (B, T, D_block_N)
+ """
+ for block_index, block in enumerate(self.blocks):
+ x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask)
+
+ x = self.norm_blocks(x)
+
+ return x
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: int = 0,
+ left_context: int = 0,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ """Forward each block of the encoder architecture.
+ Args:
+ x: MultiBlocks input sequences. (B, T, D_block_1)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att)
+ mask: Source mask. (B, T_2)
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: MultiBlocks output sequences. (B, T, D_block_N)
+ """
+ for block_idx, block in enumerate(self.blocks):
+ x, pos_enc = block.chunk_forward(
+ x,
+ pos_enc,
+ mask,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
+ )
+
+ x = self.norm_blocks(x)
+
+ return x
diff --git a/funasr/modules/subsampling.py b/funasr/modules/subsampling.py
index d492ccf..623be65 100644
--- a/funasr/modules/subsampling.py
+++ b/funasr/modules/subsampling.py
@@ -11,6 +11,10 @@
from funasr.modules.embedding import PositionalEncoding
import logging
from funasr.modules.streaming_utils.utils import sequence_mask
+from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
+from typing import Optional, Tuple, Union
+import math
+
class TooShortUttError(Exception):
"""Raised when the utt is too short for subsampling.
@@ -407,3 +411,201 @@
var_dict_tf[name_tf].shape))
return var_dict_torch_update
+class StreamingConvInput(torch.nn.Module):
+ """Streaming ConvInput module definition.
+ Args:
+ input_size: Input size.
+ conv_size: Convolution size.
+ subsampling_factor: Subsampling factor.
+ vgg_like: Whether to use a VGG-like network.
+ output_size: Block output dimension.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ conv_size: Union[int, Tuple],
+ subsampling_factor: int = 4,
+ vgg_like: bool = True,
+ output_size: Optional[int] = None,
+ ) -> None:
+ """Construct a ConvInput object."""
+ super().__init__()
+ if vgg_like:
+ if subsampling_factor == 1:
+ conv_size1, conv_size2 = conv_size
+
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((1, 2)),
+ torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((1, 2)),
+ )
+
+ output_proj = conv_size2 * ((input_size // 2) // 2)
+
+ self.subsampling_factor = 1
+
+ self.stride_1 = 1
+
+ self.create_new_mask = self.create_new_vgg_mask
+
+ else:
+ conv_size1, conv_size2 = conv_size
+
+ kernel_1 = int(subsampling_factor / 2)
+
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((kernel_1, 2)),
+ torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
+ torch.nn.ReLU(),
+ torch.nn.MaxPool2d((2, 2)),
+ )
+
+ output_proj = conv_size2 * ((input_size // 2) // 2)
+
+ self.subsampling_factor = subsampling_factor
+
+ self.create_new_mask = self.create_new_vgg_mask
+
+ self.stride_1 = kernel_1
+
+ else:
+ if subsampling_factor == 1:
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
+ torch.nn.ReLU(),
+ )
+
+ output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
+
+ self.subsampling_factor = subsampling_factor
+ self.kernel_2 = 3
+ self.stride_2 = 1
+
+ self.create_new_mask = self.create_new_conv2d_mask
+
+ else:
+ kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
+ subsampling_factor,
+ input_size,
+ )
+
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size, 3, 2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
+ torch.nn.ReLU(),
+ )
+
+ output_proj = conv_size * conv_2_output_size
+
+ self.subsampling_factor = subsampling_factor
+ self.kernel_2 = kernel_2
+ self.stride_2 = stride_2
+
+ self.create_new_mask = self.create_new_conv2d_mask
+
+ self.vgg_like = vgg_like
+ self.min_frame_length = 7
+
+ if output_size is not None:
+ self.output = torch.nn.Linear(output_proj, output_size)
+ self.output_size = output_size
+ else:
+ self.output = None
+ self.output_size = output_proj
+
+ def forward(
+ self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: ConvInput input sequences. (B, T, D_feats)
+ mask: Mask of input sequences. (B, 1, T)
+ Returns:
+ x: ConvInput output sequences. (B, sub(T), D_out)
+ mask: Mask of output sequences. (B, 1, sub(T))
+ """
+ if mask is not None:
+ mask = self.create_new_mask(mask)
+ olens = max(mask.eq(0).sum(1))
+
+ b, t, f = x.size()
+ x = x.unsqueeze(1) # (b. 1. t. f)
+
+ if chunk_size is not None:
+ max_input_length = int(
+ chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
+ )
+ x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
+ x = list(x)
+ x = torch.stack(x, dim=0)
+ N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
+ x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
+
+ x = self.conv(x)
+
+ _, c, _, f = x.size()
+ if chunk_size is not None:
+ x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
+ else:
+ x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
+
+ if self.output is not None:
+ x = self.output(x)
+
+ return x, mask[:,:olens][:,:x.size(1)]
+
+ def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ """Create a new mask for VGG output sequences.
+ Args:
+ mask: Mask of input sequences. (B, T)
+ Returns:
+ mask: Mask of output sequences. (B, sub(T))
+ """
+ if self.subsampling_factor > 1:
+ vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
+ mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
+
+ vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
+ mask = mask[:, :vgg2_t_len][:, ::2]
+ else:
+ mask = mask
+
+ return mask
+
+ def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ """Create new conformer mask for Conv2d output sequences.
+ Args:
+ mask: Mask of input sequences. (B, T)
+ Returns:
+ mask: Mask of output sequences. (B, sub(T))
+ """
+ if self.subsampling_factor > 1:
+ return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
+ else:
+ return mask
+
+ def get_size_before_subsampling(self, size: int) -> int:
+ """Return the original size before subsampling for a given size.
+ Args:
+ size: Number of frames after subsampling.
+ Returns:
+ : Number of frames before subsampling.
+ """
+ return size * self.subsampling_factor
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 52a0ce7..d52c9c3 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -38,13 +38,16 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
+from funasr.models.decoder.rnnt_decoder import RNNTDecoder
+from funasr.models.joint_net.joint_network import JointNetwork
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer
from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
+from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
@@ -151,6 +154,7 @@
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
+ chunk_conformer=ConformerChunkEncoder,
),
type_check=AbsEncoder,
default="rnn",
@@ -208,6 +212,16 @@
type_check=AbsDecoder,
default="rnn",
)
+
+rnnt_decoder_choices = ClassChoices(
+ "rnnt_decoder",
+ classes=dict(
+ rnnt=RNNTDecoder,
+ ),
+ type_check=RNNTDecoder,
+ default="rnnt",
+)
+
predictor_choices = ClassChoices(
name="predictor",
classes=dict(
@@ -1332,3 +1346,378 @@
) -> Tuple[str, ...]:
retval = ("speech", "text")
return retval
+
+
+class ASRTransducerTask(AbsTask):
+ """ASR Transducer Task definition."""
+
+ num_optimizers: int = 1
+
+ class_choices_list = [
+ frontend_choices,
+ specaug_choices,
+ normalize_choices,
+ encoder_choices,
+ rnnt_decoder_choices,
+ ]
+
+ trainer = Trainer
+
+ @classmethod
+ def add_task_arguments(cls, parser: argparse.ArgumentParser):
+ """Add Transducer task arguments.
+ Args:
+ cls: ASRTransducerTask object.
+ parser: Transducer arguments parser.
+ """
+ group = parser.add_argument_group(description="Task related.")
+
+ # required = parser.get_default("required")
+ # required += ["token_list"]
+
+ group.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="Integer-string mapper for tokens.",
+ )
+ group.add_argument(
+ "--split_with_space",
+ type=str2bool,
+ default=True,
+ help="whether to split text using <space>",
+ )
+ group.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of dimensions for input features.",
+ )
+ group.add_argument(
+ "--init",
+ type=str_or_none,
+ default=None,
+ help="Type of model initialization to use.",
+ )
+ group.add_argument(
+ "--model_conf",
+ action=NestedDictAction,
+ default=get_default_kwargs(TransducerModel),
+ help="The keyword arguments for the model class.",
+ )
+ # group.add_argument(
+ # "--encoder_conf",
+ # action=NestedDictAction,
+ # default={},
+ # help="The keyword arguments for the encoder class.",
+ # )
+ group.add_argument(
+ "--joint_network_conf",
+ action=NestedDictAction,
+ default={},
+ help="The keyword arguments for the joint network class.",
+ )
+ group = parser.add_argument_group(description="Preprocess related.")
+ group.add_argument(
+ "--use_preprocessor",
+ type=str2bool,
+ default=True,
+ help="Whether to apply preprocessing to input data.",
+ )
+ group.add_argument(
+ "--token_type",
+ type=str,
+ default="bpe",
+ choices=["bpe", "char", "word", "phn"],
+ help="The type of tokens to use during tokenization.",
+ )
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The path of the sentencepiece model.",
+ )
+ parser.add_argument(
+ "--non_linguistic_symbols",
+ type=str_or_none,
+ help="The 'non_linguistic_symbols' file path.",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese"],
+ default=None,
+ help="Text cleaner to use.",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="g2p method to use if --token_type=phn.",
+ )
+ parser.add_argument(
+ "--speech_volume_normalize",
+ type=float_or_none,
+ default=None,
+ help="Normalization value for maximum amplitude scaling.",
+ )
+ parser.add_argument(
+ "--rir_scp",
+ type=str_or_none,
+ default=None,
+ help="The RIR SCP file path.",
+ )
+ parser.add_argument(
+ "--rir_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability of the applied RIR convolution.",
+ )
+ parser.add_argument(
+ "--noise_scp",
+ type=str_or_none,
+ default=None,
+ help="The path of noise SCP file.",
+ )
+ parser.add_argument(
+ "--noise_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability of the applied noise addition.",
+ )
+ parser.add_argument(
+ "--noise_db_range",
+ type=str,
+ default="13_15",
+ help="The range of the noise decibel level.",
+ )
+ for class_choices in cls.class_choices_list:
+ # Append --<name> and --<name>_conf.
+ # e.g. --decoder and --decoder_conf
+ class_choices.add_arguments(group)
+
+ @classmethod
+ def build_collate_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Callable[
+ [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+ Tuple[List[str], Dict[str, torch.Tensor]],
+ ]:
+ """Build collate function.
+ Args:
+ cls: ASRTransducerTask object.
+ args: Task arguments.
+ train: Training mode.
+ Return:
+ : Callable collate function.
+ """
+ assert check_argument_types()
+
+ return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+
+ @classmethod
+ def build_preprocess_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+ """Build pre-processing function.
+ Args:
+ cls: ASRTransducerTask object.
+ args: Task arguments.
+ train: Training mode.
+ Return:
+ : Callable pre-processing function.
+ """
+ assert check_argument_types()
+
+ if args.use_preprocessor:
+ retval = CommonPreprocessor(
+ train=train,
+ token_type=args.token_type,
+ token_list=args.token_list,
+ bpemodel=args.bpemodel,
+ non_linguistic_symbols=args.non_linguistic_symbols,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False,
+ rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+ rir_apply_prob=args.rir_apply_prob
+ if hasattr(args, "rir_apply_prob")
+ else 1.0,
+ noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+ noise_apply_prob=args.noise_apply_prob
+ if hasattr(args, "noise_apply_prob")
+ else 1.0,
+ noise_db_range=args.noise_db_range
+ if hasattr(args, "noise_db_range")
+ else "13_15",
+ speech_volume_normalize=args.speech_volume_normalize
+ if hasattr(args, "rir_scp")
+ else None,
+ )
+ else:
+ retval = None
+
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def required_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ """Required data depending on task mode.
+ Args:
+ cls: ASRTransducerTask object.
+ train: Training mode.
+ inference: Inference mode.
+ Return:
+ retval: Required task data.
+ """
+ if not inference:
+ retval = ("speech", "text")
+ else:
+ retval = ("speech",)
+
+ return retval
+
+ @classmethod
+ def optional_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ """Optional data depending on task mode.
+ Args:
+ cls: ASRTransducerTask object.
+ train: Training mode.
+ inference: Inference mode.
+ Return:
+ retval: Optional task data.
+ """
+ retval = ()
+ assert check_return_type(retval)
+
+ return retval
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace) -> TransducerModel:
+ """Required data depending on task mode.
+ Args:
+ cls: ASRTransducerTask object.
+ args: Task arguments.
+ Return:
+ model: ASR Transducer model.
+ """
+ assert check_argument_types()
+
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+ vocab_size = len(token_list)
+ logging.info(f"Vocabulary size: {vocab_size }")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Encoder
+
+ if getattr(args, "encoder", None) is not None:
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size, **args.encoder_conf)
+ else:
+ encoder = Encoder(input_size, **args.encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ # 5. Decoder
+ rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder)
+ decoder = rnnt_decoder_class(
+ vocab_size,
+ **args.rnnt_decoder_conf,
+ )
+ decoder_output_size = decoder.output_size
+
+ if getattr(args, "decoder", None) is not None:
+ att_decoder_class = decoder_choices.get_class(args.att_decoder)
+
+ att_decoder = att_decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **args.decoder_conf,
+ )
+ else:
+ att_decoder = None
+ # 6. Joint Network
+ joint_network = JointNetwork(
+ vocab_size,
+ encoder_output_size,
+ decoder_output_size,
+ **args.joint_network_conf,
+ )
+
+ # 7. Build model
+
+ if encoder.unified_model_training:
+ model = UnifiedTransducerModel(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ att_decoder=att_decoder,
+ joint_network=joint_network,
+ **args.model_conf,
+ )
+
+ else:
+ model = TransducerModel(
+ vocab_size=vocab_size,
+ token_list=token_list,
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ encoder=encoder,
+ decoder=decoder,
+ att_decoder=att_decoder,
+ joint_network=joint_network,
+ **args.model_conf,
+ )
+
+ # 8. Initialize model
+ if args.init is not None:
+ raise NotImplementedError(
+ "Currently not supported.",
+ "Initialization part will be reworked in a short future.",
+ )
+
+ #assert check_return_type(model)
+
+ return model
--
Gitblit v1.9.1