Merge pull request #351 from alibaba-damo-academy/dev_aky
Dev aky
| New file |
| | |
| | | |
| | | # Streaming RNN-T Result |
| | | |
| | | ## Training Config |
| | | - 8 gpu(Tesla V100) |
| | | - Feature info: using 80 dims fbank, global cmvn, speed perturb(0.9, 1.0, 1.1), specaugment |
| | | - Train config: conf/train_conformer_rnnt_unified.yaml |
| | | - chunk config: chunk size 16, full left chunk |
| | | - LM config: LM was not used |
| | | - Model size: 90M |
| | | |
| | | ## Results (CER) |
| | | - Decode config: conf/train_conformer_rnnt_unified.yaml |
| | | |
| | | | testset | CER(%) | |
| | | |:-----------:|:-------:| |
| | | | dev | 5.53 | |
| | | | test | 6.24 | |
| New file |
| | |
| | | # The conformer transducer decoding configuration from @jeon30c |
| | | beam_size: 10 |
| | | simu_streaming: false |
| | | streaming: true |
| | | chunk_size: 16 |
| | | left_context: 16 |
| | | right_context: 0 |
| | | |
| New file |
| | |
| | | # The conformer transducer decoding configuration from @jeon30c |
| | | beam_size: 10 |
| | | simu_streaming: true |
| | | streaming: false |
| | | chunk_size: 16 |
| New file |
| | |
| | | encoder: chunk_conformer |
| | | encoder_conf: |
| | | activation_type: swish |
| | | positional_dropout_rate: 0.5 |
| | | time_reduction_factor: 2 |
| | | unified_model_training: true |
| | | default_chunk_size: 16 |
| | | jitter_range: 4 |
| | | left_chunk_size: 0 |
| | | embed_vgg_like: false |
| | | subsampling_factor: 4 |
| | | linear_units: 2048 |
| | | output_size: 512 |
| | | attention_heads: 8 |
| | | dropout_rate: 0.5 |
| | | positional_dropout_rate: 0.5 |
| | | attention_dropout_rate: 0.5 |
| | | cnn_module_kernel: 15 |
| | | num_blocks: 12 |
| | | |
| | | # decoder related |
| | | rnnt_decoder: rnnt |
| | | rnnt_decoder_conf: |
| | | embed_size: 512 |
| | | hidden_size: 512 |
| | | embed_dropout_rate: 0.5 |
| | | dropout_rate: 0.5 |
| | | |
| | | joint_network_conf: |
| | | joint_space_size: 512 |
| | | |
| | | # Auxiliary CTC |
| | | model_conf: |
| | | auxiliary_ctc_weight: 0.0 |
| | | |
| | | # minibatch related |
| | | use_amp: true |
| | | batch_type: unsorted |
| | | batch_size: 16 |
| | | num_workers: 16 |
| | | |
| | | # optimization related |
| | | accum_grad: 1 |
| | | grad_clip: 5 |
| | | max_epoch: 200 |
| | | val_scheduler_criterion: |
| | | - valid |
| | | - loss |
| | | best_model_criterion: |
| | | - - valid |
| | | - cer_transducer_chunk |
| | | - min |
| | | keep_nbest_models: 10 |
| | | |
| | | optim: adam |
| | | optim_conf: |
| | | lr: 0.001 |
| | | scheduler: warmuplr |
| | | scheduler_conf: |
| | | warmup_steps: 25000 |
| | | |
| | | normalize: None |
| | | |
| | | specaug: specaug |
| | | specaug_conf: |
| | | apply_time_warp: true |
| | | time_warp_window: 5 |
| | | time_warp_mode: bicubic |
| | | apply_freq_mask: true |
| | | freq_mask_width_range: |
| | | - 0 |
| | | - 40 |
| | | num_freq_mask: 2 |
| | | apply_time_mask: true |
| | | time_mask_width_range: |
| | | - 0 |
| | | - 50 |
| | | num_time_mask: 5 |
| | | |
| | | log_interval: 50 |
| New file |
| | |
| | | #!/bin/bash |
| | | |
| | | # Copyright 2017 Xingyu Na |
| | | # Apache 2.0 |
| | | |
| | | #. ./path.sh || exit 1; |
| | | |
| | | if [ $# != 3 ]; then |
| | | echo "Usage: $0 <audio-path> <text-path> <output-path>" |
| | | echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data" |
| | | exit 1; |
| | | fi |
| | | |
| | | aishell_audio_dir=$1 |
| | | aishell_text=$2/aishell_transcript_v0.8.txt |
| | | output_dir=$3 |
| | | |
| | | train_dir=$output_dir/data/local/train |
| | | dev_dir=$output_dir/data/local/dev |
| | | test_dir=$output_dir/data/local/test |
| | | tmp_dir=$output_dir/data/local/tmp |
| | | |
| | | mkdir -p $train_dir |
| | | mkdir -p $dev_dir |
| | | mkdir -p $test_dir |
| | | mkdir -p $tmp_dir |
| | | |
| | | # data directory check |
| | | if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then |
| | | echo "Error: $0 requires two directory arguments" |
| | | exit 1; |
| | | fi |
| | | |
| | | # find wav audio file for train, dev and test resp. |
| | | find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist |
| | | n=`cat $tmp_dir/wav.flist | wc -l` |
| | | [ $n -ne 141925 ] && \ |
| | | echo Warning: expected 141925 data data files, found $n |
| | | |
| | | grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; |
| | | grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; |
| | | grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; |
| | | |
| | | rm -r $tmp_dir |
| | | |
| | | # Transcriptions preparation |
| | | for dir in $train_dir $dev_dir $test_dir; do |
| | | echo Preparing $dir transcriptions |
| | | sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list |
| | | paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all |
| | | utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt |
| | | awk '{print $1}' $dir/transcripts.txt > $dir/utt.list |
| | | utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp |
| | | sort -u $dir/transcripts.txt > $dir/text |
| | | done |
| | | |
| | | mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test |
| | | |
| | | for f in wav.scp text; do |
| | | cp $train_dir/$f $output_dir/data/train/$f || exit 1; |
| | | cp $dev_dir/$f $output_dir/data/dev/$f || exit 1; |
| | | cp $test_dir/$f $output_dir/data/test/$f || exit 1; |
| | | done |
| | | |
| | | echo "$0: AISHELL data preparation succeeded" |
| | | exit 0; |
| New file |
| | |
| | | export FUNASR_DIR=$PWD/../../.. |
| | | |
| | | # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C |
| | | export PYTHONIOENCODING=UTF-8 |
| | | export PATH=$FUNASR_DIR/funasr/bin:$PATH |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | . ./path.sh || exit 1; |
| | | |
| | | # machines configuration |
| | | CUDA_VISIBLE_DEVICES="0,1,2,3" |
| | | gpu_num=4 |
| | | count=1 |
| | | gpu_inference=true # Whether to perform gpu decoding, set false for cpu decoding |
| | | # for gpu decoding, inference_nj=ngpu*njob; for cpu decoding, inference_nj=njob |
| | | njob=5 |
| | | train_cmd=utils/run.pl |
| | | infer_cmd=utils/run.pl |
| | | |
| | | # general configuration |
| | | feats_dir= #feature output dictionary |
| | | exp_dir= |
| | | lang=zh |
| | | dumpdir=dump/fbank |
| | | feats_type=fbank |
| | | token_type=char |
| | | scp=feats.scp |
| | | type=kaldi_ark |
| | | stage=0 |
| | | stop_stage=4 |
| | | |
| | | # feature configuration |
| | | feats_dim=80 |
| | | sample_frequency=16000 |
| | | nj=32 |
| | | speed_perturb="0.9,1.0,1.1" |
| | | |
| | | # data |
| | | data_aishell= |
| | | |
| | | # exp tag |
| | | tag="exp1" |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # Set bash to 'debug' mode, it will exit on : |
| | | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', |
| | | set -e |
| | | set -u |
| | | set -o pipefail |
| | | |
| | | train_set=train |
| | | valid_set=dev |
| | | test_sets="dev test" |
| | | |
| | | asr_config=conf/train_conformer_rnnt_unified.yaml |
| | | model_dir="baseline_$(basename "${asr_config}" .yaml)_${feats_type}_${lang}_${token_type}_${tag}" |
| | | |
| | | inference_config=conf/decode_rnnt_conformer_streaming.yaml |
| | | inference_asr_model=valid.cer_transducer_chunk.ave_5best.pth |
| | | |
| | | # you can set gpu num for decoding here |
| | | gpuid_list=$CUDA_VISIBLE_DEVICES # set gpus for decoding, the same as training stage by default |
| | | ngpu=$(echo $gpuid_list | awk -F "," '{print NF}') |
| | | |
| | | if ${gpu_inference}; then |
| | | inference_nj=$[${ngpu}*${njob}] |
| | | _ngpu=1 |
| | | else |
| | | inference_nj=$njob |
| | | _ngpu=0 |
| | | fi |
| | | |
| | | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then |
| | | echo "stage 0: Data preparation" |
| | | # Data preparation |
| | | local/aishell_data_prep.sh ${data_aishell}/data_aishell/wav ${data_aishell}/data_aishell/transcript ${feats_dir} |
| | | for x in train dev test; do |
| | | cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org |
| | | paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \ |
| | | > ${feats_dir}/data/${x}/text |
| | | utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org |
| | | mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text |
| | | done |
| | | fi |
| | | |
| | | feat_train_dir=${feats_dir}/${dumpdir}/train; mkdir -p ${feat_train_dir} |
| | | feat_dev_dir=${feats_dir}/${dumpdir}/dev; mkdir -p ${feat_dev_dir} |
| | | feat_test_dir=${feats_dir}/${dumpdir}/test; mkdir -p ${feat_test_dir} |
| | | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then |
| | | echo "stage 1: Feature Generation" |
| | | # compute fbank features |
| | | fbankdir=${feats_dir}/fbank |
| | | utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} --speed_perturb ${speed_perturb} \ |
| | | ${feats_dir}/data/train ${exp_dir}/exp/make_fbank/train ${fbankdir}/train |
| | | utils/fix_data_feat.sh ${fbankdir}/train |
| | | utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \ |
| | | ${feats_dir}/data/dev ${exp_dir}/exp/make_fbank/dev ${fbankdir}/dev |
| | | utils/fix_data_feat.sh ${fbankdir}/dev |
| | | utils/compute_fbank.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} --sample_frequency ${sample_frequency} \ |
| | | ${feats_dir}/data/test ${exp_dir}/exp/make_fbank/test ${fbankdir}/test |
| | | utils/fix_data_feat.sh ${fbankdir}/test |
| | | |
| | | # compute global cmvn |
| | | utils/compute_cmvn.sh --cmd "$train_cmd" --nj $nj --feats_dim ${feats_dim} \ |
| | | ${fbankdir}/train ${exp_dir}/exp/make_fbank/train |
| | | |
| | | # apply cmvn |
| | | utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \ |
| | | ${fbankdir}/train ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/train ${feat_train_dir} |
| | | utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \ |
| | | ${fbankdir}/dev ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/dev ${feat_dev_dir} |
| | | utils/apply_cmvn.sh --cmd "$train_cmd" --nj $nj \ |
| | | ${fbankdir}/test ${fbankdir}/train/cmvn.json ${exp_dir}/exp/make_fbank/test ${feat_test_dir} |
| | | |
| | | cp ${fbankdir}/train/text ${fbankdir}/train/speech_shape ${fbankdir}/train/text_shape ${feat_train_dir} |
| | | cp ${fbankdir}/dev/text ${fbankdir}/dev/speech_shape ${fbankdir}/dev/text_shape ${feat_dev_dir} |
| | | cp ${fbankdir}/test/text ${fbankdir}/test/speech_shape ${fbankdir}/test/text_shape ${feat_test_dir} |
| | | |
| | | utils/fix_data_feat.sh ${feat_train_dir} |
| | | utils/fix_data_feat.sh ${feat_dev_dir} |
| | | utils/fix_data_feat.sh ${feat_test_dir} |
| | | |
| | | #generate ark list |
| | | utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_train_dir} ${fbankdir}/train ${feat_train_dir} |
| | | utils/gen_ark_list.sh --cmd "$train_cmd" --nj $nj ${feat_dev_dir} ${fbankdir}/dev ${feat_dev_dir} |
| | | fi |
| | | |
| | | token_list=${feats_dir}/data/${lang}_token_list/char/tokens.txt |
| | | echo "dictionary: ${token_list}" |
| | | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then |
| | | echo "stage 2: Dictionary Preparation" |
| | | mkdir -p ${feats_dir}/data/${lang}_token_list/char/ |
| | | |
| | | echo "make a dictionary" |
| | | echo "<blank>" > ${token_list} |
| | | echo "<s>" >> ${token_list} |
| | | echo "</s>" >> ${token_list} |
| | | utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/train/text | cut -f 2- -d" " | tr " " "\n" \ |
| | | | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} |
| | | num_token=$(cat ${token_list} | wc -l) |
| | | echo "<unk>" >> ${token_list} |
| | | vocab_size=$(cat ${token_list} | wc -l) |
| | | awk -v v=,${vocab_size} '{print $0v}' ${feat_train_dir}/text_shape > ${feat_train_dir}/text_shape.char |
| | | awk -v v=,${vocab_size} '{print $0v}' ${feat_dev_dir}/text_shape > ${feat_dev_dir}/text_shape.char |
| | | mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/train |
| | | mkdir -p ${feats_dir}/asr_stats_fbank_zh_char/dev |
| | | cp ${feat_train_dir}/speech_shape ${feat_train_dir}/text_shape ${feat_train_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/train |
| | | cp ${feat_dev_dir}/speech_shape ${feat_dev_dir}/text_shape ${feat_dev_dir}/text_shape.char ${feats_dir}/asr_stats_fbank_zh_char/dev |
| | | fi |
| | | |
| | | # Training Stage |
| | | world_size=$gpu_num # run on one machine |
| | | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then |
| | | echo "stage 3: Training" |
| | | mkdir -p ${exp_dir}/exp/${model_dir} |
| | | mkdir -p ${exp_dir}/exp/${model_dir}/log |
| | | INIT_FILE=${exp_dir}/exp/${model_dir}/ddp_init |
| | | if [ -f $INIT_FILE ];then |
| | | rm -f $INIT_FILE |
| | | fi |
| | | init_method=file://$(readlink -f $INIT_FILE) |
| | | echo "$0: init method is $init_method" |
| | | for ((i = 0; i < $gpu_num; ++i)); do |
| | | { |
| | | rank=$i |
| | | local_rank=$i |
| | | gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1]) |
| | | asr_train_transducer.py \ |
| | | --gpu_id $gpu_id \ |
| | | --use_preprocessor true \ |
| | | --token_type char \ |
| | | --token_list $token_list \ |
| | | --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/${scp},speech,${type} \ |
| | | --train_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${train_set}/text,text,text \ |
| | | --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/speech_shape \ |
| | | --train_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${train_set}/text_shape.char \ |
| | | --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/${scp},speech,${type} \ |
| | | --valid_data_path_and_name_and_type ${feats_dir}/${dumpdir}/${valid_set}/text,text,text \ |
| | | --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/speech_shape \ |
| | | --valid_shape_file ${feats_dir}/asr_stats_fbank_zh_char/${valid_set}/text_shape.char \ |
| | | --resume true \ |
| | | --output_dir ${exp_dir}/exp/${model_dir} \ |
| | | --config $asr_config \ |
| | | --input_size $feats_dim \ |
| | | --ngpu $gpu_num \ |
| | | --num_worker_count $count \ |
| | | --multiprocessing_distributed true \ |
| | | --dist_init_method $init_method \ |
| | | --dist_world_size $world_size \ |
| | | --dist_rank $rank \ |
| | | --local_rank $local_rank 1> ${exp_dir}/exp/${model_dir}/log/train.log.$i 2>&1 |
| | | } & |
| | | done |
| | | wait |
| | | fi |
| | | |
| | | # Testing Stage |
| | | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then |
| | | echo "stage 4: Inference" |
| | | for dset in ${test_sets}; do |
| | | asr_exp=${exp_dir}/exp/${model_dir} |
| | | inference_tag="$(basename "${inference_config}" .yaml)" |
| | | _dir="${asr_exp}/${inference_tag}/${inference_asr_model}/${dset}" |
| | | _logdir="${_dir}/logdir" |
| | | if [ -d ${_dir} ]; then |
| | | echo "${_dir} is already exists. if you want to decode again, please delete this dir first." |
| | | exit 0 |
| | | fi |
| | | mkdir -p "${_logdir}" |
| | | _data="${feats_dir}/${dumpdir}/${dset}" |
| | | key_file=${_data}/${scp} |
| | | num_scp_file="$(<${key_file} wc -l)" |
| | | _nj=$([ $inference_nj -le $num_scp_file ] && echo "$inference_nj" || echo "$num_scp_file") |
| | | split_scps= |
| | | for n in $(seq "${_nj}"); do |
| | | split_scps+=" ${_logdir}/keys.${n}.scp" |
| | | done |
| | | # shellcheck disable=SC2086 |
| | | utils/split_scp.pl "${key_file}" ${split_scps} |
| | | _opts= |
| | | if [ -n "${inference_config}" ]; then |
| | | _opts+="--config ${inference_config} " |
| | | fi |
| | | ${infer_cmd} --gpu "${_ngpu}" --max-jobs-run "${_nj}" JOB=1:"${_nj}" "${_logdir}"/asr_inference.JOB.log \ |
| | | python -m funasr.bin.asr_inference_launch \ |
| | | --batch_size 1 \ |
| | | --ngpu "${_ngpu}" \ |
| | | --njob ${njob} \ |
| | | --gpuid_list ${gpuid_list} \ |
| | | --data_path_and_name_and_type "${_data}/${scp},speech,${type}" \ |
| | | --key_file "${_logdir}"/keys.JOB.scp \ |
| | | --asr_train_config "${asr_exp}"/config.yaml \ |
| | | --asr_model_file "${asr_exp}"/"${inference_asr_model}" \ |
| | | --output_dir "${_logdir}"/output.JOB \ |
| | | --mode rnnt \ |
| | | ${_opts} |
| | | |
| | | for f in token token_int score text; do |
| | | if [ -f "${_logdir}/output.1/1best_recog/${f}" ]; then |
| | | for i in $(seq "${_nj}"); do |
| | | cat "${_logdir}/output.${i}/1best_recog/${f}" |
| | | done | sort -k1 >"${_dir}/${f}" |
| | | fi |
| | | done |
| | | python utils/proce_text.py ${_dir}/text ${_dir}/text.proc |
| | | python utils/proce_text.py ${_data}/text ${_data}/text.proc |
| | | python utils/compute_wer.py ${_data}/text.proc ${_dir}/text.proc ${_dir}/text.cer |
| | | tail -n 3 ${_dir}/text.cer > ${_dir}/text.cer.txt |
| | | cat ${_dir}/text.cer.txt |
| | | done |
| | | fi |
| New file |
| | |
| | | ../transformer/utils |
| | |
| | | help="Pretrained model tag. If specify this option, *_train_config and " |
| | | "*_file will be overwritten", |
| | | ) |
| | | group.add_argument( |
| | | "--beam_search_config", |
| | | default={}, |
| | | help="The keyword arguments for transducer beam search.", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Beam-search related") |
| | | group.add_argument( |
| | |
| | | group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") |
| | | group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") |
| | | group.add_argument("--streaming", type=str2bool, default=False) |
| | | group.add_argument("--simu_streaming", type=str2bool, default=False) |
| | | group.add_argument("--chunk_size", type=int, default=16) |
| | | group.add_argument("--left_context", type=int, default=16) |
| | | group.add_argument("--right_context", type=int, default=0) |
| | | group.add_argument( |
| | | "--display_partial_hypotheses", |
| | | type=bool, |
| | | default=False, |
| | | help="Whether to display partial hypotheses during chunk-by-chunk inference.", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Dynamic quantization related") |
| | | group.add_argument( |
| | | "--quantize_asr_model", |
| | | type=bool, |
| | | default=False, |
| | | help="Apply dynamic quantization to ASR model.", |
| | | ) |
| | | group.add_argument( |
| | | "--quantize_modules", |
| | | nargs="*", |
| | | default=None, |
| | | help="""Module names to apply dynamic quantization on. |
| | | The module names are provided as a list, where each name is separated |
| | | by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]). |
| | | Each specified name should be an attribute of 'torch.nn', e.g.: |
| | | torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""", |
| | | ) |
| | | group.add_argument( |
| | | "--quantize_dtype", |
| | | type=str, |
| | | default="qint8", |
| | | choices=["float16", "qint8"], |
| | | help="Dtype for dynamic quantization.", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Text converter related") |
| | | group.add_argument( |
| | |
| | | elif mode == "mfcca": |
| | | from funasr.bin.asr_inference_mfcca import inference_modelscope |
| | | return inference_modelscope(**kwargs) |
| | | elif mode == "rnnt": |
| | | from funasr.bin.asr_inference_rnnt import inference |
| | | return inference(**kwargs) |
| | | else: |
| | | logging.info("Unknown decoding mode: {}".format(mode)) |
| | | return None |
| | |
| | | #!/usr/bin/env python3 |
| | | |
| | | """ Inference class definition for Transducer models.""" |
| | | |
| | | from __future__ import annotations |
| | | |
| | | import argparse |
| | | import logging |
| | | import math |
| | | import sys |
| | | import time |
| | | import copy |
| | | import os |
| | | import codecs |
| | | import tempfile |
| | | import requests |
| | | from pathlib import Path |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Dict |
| | | from typing import Any |
| | | from typing import List |
| | | from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | from packaging.version import parse as V |
| | | from typeguard import check_argument_types, check_return_type |
| | | |
| | | from funasr.modules.beam_search.beam_search_transducer import ( |
| | | BeamSearchTransducer, |
| | | Hypothesis, |
| | | ) |
| | | from funasr.modules.nets_utils import TooShortUttError |
| | | from funasr.fileio.datadir_writer import DatadirWriter |
| | | from funasr.modules.beam_search.beam_search import BeamSearchPara as BeamSearch |
| | | from funasr.modules.beam_search.beam_search import Hypothesis |
| | | from funasr.modules.scorers.ctc import CTCPrefixScorer |
| | | from funasr.modules.scorers.length_bonus import LengthBonus |
| | | from funasr.modules.subsampling import TooShortUttError |
| | | from funasr.tasks.asr import ASRTaskParaformer as ASRTask |
| | | from funasr.tasks.asr import ASRTransducerTask |
| | | from funasr.tasks.lm import LMTask |
| | | from funasr.text.build_tokenizer import build_tokenizer |
| | | from funasr.text.token_id_converter import TokenIDConverter |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.torch_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.utils import config_argparse |
| | | from funasr.utils.types import str2bool, str2triple_str, str_or_none |
| | | from funasr.utils.cli_utils import get_commandline_args |
| | | from funasr.utils.types import str2bool |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils import asr_utils, wav_utils, postprocess_utils |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer |
| | | from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export |
| | | |
| | | |
| | | class Speech2Text: |
| | | """Speech2Text class |
| | | |
| | | Examples: |
| | | >>> import soundfile |
| | | >>> speech2text = Speech2Text("asr_config.yml", "asr.pb") |
| | | >>> audio, rate = soundfile.read("speech.wav") |
| | | >>> speech2text(audio) |
| | | [(text, token, token_int, hypothesis object), ...] |
| | | |
| | | """Speech2Text class for Transducer models. |
| | | Args: |
| | | asr_train_config: ASR model training config path. |
| | | asr_model_file: ASR model path. |
| | | beam_search_config: Beam search config path. |
| | | lm_train_config: Language Model training config path. |
| | | lm_file: Language Model config path. |
| | | token_type: Type of token units. |
| | | bpemodel: BPE model path. |
| | | device: Device to use for inference. |
| | | beam_size: Size of beam during search. |
| | | dtype: Data type. |
| | | lm_weight: Language model weight. |
| | | quantize_asr_model: Whether to apply dynamic quantization to ASR model. |
| | | quantize_modules: List of module names to apply dynamic quantization on. |
| | | quantize_dtype: Dynamic quantization data type. |
| | | nbest: Number of final hypothesis. |
| | | streaming: Whether to perform chunk-by-chunk inference. |
| | | chunk_size: Number of frames in chunk AFTER subsampling. |
| | | left_context: Number of frames in left context AFTER subsampling. |
| | | right_context: Number of frames in right context AFTER subsampling. |
| | | display_partial_hypotheses: Whether to display partial hypotheses. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | asr_train_config: Union[Path, str] = None, |
| | | asr_model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | lm_train_config: Union[Path, str] = None, |
| | | lm_file: Union[Path, str] = None, |
| | | token_type: str = None, |
| | | bpemodel: str = None, |
| | | device: str = "cpu", |
| | | maxlenratio: float = 0.0, |
| | | minlenratio: float = 0.0, |
| | | dtype: str = "float32", |
| | | beam_size: int = 20, |
| | | ctc_weight: float = 0.5, |
| | | lm_weight: float = 1.0, |
| | | ngram_weight: float = 0.9, |
| | | penalty: float = 0.0, |
| | | nbest: int = 1, |
| | | frontend_conf: dict = None, |
| | | hotword_list_or_file: str = None, |
| | | **kwargs, |
| | | ): |
| | | assert check_argument_types() |
| | | self, |
| | | asr_train_config: Union[Path, str] = None, |
| | | asr_model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | beam_search_config: Dict[str, Any] = None, |
| | | lm_train_config: Union[Path, str] = None, |
| | | lm_file: Union[Path, str] = None, |
| | | token_type: str = None, |
| | | bpemodel: str = None, |
| | | device: str = "cpu", |
| | | beam_size: int = 5, |
| | | dtype: str = "float32", |
| | | lm_weight: float = 1.0, |
| | | quantize_asr_model: bool = False, |
| | | quantize_modules: List[str] = None, |
| | | quantize_dtype: str = "qint8", |
| | | nbest: int = 1, |
| | | streaming: bool = False, |
| | | simu_streaming: bool = False, |
| | | chunk_size: int = 16, |
| | | left_context: int = 32, |
| | | right_context: int = 0, |
| | | display_partial_hypotheses: bool = False, |
| | | ) -> None: |
| | | """Construct a Speech2Text object.""" |
| | | super().__init__() |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | asr_model, asr_train_args = ASRTask.build_model_from_file( |
| | | assert check_argument_types() |
| | | asr_model, asr_train_args = ASRTransducerTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, cmvn_file, device |
| | | ) |
| | | |
| | | frontend = None |
| | | if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: |
| | | frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) |
| | | |
| | | logging.info("asr_model: {}".format(asr_model)) |
| | | logging.info("asr_train_args: {}".format(asr_train_args)) |
| | | asr_model.to(dtype=getattr(torch, dtype)).eval() |
| | | if quantize_asr_model: |
| | | if quantize_modules is not None: |
| | | if not all([q in ["LSTM", "Linear"] for q in quantize_modules]): |
| | | raise ValueError( |
| | | "Only 'Linear' and 'LSTM' modules are currently supported" |
| | | " by PyTorch and in --quantize_modules" |
| | | ) |
| | | |
| | | if asr_model.ctc != None: |
| | | ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos) |
| | | scorers.update( |
| | | ctc=ctc |
| | | ) |
| | | token_list = asr_model.token_list |
| | | scorers.update( |
| | | length_bonus=LengthBonus(len(token_list)), |
| | | ) |
| | | q_config = set([getattr(torch.nn, q) for q in quantize_modules]) |
| | | else: |
| | | q_config = {torch.nn.Linear} |
| | | |
| | | # 2. Build Language model |
| | | if quantize_dtype == "float16" and (V(torch.__version__) < V("1.5.0")): |
| | | raise ValueError( |
| | | "float16 dtype for dynamic quantization is not supported with torch" |
| | | " version < 1.5.0. Switching to qint8 dtype instead." |
| | | ) |
| | | q_dtype = getattr(torch, quantize_dtype) |
| | | |
| | | asr_model = torch.quantization.quantize_dynamic( |
| | | asr_model, q_config, dtype=q_dtype |
| | | ).eval() |
| | | else: |
| | | asr_model.to(dtype=getattr(torch, dtype)).eval() |
| | | |
| | | if lm_train_config is not None: |
| | | lm, lm_train_args = LMTask.build_model_from_file( |
| | | lm_train_config, lm_file, device |
| | | ) |
| | | scorers["lm"] = lm.lm |
| | | |
| | | # 3. Build ngram model |
| | | # ngram is not supported now |
| | | ngram = None |
| | | scorers["ngram"] = ngram |
| | | lm_scorer = lm.lm |
| | | else: |
| | | lm_scorer = None |
| | | |
| | | # 4. Build BeamSearch object |
| | | # transducer is not supported now |
| | | beam_search_transducer = None |
| | | if beam_search_config is None: |
| | | beam_search_config = {} |
| | | |
| | | weights = dict( |
| | | decoder=1.0 - ctc_weight, |
| | | ctc=ctc_weight, |
| | | lm=lm_weight, |
| | | ngram=ngram_weight, |
| | | length_bonus=penalty, |
| | | beam_search = BeamSearchTransducer( |
| | | asr_model.decoder, |
| | | asr_model.joint_network, |
| | | beam_size, |
| | | lm=lm_scorer, |
| | | lm_weight=lm_weight, |
| | | nbest=nbest, |
| | | **beam_search_config, |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=beam_size, |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=asr_model.sos, |
| | | eos=asr_model.eos, |
| | | vocab_size=len(token_list), |
| | | token_list=token_list, |
| | | pre_beam_score_key=None if ctc_weight == 1.0 else "full", |
| | | ) |
| | | |
| | | beam_search.to(device=device, dtype=getattr(torch, dtype)).eval() |
| | | for scorer in scorers.values(): |
| | | if isinstance(scorer, torch.nn.Module): |
| | | scorer.to(device=device, dtype=getattr(torch, dtype)).eval() |
| | | |
| | | logging.info(f"Decoding device={device}, dtype={dtype}") |
| | | |
| | | # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text |
| | | if token_type is None: |
| | | token_type = asr_train_args.token_type |
| | | if bpemodel is None: |
| | | bpemodel = asr_train_args.bpemodel |
| | | |
| | | if token_type is None: |
| | | tokenizer = None |
| | | elif token_type == "bpe": |
| | | if bpemodel is not None: |
| | | tokenizer = build_tokenizer(token_type=token_type, bpemodel=bpemodel) |
| | | else: |
| | | tokenizer = None |
| | | else: |
| | | tokenizer = build_tokenizer(token_type=token_type) |
| | | converter = TokenIDConverter(token_list=token_list) |
| | | logging.info(f"Text tokenizer: {tokenizer}") |
| | | |
| | | self.asr_model = asr_model |
| | | self.asr_train_args = asr_train_args |
| | | self.converter = converter |
| | | self.tokenizer = tokenizer |
| | | |
| | | # 6. [Optional] Build hotword list from str, local file or url |
| | | self.hotword_list = None |
| | | self.hotword_list = self.generate_hotwords_list(hotword_list_or_file) |
| | | |
| | | is_use_lm = lm_weight != 0.0 and lm_file is not None |
| | | if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm: |
| | | beam_search = None |
| | | self.beam_search = beam_search |
| | | logging.info(f"Beam_search: {self.beam_search}") |
| | | self.beam_search_transducer = beam_search_transducer |
| | | self.maxlenratio = maxlenratio |
| | | self.minlenratio = minlenratio |
| | | self.device = device |
| | | self.dtype = dtype |
| | | self.nbest = nbest |
| | | self.frontend = frontend |
| | | self.encoder_downsampling_factor = 1 |
| | | if asr_train_args.encoder == "data2vec_encoder" or asr_train_args.encoder_conf["input_layer"] == "conv2d": |
| | | self.encoder_downsampling_factor = 4 |
| | | |
| | | @torch.no_grad() |
| | | def __call__( |
| | | self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None |
| | | ): |
| | | """Inference |
| | | |
| | | Args: |
| | | speech: Input speech data |
| | | Returns: |
| | | text, token, token_int, hyp |
| | | |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | # Input as audio signal |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | |
| | | if self.frontend is not None: |
| | | feats, feats_len = self.frontend.forward(speech, speech_lengths) |
| | | feats = to_device(feats, device=self.device) |
| | | feats_len = feats_len.int() |
| | | self.asr_model.frontend = None |
| | | else: |
| | | feats = speech |
| | | feats_len = speech_lengths |
| | | lfr_factor = max(1, (feats.size()[-1] // 80) - 1) |
| | | batch = {"speech": feats, "speech_lengths": feats_len} |
| | | |
| | | # a. To device |
| | | batch = to_device(batch, device=self.device) |
| | | |
| | | # b. Forward Encoder |
| | | enc, enc_len = self.asr_model.encode(**batch) |
| | | if isinstance(enc, tuple): |
| | | enc = enc[0] |
| | | # assert len(enc) == 1, len(enc) |
| | | enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor |
| | | |
| | | predictor_outs = self.asr_model.calc_predictor(enc, enc_len) |
| | | pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ |
| | | predictor_outs[2], predictor_outs[3] |
| | | pre_token_length = pre_token_length.round().long() |
| | | if torch.max(pre_token_length) < 1: |
| | | return [] |
| | | if not isinstance(self.asr_model, ContextualParaformer): |
| | | if self.hotword_list: |
| | | logging.warning("Hotword is given but asr model is not a ContextualParaformer.") |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | else: |
| | | decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | | results = [] |
| | | b, n, d = decoder_out.size() |
| | | for i in range(b): |
| | | x = enc[i, :enc_len[i], :] |
| | | am_scores = decoder_out[i, :pre_token_length[i], :] |
| | | if self.beam_search is not None: |
| | | nbest_hyps = self.beam_search( |
| | | x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio |
| | | ) |
| | | |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | else: |
| | | yseq = am_scores.argmax(dim=-1) |
| | | score = am_scores.max(dim=-1)[0] |
| | | score = torch.sum(score, dim=-1) |
| | | # pad with mask tokens to ensure compatibility with sos/eos tokens |
| | | yseq = torch.tensor( |
| | | [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device |
| | | ) |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | |
| | | for hyp in nbest_hyps: |
| | | assert isinstance(hyp, (Hypothesis)), type(hyp) |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) |
| | | |
| | | # Change integer-ids to tokens |
| | | token = self.converter.ids2tokens(token_int) |
| | | |
| | | if self.tokenizer is not None: |
| | | text = self.tokenizer.tokens2text(token) |
| | | else: |
| | | text = None |
| | | |
| | | results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) |
| | | |
| | | # assert check_return_type(results) |
| | | return results |
| | | |
| | | def generate_hotwords_list(self, hotword_list_or_file): |
| | | # for None |
| | | if hotword_list_or_file is None: |
| | | hotword_list = None |
| | | # for local txt inputs |
| | | elif os.path.exists(hotword_list_or_file) and hotword_list_or_file.endswith('.txt'): |
| | | logging.info("Attempting to parse hotwords from local txt...") |
| | | hotword_list = [] |
| | | hotword_str_list = [] |
| | | with codecs.open(hotword_list_or_file, 'r') as fin: |
| | | for line in fin.readlines(): |
| | | hw = line.strip() |
| | | hotword_str_list.append(hw) |
| | | hotword_list.append(self.converter.tokens2ids([i for i in hw])) |
| | | hotword_list.append([self.asr_model.sos]) |
| | | hotword_str_list.append('<s>') |
| | | logging.info("Initialized hotword list from file: {}, hotword list: {}." |
| | | .format(hotword_list_or_file, hotword_str_list)) |
| | | # for url, download and generate txt |
| | | elif hotword_list_or_file.startswith('http'): |
| | | logging.info("Attempting to parse hotwords from url...") |
| | | work_dir = tempfile.TemporaryDirectory().name |
| | | if not os.path.exists(work_dir): |
| | | os.makedirs(work_dir) |
| | | text_file_path = os.path.join(work_dir, os.path.basename(hotword_list_or_file)) |
| | | local_file = requests.get(hotword_list_or_file) |
| | | open(text_file_path, "wb").write(local_file.content) |
| | | hotword_list_or_file = text_file_path |
| | | hotword_list = [] |
| | | hotword_str_list = [] |
| | | with codecs.open(hotword_list_or_file, 'r') as fin: |
| | | for line in fin.readlines(): |
| | | hw = line.strip() |
| | | hotword_str_list.append(hw) |
| | | hotword_list.append(self.converter.tokens2ids([i for i in hw])) |
| | | hotword_list.append([self.asr_model.sos]) |
| | | hotword_str_list.append('<s>') |
| | | logging.info("Initialized hotword list from file: {}, hotword list: {}." |
| | | .format(hotword_list_or_file, hotword_str_list)) |
| | | # for text str input |
| | | elif not hotword_list_or_file.endswith('.txt'): |
| | | logging.info("Attempting to parse hotwords as str...") |
| | | hotword_list = [] |
| | | hotword_str_list = [] |
| | | for hw in hotword_list_or_file.strip().split(): |
| | | hotword_str_list.append(hw) |
| | | hotword_list.append(self.converter.tokens2ids([i for i in hw])) |
| | | hotword_list.append([self.asr_model.sos]) |
| | | hotword_str_list.append('<s>') |
| | | logging.info("Hotword list: {}.".format(hotword_str_list)) |
| | | else: |
| | | hotword_list = None |
| | | return hotword_list |
| | | |
| | | class Speech2TextExport: |
| | | """Speech2TextExport class |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | asr_train_config: Union[Path, str] = None, |
| | | asr_model_file: Union[Path, str] = None, |
| | | cmvn_file: Union[Path, str] = None, |
| | | lm_train_config: Union[Path, str] = None, |
| | | lm_file: Union[Path, str] = None, |
| | | token_type: str = None, |
| | | bpemodel: str = None, |
| | | device: str = "cpu", |
| | | maxlenratio: float = 0.0, |
| | | minlenratio: float = 0.0, |
| | | dtype: str = "float32", |
| | | beam_size: int = 20, |
| | | ctc_weight: float = 0.5, |
| | | lm_weight: float = 1.0, |
| | | ngram_weight: float = 0.9, |
| | | penalty: float = 0.0, |
| | | nbest: int = 1, |
| | | frontend_conf: dict = None, |
| | | hotword_list_or_file: str = None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | # 1. Build ASR model |
| | | asr_model, asr_train_args = ASRTask.build_model_from_file( |
| | | asr_train_config, asr_model_file, cmvn_file, device |
| | | ) |
| | | frontend = None |
| | | if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None: |
| | | frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf) |
| | | |
| | | logging.info("asr_model: {}".format(asr_model)) |
| | | logging.info("asr_train_args: {}".format(asr_train_args)) |
| | | asr_model.to(dtype=getattr(torch, dtype)).eval() |
| | | |
| | | token_list = asr_model.token_list |
| | | |
| | | |
| | | |
| | | logging.info(f"Decoding device={device}, dtype={dtype}") |
| | | |
| | | # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text |
| | | if token_type is None: |
| | | token_type = asr_train_args.token_type |
| | | if bpemodel is None: |
| | |
| | | tokenizer = build_tokenizer(token_type=token_type) |
| | | converter = TokenIDConverter(token_list=token_list) |
| | | logging.info(f"Text tokenizer: {tokenizer}") |
| | | |
| | | # self.asr_model = asr_model |
| | | |
| | | self.asr_model = asr_model |
| | | self.asr_train_args = asr_train_args |
| | | self.converter = converter |
| | | self.tokenizer = tokenizer |
| | | |
| | | self.device = device |
| | | self.dtype = dtype |
| | | self.nbest = nbest |
| | | self.frontend = frontend |
| | | |
| | | model = Paraformer_export(asr_model, onnx=False) |
| | | self.asr_model = model |
| | | self.converter = converter |
| | | self.tokenizer = tokenizer |
| | | |
| | | self.beam_search = beam_search |
| | | self.streaming = streaming |
| | | self.simu_streaming = simu_streaming |
| | | self.chunk_size = max(chunk_size, 0) |
| | | self.left_context = max(left_context, 0) |
| | | self.right_context = max(right_context, 0) |
| | | |
| | | if not streaming or chunk_size == 0: |
| | | self.streaming = False |
| | | self.asr_model.encoder.dynamic_chunk_training = False |
| | | |
| | | if not simu_streaming or chunk_size == 0: |
| | | self.simu_streaming = False |
| | | self.asr_model.encoder.dynamic_chunk_training = False |
| | | |
| | | self.frontend = frontend |
| | | self.window_size = self.chunk_size + self.right_context |
| | | |
| | | self._ctx = self.asr_model.encoder.get_encoder_input_size( |
| | | self.window_size |
| | | ) |
| | | |
| | | #self.last_chunk_length = ( |
| | | # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 |
| | | #) * self.hop_length |
| | | |
| | | self.last_chunk_length = ( |
| | | self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 |
| | | ) |
| | | self.reset_inference_cache() |
| | | |
| | | def reset_inference_cache(self) -> None: |
| | | """Reset Speech2Text parameters.""" |
| | | self.frontend_cache = None |
| | | |
| | | self.asr_model.encoder.reset_streaming_cache( |
| | | self.left_context, device=self.device |
| | | ) |
| | | self.beam_search.reset_inference_cache() |
| | | |
| | | self.num_processed_frames = torch.tensor([[0]], device=self.device) |
| | | |
| | | @torch.no_grad() |
| | | def __call__( |
| | | self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None |
| | | ): |
| | | """Inference |
| | | |
| | | def streaming_decode( |
| | | self, |
| | | speech: Union[torch.Tensor, np.ndarray], |
| | | is_final: bool = True, |
| | | ) -> List[Hypothesis]: |
| | | """Speech2Text streaming call. |
| | | Args: |
| | | speech: Input speech data |
| | | speech: Chunk of speech data. (S) |
| | | is_final: Whether speech corresponds to the final chunk of data. |
| | | Returns: |
| | | text, token, token_int, hyp |
| | | nbest_hypothesis: N-best hypothesis. |
| | | """ |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | if is_final: |
| | | if self.streaming and speech.size(0) < self.last_chunk_length: |
| | | pad = torch.zeros( |
| | | self.last_chunk_length - speech.size(0), speech.size(1), dtype=speech.dtype |
| | | ) |
| | | speech = torch.cat([speech, pad], dim=0) #feats, feats_length = self.apply_frontend(speech, is_final=is_final) |
| | | |
| | | feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) |
| | | feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) |
| | | |
| | | if self.asr_model.normalize is not None: |
| | | feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) |
| | | |
| | | feats = to_device(feats, device=self.device) |
| | | feats_lengths = to_device(feats_lengths, device=self.device) |
| | | enc_out = self.asr_model.encoder.chunk_forward( |
| | | feats, |
| | | feats_lengths, |
| | | self.num_processed_frames, |
| | | chunk_size=self.chunk_size, |
| | | left_context=self.left_context, |
| | | right_context=self.right_context, |
| | | ) |
| | | nbest_hyps = self.beam_search(enc_out[0], is_final=is_final) |
| | | |
| | | self.num_processed_frames += self.chunk_size |
| | | |
| | | if is_final: |
| | | self.reset_inference_cache() |
| | | |
| | | return nbest_hyps |
| | | |
| | | @torch.no_grad() |
| | | def simu_streaming_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]: |
| | | """Speech2Text call. |
| | | Args: |
| | | speech: Speech data. (S) |
| | | Returns: |
| | | nbest_hypothesis: N-best hypothesis. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | # Input as audio signal |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | |
| | | if self.frontend is not None: |
| | | feats, feats_len = self.frontend.forward(speech, speech_lengths) |
| | | feats = to_device(feats, device=self.device) |
| | | feats_len = feats_len.int() |
| | | self.asr_model.frontend = None |
| | | else: |
| | | feats = speech |
| | | feats_len = speech_lengths |
| | | |
| | | enc_len_batch_total = feats_len.sum() |
| | | lfr_factor = max(1, (feats.size()[-1] // 80) - 1) |
| | | batch = {"speech": feats, "speech_lengths": feats_len} |
| | | |
| | | # a. To device |
| | | batch = to_device(batch, device=self.device) |
| | | |
| | | decoder_outs = self.asr_model(**batch) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | | feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) |
| | | feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) |
| | | |
| | | if self.asr_model.normalize is not None: |
| | | feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) |
| | | |
| | | feats = to_device(feats, device=self.device) |
| | | feats_lengths = to_device(feats_lengths, device=self.device) |
| | | enc_out = self.asr_model.encoder.simu_chunk_forward(feats, feats_lengths, self.chunk_size, self.left_context, self.right_context) |
| | | nbest_hyps = self.beam_search(enc_out[0]) |
| | | |
| | | return nbest_hyps |
| | | |
| | | @torch.no_grad() |
| | | def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[Hypothesis]: |
| | | """Speech2Text call. |
| | | Args: |
| | | speech: Speech data. (S) |
| | | Returns: |
| | | nbest_hypothesis: N-best hypothesis. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | |
| | | feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) |
| | | feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) |
| | | |
| | | feats = to_device(feats, device=self.device) |
| | | feats_lengths = to_device(feats_lengths, device=self.device) |
| | | |
| | | enc_out, _ = self.asr_model.encoder(feats, feats_lengths) |
| | | |
| | | nbest_hyps = self.beam_search(enc_out[0]) |
| | | |
| | | return nbest_hyps |
| | | |
| | | def hypotheses_to_results(self, nbest_hyps: List[Hypothesis]) -> List[Any]: |
| | | """Build partial or final results from the hypotheses. |
| | | Args: |
| | | nbest_hyps: N-best hypothesis. |
| | | Returns: |
| | | results: Results containing different representation for the hypothesis. |
| | | """ |
| | | results = [] |
| | | b, n, d = decoder_out.size() |
| | | for i in range(b): |
| | | am_scores = decoder_out[i, :ys_pad_lens[i], :] |
| | | |
| | | yseq = am_scores.argmax(dim=-1) |
| | | score = am_scores.max(dim=-1)[0] |
| | | score = torch.sum(score, dim=-1) |
| | | # pad with mask tokens to ensure compatibility with sos/eos tokens |
| | | yseq = torch.tensor( |
| | | yseq.tolist(), device=yseq.device |
| | | ) |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | for hyp in nbest_hyps: |
| | | token_int = list(filter(lambda x: x != 0, hyp.yseq)) |
| | | |
| | | for hyp in nbest_hyps: |
| | | assert isinstance(hyp, (Hypothesis)), type(hyp) |
| | | token = self.converter.ids2tokens(token_int) |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | if self.tokenizer is not None: |
| | | text = self.tokenizer.tokens2text(token) |
| | | else: |
| | | text = None |
| | | results.append((text, token, token_int, hyp)) |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x != 0 and x != 2, token_int)) |
| | | |
| | | # Change integer-ids to tokens |
| | | token = self.converter.ids2tokens(token_int) |
| | | |
| | | if self.tokenizer is not None: |
| | | text = self.tokenizer.tokens2text(token) |
| | | else: |
| | | text = None |
| | | |
| | | results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) |
| | | assert check_return_type(results) |
| | | |
| | | return results |
| | | |
| | | @staticmethod |
| | | def from_pretrained( |
| | | model_tag: Optional[str] = None, |
| | | **kwargs: Optional[Any], |
| | | ) -> Speech2Text: |
| | | """Build Speech2Text instance from the pretrained model. |
| | | Args: |
| | | model_tag: Model tag of the pretrained models. |
| | | Return: |
| | | : Speech2Text instance. |
| | | """ |
| | | if model_tag is not None: |
| | | try: |
| | | from espnet_model_zoo.downloader import ModelDownloader |
| | | |
| | | except ImportError: |
| | | logging.error( |
| | | "`espnet_model_zoo` is not installed. " |
| | | "Please install via `pip install -U espnet_model_zoo`." |
| | | ) |
| | | raise |
| | | d = ModelDownloader() |
| | | kwargs.update(**d.download_and_unpack(model_tag)) |
| | | |
| | | return Speech2Text(**kwargs) |
| | | |
| | | |
| | | def inference( |
| | | maxlenratio: float, |
| | | minlenratio: float, |
| | | batch_size: int, |
| | | beam_size: int, |
| | | ngpu: int, |
| | | ctc_weight: float, |
| | | lm_weight: float, |
| | | penalty: float, |
| | | log_level: Union[int, str], |
| | | data_path_and_name_and_type, |
| | | asr_train_config: Optional[str], |
| | | asr_model_file: Optional[str], |
| | | cmvn_file: Optional[str] = None, |
| | | raw_inputs: Union[np.ndarray, torch.Tensor] = None, |
| | | lm_train_config: Optional[str] = None, |
| | | lm_file: Optional[str] = None, |
| | | token_type: Optional[str] = None, |
| | | key_file: Optional[str] = None, |
| | | word_lm_train_config: Optional[str] = None, |
| | | bpemodel: Optional[str] = None, |
| | | allow_variable_data_keys: bool = False, |
| | | streaming: bool = False, |
| | | output_dir: Optional[str] = None, |
| | | dtype: str = "float32", |
| | | seed: int = 0, |
| | | ngram_weight: float = 0.9, |
| | | nbest: int = 1, |
| | | num_workers: int = 1, |
| | | |
| | | **kwargs, |
| | | ): |
| | | inference_pipeline = inference_modelscope( |
| | | maxlenratio=maxlenratio, |
| | | minlenratio=minlenratio, |
| | | batch_size=batch_size, |
| | | beam_size=beam_size, |
| | | ngpu=ngpu, |
| | | ctc_weight=ctc_weight, |
| | | lm_weight=lm_weight, |
| | | penalty=penalty, |
| | | log_level=log_level, |
| | | asr_train_config=asr_train_config, |
| | | asr_model_file=asr_model_file, |
| | | cmvn_file=cmvn_file, |
| | | raw_inputs=raw_inputs, |
| | | lm_train_config=lm_train_config, |
| | | lm_file=lm_file, |
| | | token_type=token_type, |
| | | key_file=key_file, |
| | | word_lm_train_config=word_lm_train_config, |
| | | bpemodel=bpemodel, |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | streaming=streaming, |
| | | output_dir=output_dir, |
| | | dtype=dtype, |
| | | seed=seed, |
| | | ngram_weight=ngram_weight, |
| | | nbest=nbest, |
| | | num_workers=num_workers, |
| | | |
| | | **kwargs, |
| | | ) |
| | | return inference_pipeline(data_path_and_name_and_type, raw_inputs) |
| | | |
| | | |
| | | def inference_modelscope( |
| | | maxlenratio: float, |
| | | minlenratio: float, |
| | | batch_size: int, |
| | | beam_size: int, |
| | | ngpu: int, |
| | | ctc_weight: float, |
| | | lm_weight: float, |
| | | penalty: float, |
| | | log_level: Union[int, str], |
| | | # data_path_and_name_and_type, |
| | | asr_train_config: Optional[str], |
| | | asr_model_file: Optional[str], |
| | | cmvn_file: Optional[str] = None, |
| | | lm_train_config: Optional[str] = None, |
| | | lm_file: Optional[str] = None, |
| | | token_type: Optional[str] = None, |
| | | key_file: Optional[str] = None, |
| | | word_lm_train_config: Optional[str] = None, |
| | | bpemodel: Optional[str] = None, |
| | | allow_variable_data_keys: bool = False, |
| | | dtype: str = "float32", |
| | | seed: int = 0, |
| | | ngram_weight: float = 0.9, |
| | | nbest: int = 1, |
| | | num_workers: int = 1, |
| | | output_dir: Optional[str] = None, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | output_dir: str, |
| | | batch_size: int, |
| | | dtype: str, |
| | | beam_size: int, |
| | | ngpu: int, |
| | | seed: int, |
| | | lm_weight: float, |
| | | nbest: int, |
| | | num_workers: int, |
| | | log_level: Union[int, str], |
| | | data_path_and_name_and_type: Sequence[Tuple[str, str, str]], |
| | | asr_train_config: Optional[str], |
| | | asr_model_file: Optional[str], |
| | | cmvn_file: Optional[str], |
| | | beam_search_config: Optional[dict], |
| | | lm_train_config: Optional[str], |
| | | lm_file: Optional[str], |
| | | model_tag: Optional[str], |
| | | token_type: Optional[str], |
| | | bpemodel: Optional[str], |
| | | key_file: Optional[str], |
| | | allow_variable_data_keys: bool, |
| | | quantize_asr_model: Optional[bool], |
| | | quantize_modules: Optional[List[str]], |
| | | quantize_dtype: Optional[str], |
| | | streaming: Optional[bool], |
| | | simu_streaming: Optional[bool], |
| | | chunk_size: Optional[int], |
| | | left_context: Optional[int], |
| | | right_context: Optional[int], |
| | | display_partial_hypotheses: bool, |
| | | **kwargs, |
| | | ) -> None: |
| | | """Transducer model inference. |
| | | Args: |
| | | output_dir: Output directory path. |
| | | batch_size: Batch decoding size. |
| | | dtype: Data type. |
| | | beam_size: Beam size. |
| | | ngpu: Number of GPUs. |
| | | seed: Random number generator seed. |
| | | lm_weight: Weight of language model. |
| | | nbest: Number of final hypothesis. |
| | | num_workers: Number of workers. |
| | | log_level: Level of verbose for logs. |
| | | data_path_and_name_and_type: |
| | | asr_train_config: ASR model training config path. |
| | | asr_model_file: ASR model path. |
| | | beam_search_config: Beam search config path. |
| | | lm_train_config: Language Model training config path. |
| | | lm_file: Language Model path. |
| | | model_tag: Model tag. |
| | | token_type: Type of token units. |
| | | bpemodel: BPE model path. |
| | | key_file: File key. |
| | | allow_variable_data_keys: Whether to allow variable data keys. |
| | | quantize_asr_model: Whether to apply dynamic quantization to ASR model. |
| | | quantize_modules: List of module names to apply dynamic quantization on. |
| | | quantize_dtype: Dynamic quantization data type. |
| | | streaming: Whether to perform chunk-by-chunk inference. |
| | | chunk_size: Number of frames in chunk AFTER subsampling. |
| | | left_context: Number of frames in left context AFTER subsampling. |
| | | right_context: Number of frames in right context AFTER subsampling. |
| | | display_partial_hypotheses: Whether to display partial hypotheses. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | if word_lm_train_config is not None: |
| | | raise NotImplementedError("Word LM is not implemented") |
| | | if batch_size > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | if ngpu > 1: |
| | | raise NotImplementedError("only single GPU decoding is supported") |
| | | |
| | |
| | | level=log_level, |
| | | format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", |
| | | ) |
| | | |
| | | export_mode = False |
| | | if param_dict is not None: |
| | | hotword_list_or_file = param_dict.get('hotword') |
| | | export_mode = param_dict.get("export_mode", False) |
| | | else: |
| | | hotword_list_or_file = None |
| | | |
| | | if ngpu >= 1 and torch.cuda.is_available(): |
| | | if ngpu >= 1: |
| | | device = "cuda" |
| | | else: |
| | | device = "cpu" |
| | | batch_size = 1 |
| | | |
| | | # 1. Set random-seed |
| | | set_all_random_seed(seed) |
| | | |
| | |
| | | asr_train_config=asr_train_config, |
| | | asr_model_file=asr_model_file, |
| | | cmvn_file=cmvn_file, |
| | | beam_search_config=beam_search_config, |
| | | lm_train_config=lm_train_config, |
| | | lm_file=lm_file, |
| | | token_type=token_type, |
| | | bpemodel=bpemodel, |
| | | device=device, |
| | | maxlenratio=maxlenratio, |
| | | minlenratio=minlenratio, |
| | | dtype=dtype, |
| | | beam_size=beam_size, |
| | | ctc_weight=ctc_weight, |
| | | lm_weight=lm_weight, |
| | | ngram_weight=ngram_weight, |
| | | penalty=penalty, |
| | | nbest=nbest, |
| | | hotword_list_or_file=hotword_list_or_file, |
| | | quantize_asr_model=quantize_asr_model, |
| | | quantize_modules=quantize_modules, |
| | | quantize_dtype=quantize_dtype, |
| | | streaming=streaming, |
| | | simu_streaming=simu_streaming, |
| | | chunk_size=chunk_size, |
| | | left_context=left_context, |
| | | right_context=right_context, |
| | | ) |
| | | if export_mode: |
| | | speech2text = Speech2TextExport(**speech2text_kwargs) |
| | | else: |
| | | speech2text = Speech2Text(**speech2text_kwargs) |
| | | speech2text = Speech2Text.from_pretrained( |
| | | model_tag=model_tag, |
| | | **speech2text_kwargs, |
| | | ) |
| | | |
| | | def _forward( |
| | | data_path_and_name_and_type, |
| | | raw_inputs: Union[np.ndarray, torch.Tensor] = None, |
| | | output_dir_v2: Optional[str] = None, |
| | | fs: dict = None, |
| | | param_dict: dict = None, |
| | | **kwargs, |
| | | ): |
| | | # 3. Build data-iterator |
| | | loader = ASRTransducerTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=ASRTransducerTask.build_preprocess_fn( |
| | | speech2text.asr_train_args, False |
| | | ), |
| | | collate_fn=ASRTransducerTask.build_collate_fn( |
| | | speech2text.asr_train_args, False |
| | | ), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | hotword_list_or_file = None |
| | | if param_dict is not None: |
| | | hotword_list_or_file = param_dict.get('hotword') |
| | | if 'hotword' in kwargs: |
| | | hotword_list_or_file = kwargs['hotword'] |
| | | if hotword_list_or_file is not None or 'hotword' in kwargs: |
| | | speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) |
| | | cache = None |
| | | if 'cache' in param_dict: |
| | | cache = param_dict['cache'] |
| | | # 3. Build data-iterator |
| | | if data_path_and_name_and_type is None and raw_inputs is not None: |
| | | if isinstance(raw_inputs, torch.Tensor): |
| | | raw_inputs = raw_inputs.numpy() |
| | | data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] |
| | | loader = ASRTask.build_streaming_iterator( |
| | | data_path_and_name_and_type, |
| | | dtype=dtype, |
| | | fs=fs, |
| | | batch_size=batch_size, |
| | | key_file=key_file, |
| | | num_workers=num_workers, |
| | | preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False), |
| | | collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False), |
| | | allow_variable_data_keys=allow_variable_data_keys, |
| | | inference=True, |
| | | ) |
| | | |
| | | forward_time_total = 0.0 |
| | | length_total = 0.0 |
| | | finish_count = 0 |
| | | file_count = 1 |
| | | # 7 .Start for-loop |
| | | # FIXME(kamo): The output format should be discussed about |
| | | asr_result_list = [] |
| | | output_path = output_dir_v2 if output_dir_v2 is not None else output_dir |
| | | if output_path is not None: |
| | | writer = DatadirWriter(output_path) |
| | | else: |
| | | writer = None |
| | | |
| | | # 4 .Start for-loop |
| | | with DatadirWriter(output_dir) as writer: |
| | | for keys, batch in loader: |
| | | assert isinstance(batch, dict), type(batch) |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")} |
| | | batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} |
| | | assert len(batch.keys()) == 1 |
| | | |
| | | logging.info("decoding, utt_id: {}".format(keys)) |
| | | # N-best list of (text, token, token_int, hyp_object) |
| | | try: |
| | | if speech2text.streaming: |
| | | speech = batch["speech"] |
| | | |
| | | time_beg = time.time() |
| | | results = speech2text(cache=cache, **batch) |
| | | if len(results) < 1: |
| | | hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[]) |
| | | results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest |
| | | time_end = time.time() |
| | | forward_time = time_end - time_beg |
| | | lfr_factor = results[0][-1] |
| | | length = results[0][-2] |
| | | forward_time_total += forward_time |
| | | length_total += length |
| | | rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor)) |
| | | logging.info(rtf_cur) |
| | | _steps = len(speech) // speech2text._ctx |
| | | _end = 0 |
| | | for i in range(_steps): |
| | | _end = (i + 1) * speech2text._ctx |
| | | |
| | | for batch_id in range(_bs): |
| | | result = [results[batch_id][:-2]] |
| | | speech2text.streaming_decode( |
| | | speech[i * speech2text._ctx : _end], is_final=False |
| | | ) |
| | | |
| | | key = keys[batch_id] |
| | | for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result): |
| | | # Create a directory: outdir/{n}best_recog |
| | | if writer is not None: |
| | | ibest_writer = writer[f"{n}best_recog"] |
| | | final_hyps = speech2text.streaming_decode( |
| | | speech[_end : len(speech)], is_final=True |
| | | ) |
| | | elif speech2text.simu_streaming: |
| | | final_hyps = speech2text.simu_streaming_decode(**batch) |
| | | else: |
| | | final_hyps = speech2text(**batch) |
| | | |
| | | # Write the result to each file |
| | | ibest_writer["token"][key] = " ".join(token) |
| | | # ibest_writer["token_int"][key] = " ".join(map(str, token_int)) |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | ibest_writer["rtf"][key] = rtf_cur |
| | | results = speech2text.hypotheses_to_results(final_hyps) |
| | | except TooShortUttError as e: |
| | | logging.warning(f"Utterance {keys} {e}") |
| | | hyp = Hypothesis(score=0.0, yseq=[], dec_state=None) |
| | | results = [[" ", ["<space>"], [2], hyp]] * nbest |
| | | |
| | | if text is not None: |
| | | text_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token) |
| | | item = {'key': key, 'value': text_postprocessed} |
| | | asr_result_list.append(item) |
| | | finish_count += 1 |
| | | # asr_utils.print_progress(finish_count / file_count) |
| | | if writer is not None: |
| | | ibest_writer["text"][key] = " ".join(word_lists) |
| | | key = keys[0] |
| | | for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results): |
| | | ibest_writer = writer[f"{n}best_recog"] |
| | | |
| | | logging.info("decoding, utt: {}, predictions: {}".format(key, text)) |
| | | rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor)) |
| | | logging.info(rtf_avg) |
| | | if writer is not None: |
| | | ibest_writer["rtf"]["rtf_avf"] = rtf_avg |
| | | return asr_result_list |
| | | ibest_writer["token"][key] = " ".join(token) |
| | | ibest_writer["token_int"][key] = " ".join(map(str, token_int)) |
| | | ibest_writer["score"][key] = str(hyp.score) |
| | | |
| | | return _forward |
| | | if text is not None: |
| | | ibest_writer["text"][key] = text |
| | | |
| | | |
| | | def get_parser(): |
| | | """Get Transducer model inference parser.""" |
| | | |
| | | parser = config_argparse.ArgumentParser( |
| | | description="ASR Decoding", |
| | | description="ASR Transducer Decoding", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | |
| | | # Note(kamo): Use '_' instead of '-' as separator. |
| | | # '-' is confusing if written in yaml. |
| | | parser.add_argument( |
| | | "--log_level", |
| | | type=lambda x: x.upper(), |
| | |
| | | default=1, |
| | | help="The number of workers used for DataLoader", |
| | | ) |
| | | parser.add_argument( |
| | | "--hotword", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="hotword file path or hotwords seperated by space" |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Input data related") |
| | | group.add_argument( |
| | | "--data_path_and_name_and_type", |
| | | type=str2triple_str, |
| | | required=False, |
| | | required=True, |
| | | action="append", |
| | | ) |
| | | group.add_argument("--key_file", type=str_or_none) |
| | |
| | | help="LM parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--word_lm_train_config", |
| | | type=str, |
| | | help="Word LM training configuration", |
| | | ) |
| | | group.add_argument( |
| | | "--word_lm_file", |
| | | type=str, |
| | | help="Word LM parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--ngram_file", |
| | | type=str, |
| | | help="N-gram parameter file", |
| | | ) |
| | | group.add_argument( |
| | | "--model_tag", |
| | | type=str, |
| | | help="Pretrained model tag. If specify this option, *_train_config and " |
| | | "*_file will be overwritten", |
| | | "*_file will be overwritten", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Beam-search related") |
| | |
| | | help="The batch size for inference", |
| | | ) |
| | | group.add_argument("--nbest", type=int, default=1, help="Output N-best hypotheses") |
| | | group.add_argument("--beam_size", type=int, default=20, help="Beam size") |
| | | group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty") |
| | | group.add_argument( |
| | | "--maxlenratio", |
| | | type=float, |
| | | default=0.0, |
| | | help="Input length ratio to obtain max output length. " |
| | | "If maxlenratio=0.0 (default), it uses a end-detect " |
| | | "function " |
| | | "to automatically find maximum hypothesis lengths." |
| | | "If maxlenratio<0.0, its absolute value is interpreted" |
| | | "as a constant max output length", |
| | | ) |
| | | group.add_argument( |
| | | "--minlenratio", |
| | | type=float, |
| | | default=0.0, |
| | | help="Input length ratio to obtain min output length", |
| | | ) |
| | | group.add_argument( |
| | | "--ctc_weight", |
| | | type=float, |
| | | default=0.5, |
| | | help="CTC weight in joint decoding", |
| | | ) |
| | | group.add_argument("--beam_size", type=int, default=5, help="Beam size") |
| | | group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight") |
| | | group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight") |
| | | group.add_argument("--streaming", type=str2bool, default=False) |
| | | |
| | | group.add_argument( |
| | | "--frontend_conf", |
| | | default=None, |
| | | help="", |
| | | "--beam_search_config", |
| | | default={}, |
| | | help="The keyword arguments for transducer beam search.", |
| | | ) |
| | | group.add_argument("--raw_inputs", type=list, default=None) |
| | | # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}]) |
| | | |
| | | group = parser.add_argument_group("Text converter related") |
| | | group.add_argument( |
| | |
| | | default=None, |
| | | choices=["char", "bpe", None], |
| | | help="The token type for ASR model. " |
| | | "If not given, refers from the training args", |
| | | "If not given, refers from the training args", |
| | | ) |
| | | group.add_argument( |
| | | "--bpemodel", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The model path of sentencepiece. " |
| | | "If not given, refers from the training args", |
| | | "If not given, refers from the training args", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Dynamic quantization related") |
| | | parser.add_argument( |
| | | "--quantize_asr_model", |
| | | type=bool, |
| | | default=False, |
| | | help="Apply dynamic quantization to ASR model.", |
| | | ) |
| | | parser.add_argument( |
| | | "--quantize_modules", |
| | | nargs="*", |
| | | default=None, |
| | | help="""Module names to apply dynamic quantization on. |
| | | The module names are provided as a list, where each name is separated |
| | | by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]). |
| | | Each specified name should be an attribute of 'torch.nn', e.g.: |
| | | torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""", |
| | | ) |
| | | parser.add_argument( |
| | | "--quantize_dtype", |
| | | type=str, |
| | | default="qint8", |
| | | choices=["float16", "qint8"], |
| | | help="Dtype for dynamic quantization.", |
| | | ) |
| | | |
| | | group = parser.add_argument_group("Streaming related") |
| | | parser.add_argument( |
| | | "--streaming", |
| | | type=bool, |
| | | default=False, |
| | | help="Whether to perform chunk-by-chunk inference.", |
| | | ) |
| | | parser.add_argument( |
| | | "--simu_streaming", |
| | | type=bool, |
| | | default=False, |
| | | help="Whether to simulate chunk-by-chunk inference.", |
| | | ) |
| | | parser.add_argument( |
| | | "--chunk_size", |
| | | type=int, |
| | | default=16, |
| | | help="Number of frames in chunk AFTER subsampling.", |
| | | ) |
| | | parser.add_argument( |
| | | "--left_context", |
| | | type=int, |
| | | default=32, |
| | | help="Number of frames in left context of the chunk AFTER subsampling.", |
| | | ) |
| | | parser.add_argument( |
| | | "--right_context", |
| | | type=int, |
| | | default=0, |
| | | help="Number of frames in right context of the chunk AFTER subsampling.", |
| | | ) |
| | | parser.add_argument( |
| | | "--display_partial_hypotheses", |
| | | type=bool, |
| | | default=False, |
| | | help="Whether to display partial hypotheses during chunk-by-chunk inference.", |
| | | ) |
| | | |
| | | return parser |
| | |
| | | |
| | | def main(cmd=None): |
| | | print(get_commandline_args(), file=sys.stderr) |
| | | |
| | | parser = get_parser() |
| | | args = parser.parse_args(cmd) |
| | | param_dict = {'hotword': args.hotword} |
| | | kwargs = vars(args) |
| | | |
| | | kwargs.pop("config", None) |
| | | kwargs['param_dict'] = param_dict |
| | | inference(**kwargs) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| | | |
| | | # from modelscope.pipelines import pipeline |
| | | # from modelscope.utils.constant import Tasks |
| | | # |
| | | # inference_16k_pipline = pipeline( |
| | | # task=Tasks.auto_speech_recognition, |
| | | # model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch') |
| | | # |
| | | # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav') |
| | | # print(rec_result) |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | |
| | | import os |
| | | |
| | | from funasr.tasks.asr import ASRTransducerTask |
| | | |
| | | |
| | | # for ASR Training |
| | | def parse_args(): |
| | | parser = ASRTransducerTask.get_parser() |
| | | parser.add_argument( |
| | | "--gpu_id", |
| | | type=int, |
| | | default=0, |
| | | help="local gpu id.", |
| | | ) |
| | | args = parser.parse_args() |
| | | return args |
| | | |
| | | |
| | | def main(args=None, cmd=None): |
| | | # for ASR Training |
| | | ASRTransducerTask.main(args=args, cmd=cmd) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | args = parse_args() |
| | | |
| | | # setup local gpu_id |
| | | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) |
| | | |
| | | # DDP settings |
| | | if args.ngpu > 1: |
| | | args.distributed = True |
| | | else: |
| | | args.distributed = False |
| | | assert args.num_worker_count == 1 |
| | | |
| | | # re-compute batch size: when dataset type is small |
| | | if args.dataset_type == "small": |
| | | if args.batch_size is not None: |
| | | args.batch_size = args.batch_size * args.ngpu |
| | | if args.batch_bins is not None: |
| | | args.batch_bins = args.batch_bins * args.ngpu |
| | | |
| | | main(args=args) |
| New file |
| | |
| | | """RNN decoder definition for Transducer models.""" |
| | | |
| | | from typing import List, Optional, Tuple |
| | | |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.beam_search.beam_search_transducer import Hypothesis |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | |
| | | class RNNTDecoder(torch.nn.Module): |
| | | """RNN decoder module. |
| | | |
| | | Args: |
| | | vocab_size: Vocabulary size. |
| | | embed_size: Embedding size. |
| | | hidden_size: Hidden size.. |
| | | rnn_type: Decoder layers type. |
| | | num_layers: Number of decoder layers. |
| | | dropout_rate: Dropout rate for decoder layers. |
| | | embed_dropout_rate: Dropout rate for embedding layer. |
| | | embed_pad: Embedding padding symbol ID. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | embed_size: int = 256, |
| | | hidden_size: int = 256, |
| | | rnn_type: str = "lstm", |
| | | num_layers: int = 1, |
| | | dropout_rate: float = 0.0, |
| | | embed_dropout_rate: float = 0.0, |
| | | embed_pad: int = 0, |
| | | ) -> None: |
| | | """Construct a RNNDecoder object.""" |
| | | super().__init__() |
| | | |
| | | assert check_argument_types() |
| | | |
| | | if rnn_type not in ("lstm", "gru"): |
| | | raise ValueError(f"Not supported: rnn_type={rnn_type}") |
| | | |
| | | self.embed = torch.nn.Embedding(vocab_size, embed_size, padding_idx=embed_pad) |
| | | self.dropout_embed = torch.nn.Dropout(p=embed_dropout_rate) |
| | | |
| | | rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU |
| | | |
| | | self.rnn = torch.nn.ModuleList( |
| | | [rnn_class(embed_size, hidden_size, 1, batch_first=True)] |
| | | ) |
| | | |
| | | for _ in range(1, num_layers): |
| | | self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)] |
| | | |
| | | self.dropout_rnn = torch.nn.ModuleList( |
| | | [torch.nn.Dropout(p=dropout_rate) for _ in range(num_layers)] |
| | | ) |
| | | |
| | | self.dlayers = num_layers |
| | | self.dtype = rnn_type |
| | | |
| | | self.output_size = hidden_size |
| | | self.vocab_size = vocab_size |
| | | |
| | | self.device = next(self.parameters()).device |
| | | self.score_cache = {} |
| | | |
| | | def forward( |
| | | self, |
| | | labels: torch.Tensor, |
| | | label_lens: torch.Tensor, |
| | | states: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None, |
| | | ) -> torch.Tensor: |
| | | """Encode source label sequences. |
| | | |
| | | Args: |
| | | labels: Label ID sequences. (B, L) |
| | | states: Decoder hidden states. |
| | | ((N, B, D_dec), (N, B, D_dec) or None) or None |
| | | |
| | | Returns: |
| | | dec_out: Decoder output sequences. (B, U, D_dec) |
| | | |
| | | """ |
| | | if states is None: |
| | | states = self.init_state(labels.size(0)) |
| | | |
| | | dec_embed = self.dropout_embed(self.embed(labels)) |
| | | dec_out, states = self.rnn_forward(dec_embed, states) |
| | | return dec_out |
| | | |
| | | def rnn_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | state: Tuple[torch.Tensor, Optional[torch.Tensor]], |
| | | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: |
| | | """Encode source label sequences. |
| | | |
| | | Args: |
| | | x: RNN input sequences. (B, D_emb) |
| | | state: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) |
| | | |
| | | Returns: |
| | | x: RNN output sequences. (B, D_dec) |
| | | (h_next, c_next): Decoder hidden states. |
| | | (N, B, D_dec), (N, B, D_dec) or None) |
| | | |
| | | """ |
| | | h_prev, c_prev = state |
| | | h_next, c_next = self.init_state(x.size(0)) |
| | | |
| | | for layer in range(self.dlayers): |
| | | if self.dtype == "lstm": |
| | | x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[ |
| | | layer |
| | | ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])) |
| | | else: |
| | | x, h_next[layer : layer + 1] = self.rnn[layer]( |
| | | x, hx=h_prev[layer : layer + 1] |
| | | ) |
| | | |
| | | x = self.dropout_rnn[layer](x) |
| | | |
| | | return x, (h_next, c_next) |
| | | |
| | | def score( |
| | | self, |
| | | label: torch.Tensor, |
| | | label_sequence: List[int], |
| | | dec_state: Tuple[torch.Tensor, Optional[torch.Tensor]], |
| | | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: |
| | | """One-step forward hypothesis. |
| | | |
| | | Args: |
| | | label: Previous label. (1, 1) |
| | | label_sequence: Current label sequence. |
| | | dec_state: Previous decoder hidden states. |
| | | ((N, 1, D_dec), (N, 1, D_dec) or None) |
| | | |
| | | Returns: |
| | | dec_out: Decoder output sequence. (1, D_dec) |
| | | dec_state: Decoder hidden states. |
| | | ((N, 1, D_dec), (N, 1, D_dec) or None) |
| | | |
| | | """ |
| | | str_labels = "_".join(map(str, label_sequence)) |
| | | |
| | | if str_labels in self.score_cache: |
| | | dec_out, dec_state = self.score_cache[str_labels] |
| | | else: |
| | | dec_embed = self.embed(label) |
| | | dec_out, dec_state = self.rnn_forward(dec_embed, dec_state) |
| | | |
| | | self.score_cache[str_labels] = (dec_out, dec_state) |
| | | |
| | | return dec_out[0], dec_state |
| | | |
| | | def batch_score( |
| | | self, |
| | | hyps: List[Hypothesis], |
| | | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: |
| | | """One-step forward hypotheses. |
| | | |
| | | Args: |
| | | hyps: Hypotheses. |
| | | |
| | | Returns: |
| | | dec_out: Decoder output sequences. (B, D_dec) |
| | | states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) |
| | | |
| | | """ |
| | | labels = torch.LongTensor([[h.yseq[-1]] for h in hyps], device=self.device) |
| | | dec_embed = self.embed(labels) |
| | | |
| | | states = self.create_batch_states([h.dec_state for h in hyps]) |
| | | dec_out, states = self.rnn_forward(dec_embed, states) |
| | | |
| | | return dec_out.squeeze(1), states |
| | | |
| | | def set_device(self, device: torch.device) -> None: |
| | | """Set GPU device to use. |
| | | |
| | | Args: |
| | | device: Device ID. |
| | | |
| | | """ |
| | | self.device = device |
| | | |
| | | def init_state( |
| | | self, batch_size: int |
| | | ) -> Tuple[torch.Tensor, Optional[torch.tensor]]: |
| | | """Initialize decoder states. |
| | | |
| | | Args: |
| | | batch_size: Batch size. |
| | | |
| | | Returns: |
| | | : Initial decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) |
| | | |
| | | """ |
| | | h_n = torch.zeros( |
| | | self.dlayers, |
| | | batch_size, |
| | | self.output_size, |
| | | device=self.device, |
| | | ) |
| | | |
| | | if self.dtype == "lstm": |
| | | c_n = torch.zeros( |
| | | self.dlayers, |
| | | batch_size, |
| | | self.output_size, |
| | | device=self.device, |
| | | ) |
| | | |
| | | return (h_n, c_n) |
| | | |
| | | return (h_n, None) |
| | | |
| | | def select_state( |
| | | self, states: Tuple[torch.Tensor, Optional[torch.Tensor]], idx: int |
| | | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | | """Get specified ID state from decoder hidden states. |
| | | |
| | | Args: |
| | | states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) |
| | | idx: State ID to extract. |
| | | |
| | | Returns: |
| | | : Decoder hidden state for given ID. ((N, 1, D_dec), (N, 1, D_dec) or None) |
| | | |
| | | """ |
| | | return ( |
| | | states[0][:, idx : idx + 1, :], |
| | | states[1][:, idx : idx + 1, :] if self.dtype == "lstm" else None, |
| | | ) |
| | | |
| | | def create_batch_states( |
| | | self, |
| | | new_states: List[Tuple[torch.Tensor, Optional[torch.Tensor]]], |
| | | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | | """Create decoder hidden states. |
| | | |
| | | Args: |
| | | new_states: Decoder hidden states. [N x ((1, D_dec), (1, D_dec) or None)] |
| | | |
| | | Returns: |
| | | states: Decoder hidden states. ((N, B, D_dec), (N, B, D_dec) or None) |
| | | |
| | | """ |
| | | return ( |
| | | torch.cat([s[0] for s in new_states], dim=1), |
| | | torch.cat([s[1] for s in new_states], dim=1) |
| | | if self.dtype == "lstm" |
| | | else None, |
| | | ) |
| New file |
| | |
| | | """ESPnet2 ASR Transducer model.""" |
| | | |
| | | import logging |
| | | from contextlib import contextmanager |
| | | from typing import Dict, List, Optional, Tuple, Union |
| | | |
| | | import torch |
| | | from packaging.version import parse as V |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.decoder.rnnt_decoder import RNNTDecoder |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | from funasr.modules.nets_utils import get_transducer_task_io |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | |
| | | if V(torch.__version__) >= V("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | | else: |
| | | |
| | | @contextmanager |
| | | def autocast(enabled=True): |
| | | yield |
| | | |
| | | |
| | | class TransducerModel(AbsESPnetModel): |
| | | """ESPnet2ASRTransducerModel module definition. |
| | | |
| | | Args: |
| | | vocab_size: Size of complete vocabulary (w/ EOS and blank included). |
| | | token_list: List of token |
| | | frontend: Frontend module. |
| | | specaug: SpecAugment module. |
| | | normalize: Normalization module. |
| | | encoder: Encoder module. |
| | | decoder: Decoder module. |
| | | joint_network: Joint Network module. |
| | | transducer_weight: Weight of the Transducer loss. |
| | | fastemit_lambda: FastEmit lambda value. |
| | | auxiliary_ctc_weight: Weight of auxiliary CTC loss. |
| | | auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. |
| | | auxiliary_lm_loss_weight: Weight of auxiliary LM loss. |
| | | auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. |
| | | ignore_id: Initial padding ID. |
| | | sym_space: Space symbol. |
| | | sym_blank: Blank Symbol |
| | | report_cer: Whether to report Character Error Rate during validation. |
| | | report_wer: Whether to report Word Error Rate during validation. |
| | | extract_feats_in_collect_stats: Whether to use extract_feats stats collection. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | token_list: Union[Tuple[str, ...], List[str]], |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | encoder: Encoder, |
| | | decoder: RNNTDecoder, |
| | | joint_network: JointNetwork, |
| | | att_decoder: Optional[AbsAttDecoder] = None, |
| | | transducer_weight: float = 1.0, |
| | | fastemit_lambda: float = 0.0, |
| | | auxiliary_ctc_weight: float = 0.0, |
| | | auxiliary_ctc_dropout_rate: float = 0.0, |
| | | auxiliary_lm_loss_weight: float = 0.0, |
| | | auxiliary_lm_loss_smoothing: float = 0.0, |
| | | ignore_id: int = -1, |
| | | sym_space: str = "<space>", |
| | | sym_blank: str = "<blank>", |
| | | report_cer: bool = True, |
| | | report_wer: bool = True, |
| | | extract_feats_in_collect_stats: bool = True, |
| | | ) -> None: |
| | | """Construct an ESPnetASRTransducerModel object.""" |
| | | super().__init__() |
| | | |
| | | assert check_argument_types() |
| | | |
| | | # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) |
| | | self.blank_id = 0 |
| | | self.vocab_size = vocab_size |
| | | self.ignore_id = ignore_id |
| | | self.token_list = token_list.copy() |
| | | |
| | | self.sym_space = sym_space |
| | | self.sym_blank = sym_blank |
| | | |
| | | self.frontend = frontend |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | |
| | | self.encoder = encoder |
| | | self.decoder = decoder |
| | | self.joint_network = joint_network |
| | | |
| | | self.criterion_transducer = None |
| | | self.error_calculator = None |
| | | |
| | | self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 |
| | | self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 |
| | | |
| | | if self.use_auxiliary_ctc: |
| | | self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) |
| | | self.ctc_dropout_rate = auxiliary_ctc_dropout_rate |
| | | |
| | | if self.use_auxiliary_lm_loss: |
| | | self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) |
| | | self.lm_loss_smoothing = auxiliary_lm_loss_smoothing |
| | | |
| | | self.transducer_weight = transducer_weight |
| | | self.fastemit_lambda = fastemit_lambda |
| | | |
| | | self.auxiliary_ctc_weight = auxiliary_ctc_weight |
| | | self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight |
| | | |
| | | self.report_cer = report_cer |
| | | self.report_wer = report_wer |
| | | |
| | | self.extract_feats_in_collect_stats = extract_feats_in_collect_stats |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Forward architecture and compute loss(es). |
| | | |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | text: Label ID sequences. (B, L) |
| | | text_lengths: Label ID sequences lengths. (B,) |
| | | kwargs: Contains "utts_id". |
| | | |
| | | Return: |
| | | loss: Main loss value. |
| | | stats: Task statistics. |
| | | weight: Task weights. |
| | | |
| | | """ |
| | | assert text_lengths.dim() == 1, text_lengths.shape |
| | | assert ( |
| | | speech.shape[0] |
| | | == speech_lengths.shape[0] |
| | | == text.shape[0] |
| | | == text_lengths.shape[0] |
| | | ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) |
| | | |
| | | batch_size = speech.shape[0] |
| | | text = text[:, : text_lengths.max()] |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | # 2. Transducer-related I/O preparation |
| | | decoder_in, target, t_len, u_len = get_transducer_task_io( |
| | | text, |
| | | encoder_out_lens, |
| | | ignore_id=self.ignore_id, |
| | | ) |
| | | |
| | | # 3. Decoder |
| | | self.decoder.set_device(encoder_out.device) |
| | | decoder_out = self.decoder(decoder_in, u_len) |
| | | |
| | | # 4. Joint Network |
| | | joint_out = self.joint_network( |
| | | encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) |
| | | ) |
| | | |
| | | # 5. Losses |
| | | loss_trans, cer_trans, wer_trans = self._calc_transducer_loss( |
| | | encoder_out, |
| | | joint_out, |
| | | target, |
| | | t_len, |
| | | u_len, |
| | | ) |
| | | |
| | | loss_ctc, loss_lm = 0.0, 0.0 |
| | | |
| | | if self.use_auxiliary_ctc: |
| | | loss_ctc = self._calc_ctc_loss( |
| | | encoder_out, |
| | | target, |
| | | t_len, |
| | | u_len, |
| | | ) |
| | | |
| | | if self.use_auxiliary_lm_loss: |
| | | loss_lm = self._calc_lm_loss(decoder_out, target) |
| | | |
| | | loss = ( |
| | | self.transducer_weight * loss_trans |
| | | + self.auxiliary_ctc_weight * loss_ctc |
| | | + self.auxiliary_lm_loss_weight * loss_lm |
| | | ) |
| | | |
| | | stats = dict( |
| | | loss=loss.detach(), |
| | | loss_transducer=loss_trans.detach(), |
| | | aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, |
| | | aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, |
| | | cer_transducer=cer_trans, |
| | | wer_transducer=wer_trans, |
| | | ) |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | |
| | | return loss, stats, weight |
| | | |
| | | def collect_feats( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | """Collect features sequences and features lengths sequences. |
| | | |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | text: Label ID sequences. (B, L) |
| | | text_lengths: Label ID sequences lengths. (B,) |
| | | kwargs: Contains "utts_id". |
| | | |
| | | Return: |
| | | {}: "feats": Features sequences. (B, T, D_feats), |
| | | "feats_lengths": Features sequences lengths. (B,) |
| | | |
| | | """ |
| | | if self.extract_feats_in_collect_stats: |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | else: |
| | | # Generate dummy stats if extract_feats_in_collect_stats is False |
| | | logging.warning( |
| | | "Generating dummy stats for feats and feats_lengths, " |
| | | "because encoder_conf.extract_feats_in_collect_stats is " |
| | | f"{self.extract_feats_in_collect_stats}" |
| | | ) |
| | | |
| | | feats, feats_lengths = speech, speech_lengths |
| | | |
| | | return {"feats": feats, "feats_lengths": feats_lengths} |
| | | |
| | | def encode( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encoder speech sequences. |
| | | |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | |
| | | Return: |
| | | encoder_out: Encoder outputs. (B, T, D_enc) |
| | | encoder_out_lens: Encoder outputs lengths. (B,) |
| | | |
| | | """ |
| | | with autocast(False): |
| | | # 1. Extract feats |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | |
| | | # 2. Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | feats, feats_lengths = self.specaug(feats, feats_lengths) |
| | | |
| | | # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | if self.normalize is not None: |
| | | feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | |
| | | # 4. Forward encoder |
| | | encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths) |
| | | |
| | | assert encoder_out.size(0) == speech.size(0), ( |
| | | encoder_out.size(), |
| | | speech.size(0), |
| | | ) |
| | | assert encoder_out.size(1) <= encoder_out_lens.max(), ( |
| | | encoder_out.size(), |
| | | encoder_out_lens.max(), |
| | | ) |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def _extract_feats( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Extract features sequences and features sequences lengths. |
| | | |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | |
| | | Return: |
| | | feats: Features sequences. (B, T, D_feats) |
| | | feats_lengths: Features sequences lengths. (B,) |
| | | |
| | | """ |
| | | assert speech_lengths.dim() == 1, speech_lengths.shape |
| | | |
| | | # for data-parallel |
| | | speech = speech[:, : speech_lengths.max()] |
| | | |
| | | if self.frontend is not None: |
| | | feats, feats_lengths = self.frontend(speech, speech_lengths) |
| | | else: |
| | | feats, feats_lengths = speech, speech_lengths |
| | | |
| | | return feats, feats_lengths |
| | | |
| | | def _calc_transducer_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | joint_out: torch.Tensor, |
| | | target: torch.Tensor, |
| | | t_len: torch.Tensor, |
| | | u_len: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: |
| | | """Compute Transducer loss. |
| | | |
| | | Args: |
| | | encoder_out: Encoder output sequences. (B, T, D_enc) |
| | | joint_out: Joint Network output sequences (B, T, U, D_joint) |
| | | target: Target label ID sequences. (B, L) |
| | | t_len: Encoder output sequences lengths. (B,) |
| | | u_len: Target label ID sequences lengths. (B,) |
| | | |
| | | Return: |
| | | loss_transducer: Transducer loss value. |
| | | cer_transducer: Character error rate for Transducer. |
| | | wer_transducer: Word Error Rate for Transducer. |
| | | |
| | | """ |
| | | if self.criterion_transducer is None: |
| | | try: |
| | | # from warprnnt_pytorch import RNNTLoss |
| | | # self.criterion_transducer = RNNTLoss( |
| | | # reduction="mean", |
| | | # fastemit_lambda=self.fastemit_lambda, |
| | | # ) |
| | | from warp_rnnt import rnnt_loss as RNNTLoss |
| | | self.criterion_transducer = RNNTLoss |
| | | |
| | | except ImportError: |
| | | logging.error( |
| | | "warp-rnnt was not installed." |
| | | "Please consult the installation documentation." |
| | | ) |
| | | exit(1) |
| | | |
| | | # loss_transducer = self.criterion_transducer( |
| | | # joint_out, |
| | | # target, |
| | | # t_len, |
| | | # u_len, |
| | | # ) |
| | | log_probs = torch.log_softmax(joint_out, dim=-1) |
| | | |
| | | loss_transducer = self.criterion_transducer( |
| | | log_probs, |
| | | target, |
| | | t_len, |
| | | u_len, |
| | | reduction="mean", |
| | | blank=self.blank_id, |
| | | fastemit_lambda=self.fastemit_lambda, |
| | | gather=True, |
| | | ) |
| | | |
| | | if not self.training and (self.report_cer or self.report_wer): |
| | | if self.error_calculator is None: |
| | | from espnet2.asr_transducer.error_calculator import ErrorCalculator |
| | | |
| | | self.error_calculator = ErrorCalculator( |
| | | self.decoder, |
| | | self.joint_network, |
| | | self.token_list, |
| | | self.sym_space, |
| | | self.sym_blank, |
| | | report_cer=self.report_cer, |
| | | report_wer=self.report_wer, |
| | | ) |
| | | |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) |
| | | |
| | | return loss_transducer, cer_transducer, wer_transducer |
| | | |
| | | return loss_transducer, None, None |
| | | |
| | | def _calc_ctc_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | target: torch.Tensor, |
| | | t_len: torch.Tensor, |
| | | u_len: torch.Tensor, |
| | | ) -> torch.Tensor: |
| | | """Compute CTC loss. |
| | | |
| | | Args: |
| | | encoder_out: Encoder output sequences. (B, T, D_enc) |
| | | target: Target label ID sequences. (B, L) |
| | | t_len: Encoder output sequences lengths. (B,) |
| | | u_len: Target label ID sequences lengths. (B,) |
| | | |
| | | Return: |
| | | loss_ctc: CTC loss value. |
| | | |
| | | """ |
| | | ctc_in = self.ctc_lin( |
| | | torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) |
| | | ) |
| | | ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) |
| | | |
| | | target_mask = target != 0 |
| | | ctc_target = target[target_mask].cpu() |
| | | |
| | | with torch.backends.cudnn.flags(deterministic=True): |
| | | loss_ctc = torch.nn.functional.ctc_loss( |
| | | ctc_in, |
| | | ctc_target, |
| | | t_len, |
| | | u_len, |
| | | zero_infinity=True, |
| | | reduction="sum", |
| | | ) |
| | | loss_ctc /= target.size(0) |
| | | |
| | | return loss_ctc |
| | | |
| | | def _calc_lm_loss( |
| | | self, |
| | | decoder_out: torch.Tensor, |
| | | target: torch.Tensor, |
| | | ) -> torch.Tensor: |
| | | """Compute LM loss. |
| | | |
| | | Args: |
| | | decoder_out: Decoder output sequences. (B, U, D_dec) |
| | | target: Target label ID sequences. (B, L) |
| | | |
| | | Return: |
| | | loss_lm: LM loss value. |
| | | |
| | | """ |
| | | lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) |
| | | lm_target = target.view(-1).type(torch.int64) |
| | | |
| | | with torch.no_grad(): |
| | | true_dist = lm_loss_in.clone() |
| | | true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) |
| | | |
| | | # Ignore blank ID (0) |
| | | ignore = lm_target == 0 |
| | | lm_target = lm_target.masked_fill(ignore, 0) |
| | | |
| | | true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) |
| | | |
| | | loss_lm = torch.nn.functional.kl_div( |
| | | torch.log_softmax(lm_loss_in, dim=1), |
| | | true_dist, |
| | | reduction="none", |
| | | ) |
| | | loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( |
| | | 0 |
| | | ) |
| | | |
| | | return loss_lm |
| | | |
| | | class UnifiedTransducerModel(AbsESPnetModel): |
| | | """ESPnet2ASRTransducerModel module definition. |
| | | Args: |
| | | vocab_size: Size of complete vocabulary (w/ EOS and blank included). |
| | | token_list: List of token |
| | | frontend: Frontend module. |
| | | specaug: SpecAugment module. |
| | | normalize: Normalization module. |
| | | encoder: Encoder module. |
| | | decoder: Decoder module. |
| | | joint_network: Joint Network module. |
| | | transducer_weight: Weight of the Transducer loss. |
| | | fastemit_lambda: FastEmit lambda value. |
| | | auxiliary_ctc_weight: Weight of auxiliary CTC loss. |
| | | auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. |
| | | auxiliary_lm_loss_weight: Weight of auxiliary LM loss. |
| | | auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. |
| | | ignore_id: Initial padding ID. |
| | | sym_space: Space symbol. |
| | | sym_blank: Blank Symbol |
| | | report_cer: Whether to report Character Error Rate during validation. |
| | | report_wer: Whether to report Word Error Rate during validation. |
| | | extract_feats_in_collect_stats: Whether to use extract_feats stats collection. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | token_list: Union[Tuple[str, ...], List[str]], |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | encoder: Encoder, |
| | | decoder: RNNTDecoder, |
| | | joint_network: JointNetwork, |
| | | att_decoder: Optional[AbsAttDecoder] = None, |
| | | transducer_weight: float = 1.0, |
| | | fastemit_lambda: float = 0.0, |
| | | auxiliary_ctc_weight: float = 0.0, |
| | | auxiliary_att_weight: float = 0.0, |
| | | auxiliary_ctc_dropout_rate: float = 0.0, |
| | | auxiliary_lm_loss_weight: float = 0.0, |
| | | auxiliary_lm_loss_smoothing: float = 0.0, |
| | | ignore_id: int = -1, |
| | | sym_space: str = "<space>", |
| | | sym_blank: str = "<blank>", |
| | | report_cer: bool = True, |
| | | report_wer: bool = True, |
| | | sym_sos: str = "<sos/eos>", |
| | | sym_eos: str = "<sos/eos>", |
| | | extract_feats_in_collect_stats: bool = True, |
| | | lsm_weight: float = 0.0, |
| | | length_normalized_loss: bool = False, |
| | | ) -> None: |
| | | """Construct an ESPnetASRTransducerModel object.""" |
| | | super().__init__() |
| | | |
| | | assert check_argument_types() |
| | | |
| | | # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) |
| | | self.blank_id = 0 |
| | | |
| | | if sym_sos in token_list: |
| | | self.sos = token_list.index(sym_sos) |
| | | else: |
| | | self.sos = vocab_size - 1 |
| | | if sym_eos in token_list: |
| | | self.eos = token_list.index(sym_eos) |
| | | else: |
| | | self.eos = vocab_size - 1 |
| | | |
| | | self.vocab_size = vocab_size |
| | | self.ignore_id = ignore_id |
| | | self.token_list = token_list.copy() |
| | | |
| | | self.sym_space = sym_space |
| | | self.sym_blank = sym_blank |
| | | |
| | | self.frontend = frontend |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | |
| | | self.encoder = encoder |
| | | self.decoder = decoder |
| | | self.joint_network = joint_network |
| | | |
| | | self.criterion_transducer = None |
| | | self.error_calculator = None |
| | | |
| | | self.use_auxiliary_ctc = auxiliary_ctc_weight > 0 |
| | | self.use_auxiliary_att = auxiliary_att_weight > 0 |
| | | self.use_auxiliary_lm_loss = auxiliary_lm_loss_weight > 0 |
| | | |
| | | if self.use_auxiliary_ctc: |
| | | self.ctc_lin = torch.nn.Linear(encoder.output_size, vocab_size) |
| | | self.ctc_dropout_rate = auxiliary_ctc_dropout_rate |
| | | |
| | | if self.use_auxiliary_att: |
| | | self.att_decoder = att_decoder |
| | | |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=vocab_size, |
| | | padding_idx=ignore_id, |
| | | smoothing=lsm_weight, |
| | | normalize_length=length_normalized_loss, |
| | | ) |
| | | |
| | | if self.use_auxiliary_lm_loss: |
| | | self.lm_lin = torch.nn.Linear(decoder.output_size, vocab_size) |
| | | self.lm_loss_smoothing = auxiliary_lm_loss_smoothing |
| | | |
| | | self.transducer_weight = transducer_weight |
| | | self.fastemit_lambda = fastemit_lambda |
| | | |
| | | self.auxiliary_ctc_weight = auxiliary_ctc_weight |
| | | self.auxiliary_att_weight = auxiliary_att_weight |
| | | self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight |
| | | |
| | | self.report_cer = report_cer |
| | | self.report_wer = report_wer |
| | | |
| | | self.extract_feats_in_collect_stats = extract_feats_in_collect_stats |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Forward architecture and compute loss(es). |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | text: Label ID sequences. (B, L) |
| | | text_lengths: Label ID sequences lengths. (B,) |
| | | kwargs: Contains "utts_id". |
| | | Return: |
| | | loss: Main loss value. |
| | | stats: Task statistics. |
| | | weight: Task weights. |
| | | """ |
| | | assert text_lengths.dim() == 1, text_lengths.shape |
| | | assert ( |
| | | speech.shape[0] |
| | | == speech_lengths.shape[0] |
| | | == text.shape[0] |
| | | == text_lengths.shape[0] |
| | | ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) |
| | | |
| | | batch_size = speech.shape[0] |
| | | text = text[:, : text_lengths.max()] |
| | | #print(speech.shape) |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_chunk, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | loss_att, loss_att_chunk = 0.0, 0.0 |
| | | |
| | | if self.use_auxiliary_att: |
| | | loss_att, _ = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | loss_att_chunk, _ = self._calc_att_loss( |
| | | encoder_out_chunk, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # 2. Transducer-related I/O preparation |
| | | decoder_in, target, t_len, u_len = get_transducer_task_io( |
| | | text, |
| | | encoder_out_lens, |
| | | ignore_id=self.ignore_id, |
| | | ) |
| | | |
| | | # 3. Decoder |
| | | self.decoder.set_device(encoder_out.device) |
| | | decoder_out = self.decoder(decoder_in, u_len) |
| | | |
| | | # 4. Joint Network |
| | | joint_out = self.joint_network( |
| | | encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) |
| | | ) |
| | | |
| | | joint_out_chunk = self.joint_network( |
| | | encoder_out_chunk.unsqueeze(2), decoder_out.unsqueeze(1) |
| | | ) |
| | | |
| | | # 5. Losses |
| | | loss_trans_utt, cer_trans, wer_trans = self._calc_transducer_loss( |
| | | encoder_out, |
| | | joint_out, |
| | | target, |
| | | t_len, |
| | | u_len, |
| | | ) |
| | | |
| | | loss_trans_chunk, cer_trans_chunk, wer_trans_chunk = self._calc_transducer_loss( |
| | | encoder_out_chunk, |
| | | joint_out_chunk, |
| | | target, |
| | | t_len, |
| | | u_len, |
| | | ) |
| | | |
| | | loss_ctc, loss_ctc_chunk, loss_lm = 0.0, 0.0, 0.0 |
| | | |
| | | if self.use_auxiliary_ctc: |
| | | loss_ctc = self._calc_ctc_loss( |
| | | encoder_out, |
| | | target, |
| | | t_len, |
| | | u_len, |
| | | ) |
| | | loss_ctc_chunk = self._calc_ctc_loss( |
| | | encoder_out_chunk, |
| | | target, |
| | | t_len, |
| | | u_len, |
| | | ) |
| | | |
| | | if self.use_auxiliary_lm_loss: |
| | | loss_lm = self._calc_lm_loss(decoder_out, target) |
| | | |
| | | loss_trans = loss_trans_utt + loss_trans_chunk |
| | | loss_ctc = loss_ctc + loss_ctc_chunk |
| | | loss_ctc = loss_att + loss_att_chunk |
| | | |
| | | loss = ( |
| | | self.transducer_weight * loss_trans |
| | | + self.auxiliary_ctc_weight * loss_ctc |
| | | + self.auxiliary_att_weight * loss_att |
| | | + self.auxiliary_lm_loss_weight * loss_lm |
| | | ) |
| | | |
| | | stats = dict( |
| | | loss=loss.detach(), |
| | | loss_transducer=loss_trans_utt.detach(), |
| | | loss_transducer_chunk=loss_trans_chunk.detach(), |
| | | aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, |
| | | aux_ctc_loss_chunk=loss_ctc_chunk.detach() if loss_ctc_chunk > 0.0 else None, |
| | | aux_att_loss=loss_att.detach() if loss_att > 0.0 else None, |
| | | aux_att_loss_chunk=loss_att_chunk.detach() if loss_att_chunk > 0.0 else None, |
| | | aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, |
| | | cer_transducer=cer_trans, |
| | | wer_transducer=wer_trans, |
| | | cer_transducer_chunk=cer_trans_chunk, |
| | | wer_transducer_chunk=wer_trans_chunk, |
| | | ) |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def collect_feats( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | """Collect features sequences and features lengths sequences. |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | text: Label ID sequences. (B, L) |
| | | text_lengths: Label ID sequences lengths. (B,) |
| | | kwargs: Contains "utts_id". |
| | | Return: |
| | | {}: "feats": Features sequences. (B, T, D_feats), |
| | | "feats_lengths": Features sequences lengths. (B,) |
| | | """ |
| | | if self.extract_feats_in_collect_stats: |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | else: |
| | | # Generate dummy stats if extract_feats_in_collect_stats is False |
| | | logging.warning( |
| | | "Generating dummy stats for feats and feats_lengths, " |
| | | "because encoder_conf.extract_feats_in_collect_stats is " |
| | | f"{self.extract_feats_in_collect_stats}" |
| | | ) |
| | | |
| | | feats, feats_lengths = speech, speech_lengths |
| | | |
| | | return {"feats": feats, "feats_lengths": feats_lengths} |
| | | |
| | | def encode( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encoder speech sequences. |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | Return: |
| | | encoder_out: Encoder outputs. (B, T, D_enc) |
| | | encoder_out_lens: Encoder outputs lengths. (B,) |
| | | """ |
| | | with autocast(False): |
| | | # 1. Extract feats |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | |
| | | # 2. Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | feats, feats_lengths = self.specaug(feats, feats_lengths) |
| | | |
| | | # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | if self.normalize is not None: |
| | | feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | |
| | | # 4. Forward encoder |
| | | encoder_out, encoder_out_chunk, encoder_out_lens = self.encoder(feats, feats_lengths) |
| | | |
| | | assert encoder_out.size(0) == speech.size(0), ( |
| | | encoder_out.size(), |
| | | speech.size(0), |
| | | ) |
| | | assert encoder_out.size(1) <= encoder_out_lens.max(), ( |
| | | encoder_out.size(), |
| | | encoder_out_lens.max(), |
| | | ) |
| | | |
| | | return encoder_out, encoder_out_chunk, encoder_out_lens |
| | | |
| | | def _extract_feats( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Extract features sequences and features sequences lengths. |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | Return: |
| | | feats: Features sequences. (B, T, D_feats) |
| | | feats_lengths: Features sequences lengths. (B,) |
| | | """ |
| | | assert speech_lengths.dim() == 1, speech_lengths.shape |
| | | |
| | | # for data-parallel |
| | | speech = speech[:, : speech_lengths.max()] |
| | | |
| | | if self.frontend is not None: |
| | | feats, feats_lengths = self.frontend(speech, speech_lengths) |
| | | else: |
| | | feats, feats_lengths = speech, speech_lengths |
| | | |
| | | return feats, feats_lengths |
| | | |
| | | def _calc_transducer_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | joint_out: torch.Tensor, |
| | | target: torch.Tensor, |
| | | t_len: torch.Tensor, |
| | | u_len: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: |
| | | """Compute Transducer loss. |
| | | Args: |
| | | encoder_out: Encoder output sequences. (B, T, D_enc) |
| | | joint_out: Joint Network output sequences (B, T, U, D_joint) |
| | | target: Target label ID sequences. (B, L) |
| | | t_len: Encoder output sequences lengths. (B,) |
| | | u_len: Target label ID sequences lengths. (B,) |
| | | Return: |
| | | loss_transducer: Transducer loss value. |
| | | cer_transducer: Character error rate for Transducer. |
| | | wer_transducer: Word Error Rate for Transducer. |
| | | """ |
| | | if self.criterion_transducer is None: |
| | | try: |
| | | # from warprnnt_pytorch import RNNTLoss |
| | | # self.criterion_transducer = RNNTLoss( |
| | | # reduction="mean", |
| | | # fastemit_lambda=self.fastemit_lambda, |
| | | # ) |
| | | from warp_rnnt import rnnt_loss as RNNTLoss |
| | | self.criterion_transducer = RNNTLoss |
| | | |
| | | except ImportError: |
| | | logging.error( |
| | | "warp-rnnt was not installed." |
| | | "Please consult the installation documentation." |
| | | ) |
| | | exit(1) |
| | | |
| | | # loss_transducer = self.criterion_transducer( |
| | | # joint_out, |
| | | # target, |
| | | # t_len, |
| | | # u_len, |
| | | # ) |
| | | log_probs = torch.log_softmax(joint_out, dim=-1) |
| | | |
| | | loss_transducer = self.criterion_transducer( |
| | | log_probs, |
| | | target, |
| | | t_len, |
| | | u_len, |
| | | reduction="mean", |
| | | blank=self.blank_id, |
| | | fastemit_lambda=self.fastemit_lambda, |
| | | gather=True, |
| | | ) |
| | | |
| | | if not self.training and (self.report_cer or self.report_wer): |
| | | if self.error_calculator is None: |
| | | self.error_calculator = ErrorCalculator( |
| | | self.decoder, |
| | | self.joint_network, |
| | | self.token_list, |
| | | self.sym_space, |
| | | self.sym_blank, |
| | | report_cer=self.report_cer, |
| | | report_wer=self.report_wer, |
| | | ) |
| | | |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target) |
| | | return loss_transducer, cer_transducer, wer_transducer |
| | | |
| | | return loss_transducer, None, None |
| | | |
| | | def _calc_ctc_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | target: torch.Tensor, |
| | | t_len: torch.Tensor, |
| | | u_len: torch.Tensor, |
| | | ) -> torch.Tensor: |
| | | """Compute CTC loss. |
| | | Args: |
| | | encoder_out: Encoder output sequences. (B, T, D_enc) |
| | | target: Target label ID sequences. (B, L) |
| | | t_len: Encoder output sequences lengths. (B,) |
| | | u_len: Target label ID sequences lengths. (B,) |
| | | Return: |
| | | loss_ctc: CTC loss value. |
| | | """ |
| | | ctc_in = self.ctc_lin( |
| | | torch.nn.functional.dropout(encoder_out, p=self.ctc_dropout_rate) |
| | | ) |
| | | ctc_in = torch.log_softmax(ctc_in.transpose(0, 1), dim=-1) |
| | | |
| | | target_mask = target != 0 |
| | | ctc_target = target[target_mask].cpu() |
| | | |
| | | with torch.backends.cudnn.flags(deterministic=True): |
| | | loss_ctc = torch.nn.functional.ctc_loss( |
| | | ctc_in, |
| | | ctc_target, |
| | | t_len, |
| | | u_len, |
| | | zero_infinity=True, |
| | | reduction="sum", |
| | | ) |
| | | loss_ctc /= target.size(0) |
| | | |
| | | return loss_ctc |
| | | |
| | | def _calc_lm_loss( |
| | | self, |
| | | decoder_out: torch.Tensor, |
| | | target: torch.Tensor, |
| | | ) -> torch.Tensor: |
| | | """Compute LM loss. |
| | | Args: |
| | | decoder_out: Decoder output sequences. (B, U, D_dec) |
| | | target: Target label ID sequences. (B, L) |
| | | Return: |
| | | loss_lm: LM loss value. |
| | | """ |
| | | lm_loss_in = self.lm_lin(decoder_out[:, :-1, :]).view(-1, self.vocab_size) |
| | | lm_target = target.view(-1).type(torch.int64) |
| | | |
| | | with torch.no_grad(): |
| | | true_dist = lm_loss_in.clone() |
| | | true_dist.fill_(self.lm_loss_smoothing / (self.vocab_size - 1)) |
| | | |
| | | # Ignore blank ID (0) |
| | | ignore = lm_target == 0 |
| | | lm_target = lm_target.masked_fill(ignore, 0) |
| | | |
| | | true_dist.scatter_(1, lm_target.unsqueeze(1), (1 - self.lm_loss_smoothing)) |
| | | |
| | | loss_lm = torch.nn.functional.kl_div( |
| | | torch.log_softmax(lm_loss_in, dim=1), |
| | | true_dist, |
| | | reduction="none", |
| | | ) |
| | | loss_lm = loss_lm.masked_fill(ignore.unsqueeze(1), 0).sum() / decoder_out.size( |
| | | 0 |
| | | ) |
| | | |
| | | return loss_lm |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | if hasattr(self, "lang_token_id") and self.lang_token_id is not None: |
| | | ys_pad = torch.cat( |
| | | [ |
| | | self.lang_token_id.repeat(ys_pad.size(0), 1).to(ys_pad.device), |
| | | ys_pad, |
| | | ], |
| | | dim=1, |
| | | ) |
| | | ys_pad_lens += 1 |
| | | |
| | | ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_in_lens = ys_pad_lens + 1 |
| | | |
| | | # 1. Forward decoder |
| | | decoder_out, _ = self.att_decoder( |
| | | encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | loss_att = self.criterion_att(decoder_out, ys_out_pad) |
| | | acc_att = th_accuracy( |
| | | decoder_out.view(-1, self.vocab_size), |
| | | ys_out_pad, |
| | | ignore_label=self.ignore_id, |
| | | ) |
| | | |
| | | return loss_att, acc_att |
| | |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Dict |
| | | |
| | | import torch |
| | | from torch import nn |
| | |
| | | from funasr.modules.attention import ( |
| | | MultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttentionChunk, |
| | | LegacyRelPositionMultiHeadedAttention, # noqa: H301 |
| | | ) |
| | | from funasr.modules.embedding import ( |
| | |
| | | ScaledPositionalEncoding, # noqa: H301 |
| | | RelPositionalEncoding, # noqa: H301 |
| | | LegacyRelPositionalEncoding, # noqa: H301 |
| | | StreamingRelPositionalEncoding, |
| | | ) |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.multi_layer_conv import Conv1dLinear |
| | | from funasr.modules.multi_layer_conv import MultiLayeredConv1d |
| | | from funasr.modules.nets_utils import get_activation |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.modules.nets_utils import ( |
| | | TooShortUttError, |
| | | check_short_utt, |
| | | make_chunk_mask, |
| | | make_source_mask, |
| | | ) |
| | | from funasr.modules.positionwise_feed_forward import ( |
| | | PositionwiseFeedForward, # noqa: H301 |
| | | ) |
| | | from funasr.modules.repeat import repeat |
| | | from funasr.modules.repeat import repeat, MultiBlocks |
| | | from funasr.modules.subsampling import Conv2dSubsampling |
| | | from funasr.modules.subsampling import Conv2dSubsampling2 |
| | | from funasr.modules.subsampling import Conv2dSubsampling6 |
| | |
| | | from funasr.modules.subsampling import TooShortUttError |
| | | from funasr.modules.subsampling import check_short_utt |
| | | from funasr.modules.subsampling import Conv2dSubsamplingPad |
| | | from funasr.modules.subsampling import StreamingConvInput |
| | | |
| | | class ConvolutionModule(nn.Module): |
| | | """ConvolutionModule in Conformer model. |
| | | |
| | |
| | | return (x, pos_emb), mask |
| | | |
| | | return x, mask |
| | | |
| | | class ChunkEncoderLayer(torch.nn.Module): |
| | | """Chunk Conformer module definition. |
| | | Args: |
| | | block_size: Input/output size. |
| | | self_att: Self-attention module instance. |
| | | feed_forward: Feed-forward module instance. |
| | | feed_forward_macaron: Feed-forward module instance for macaron network. |
| | | conv_mod: Convolution module instance. |
| | | norm_class: Normalization module class. |
| | | norm_args: Normalization module arguments. |
| | | dropout_rate: Dropout rate. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | block_size: int, |
| | | self_att: torch.nn.Module, |
| | | feed_forward: torch.nn.Module, |
| | | feed_forward_macaron: torch.nn.Module, |
| | | conv_mod: torch.nn.Module, |
| | | norm_class: torch.nn.Module = torch.nn.LayerNorm, |
| | | norm_args: Dict = {}, |
| | | dropout_rate: float = 0.0, |
| | | ) -> None: |
| | | """Construct a Conformer object.""" |
| | | super().__init__() |
| | | |
| | | self.self_att = self_att |
| | | |
| | | self.feed_forward = feed_forward |
| | | self.feed_forward_macaron = feed_forward_macaron |
| | | self.feed_forward_scale = 0.5 |
| | | |
| | | self.conv_mod = conv_mod |
| | | |
| | | self.norm_feed_forward = norm_class(block_size, **norm_args) |
| | | self.norm_self_att = norm_class(block_size, **norm_args) |
| | | |
| | | self.norm_macaron = norm_class(block_size, **norm_args) |
| | | self.norm_conv = norm_class(block_size, **norm_args) |
| | | self.norm_final = norm_class(block_size, **norm_args) |
| | | |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | |
| | | self.block_size = block_size |
| | | self.cache = None |
| | | |
| | | def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: |
| | | """Initialize/Reset self-attention and convolution modules cache for streaming. |
| | | Args: |
| | | left_context: Number of left frames during chunk-by-chunk inference. |
| | | device: Device to use for cache tensor. |
| | | """ |
| | | self.cache = [ |
| | | torch.zeros( |
| | | (1, left_context, self.block_size), |
| | | device=device, |
| | | ), |
| | | torch.zeros( |
| | | ( |
| | | 1, |
| | | self.block_size, |
| | | self.conv_mod.kernel_size - 1, |
| | | ), |
| | | device=device, |
| | | ), |
| | | ] |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_mask: Optional[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: Conformer input sequences. (B, T, D_block) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | mask: Source mask. (B, T) |
| | | chunk_mask: Chunk mask. (T_2, T_2) |
| | | Returns: |
| | | x: Conformer output sequences. (B, T, D_block) |
| | | mask: Source mask. (B, T) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | """ |
| | | residual = x |
| | | |
| | | x = self.norm_macaron(x) |
| | | x = residual + self.feed_forward_scale * self.dropout( |
| | | self.feed_forward_macaron(x) |
| | | ) |
| | | |
| | | residual = x |
| | | x = self.norm_self_att(x) |
| | | x_q = x |
| | | x = residual + self.dropout( |
| | | self.self_att( |
| | | x_q, |
| | | x, |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | ) |
| | | |
| | | residual = x |
| | | |
| | | x = self.norm_conv(x) |
| | | x, _ = self.conv_mod(x) |
| | | x = residual + self.dropout(x) |
| | | residual = x |
| | | |
| | | x = self.norm_feed_forward(x) |
| | | x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x)) |
| | | |
| | | x = self.norm_final(x) |
| | | return x, mask, pos_enc |
| | | |
| | | def chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_size: int = 16, |
| | | left_context: int = 0, |
| | | right_context: int = 0, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode chunk of input sequence. |
| | | Args: |
| | | x: Conformer input sequences. (B, T, D_block) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | mask: Source mask. (B, T_2) |
| | | left_context: Number of frames in left context. |
| | | right_context: Number of frames in right context. |
| | | Returns: |
| | | x: Conformer output sequences. (B, T, D_block) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | """ |
| | | residual = x |
| | | |
| | | x = self.norm_macaron(x) |
| | | x = residual + self.feed_forward_scale * self.feed_forward_macaron(x) |
| | | |
| | | residual = x |
| | | x = self.norm_self_att(x) |
| | | if left_context > 0: |
| | | key = torch.cat([self.cache[0], x], dim=1) |
| | | else: |
| | | key = x |
| | | val = key |
| | | |
| | | if right_context > 0: |
| | | att_cache = key[:, -(left_context + right_context) : -right_context, :] |
| | | else: |
| | | att_cache = key[:, -left_context:, :] |
| | | x = residual + self.self_att( |
| | | x, |
| | | key, |
| | | val, |
| | | pos_enc, |
| | | mask, |
| | | left_context=left_context, |
| | | ) |
| | | |
| | | residual = x |
| | | x = self.norm_conv(x) |
| | | x, conv_cache = self.conv_mod( |
| | | x, cache=self.cache[1], right_context=right_context |
| | | ) |
| | | x = residual + x |
| | | residual = x |
| | | |
| | | x = self.norm_feed_forward(x) |
| | | x = residual + self.feed_forward_scale * self.feed_forward(x) |
| | | |
| | | x = self.norm_final(x) |
| | | self.cache = [att_cache, conv_cache] |
| | | |
| | | return x, pos_enc |
| | | |
| | | |
| | | class ConformerEncoder(AbsEncoder): |
| | |
| | | if len(intermediate_outs) > 0: |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| | | |
| | | |
| | | class CausalConvolution(torch.nn.Module): |
| | | """ConformerConvolution module definition. |
| | | Args: |
| | | channels: The number of channels. |
| | | kernel_size: Size of the convolving kernel. |
| | | activation: Type of activation function. |
| | | norm_args: Normalization module arguments. |
| | | causal: Whether to use causal convolution (set to True if streaming). |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | channels: int, |
| | | kernel_size: int, |
| | | activation: torch.nn.Module = torch.nn.ReLU(), |
| | | norm_args: Dict = {}, |
| | | causal: bool = False, |
| | | ) -> None: |
| | | """Construct an ConformerConvolution object.""" |
| | | super().__init__() |
| | | |
| | | assert (kernel_size - 1) % 2 == 0 |
| | | |
| | | self.kernel_size = kernel_size |
| | | |
| | | self.pointwise_conv1 = torch.nn.Conv1d( |
| | | channels, |
| | | 2 * channels, |
| | | kernel_size=1, |
| | | stride=1, |
| | | padding=0, |
| | | ) |
| | | |
| | | if causal: |
| | | self.lorder = kernel_size - 1 |
| | | padding = 0 |
| | | else: |
| | | self.lorder = 0 |
| | | padding = (kernel_size - 1) // 2 |
| | | |
| | | self.depthwise_conv = torch.nn.Conv1d( |
| | | channels, |
| | | channels, |
| | | kernel_size, |
| | | stride=1, |
| | | padding=padding, |
| | | groups=channels, |
| | | ) |
| | | self.norm = torch.nn.BatchNorm1d(channels, **norm_args) |
| | | self.pointwise_conv2 = torch.nn.Conv1d( |
| | | channels, |
| | | channels, |
| | | kernel_size=1, |
| | | stride=1, |
| | | padding=0, |
| | | ) |
| | | |
| | | self.activation = activation |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | cache: Optional[torch.Tensor] = None, |
| | | right_context: int = 0, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Compute convolution module. |
| | | Args: |
| | | x: ConformerConvolution input sequences. (B, T, D_hidden) |
| | | cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden) |
| | | right_context: Number of frames in right context. |
| | | Returns: |
| | | x: ConformerConvolution output sequences. (B, T, D_hidden) |
| | | cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden) |
| | | """ |
| | | x = self.pointwise_conv1(x.transpose(1, 2)) |
| | | x = torch.nn.functional.glu(x, dim=1) |
| | | |
| | | if self.lorder > 0: |
| | | if cache is None: |
| | | x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) |
| | | else: |
| | | x = torch.cat([cache, x], dim=2) |
| | | |
| | | if right_context > 0: |
| | | cache = x[:, :, -(self.lorder + right_context) : -right_context] |
| | | else: |
| | | cache = x[:, :, -self.lorder :] |
| | | |
| | | x = self.depthwise_conv(x) |
| | | x = self.activation(self.norm(x)) |
| | | |
| | | x = self.pointwise_conv2(x).transpose(1, 2) |
| | | |
| | | return x, cache |
| | | |
| | | class ConformerChunkEncoder(AbsEncoder): |
| | | """Encoder module definition. |
| | | Args: |
| | | input_size: Input size. |
| | | body_conf: Encoder body configuration. |
| | | input_conf: Encoder input configuration. |
| | | main_conf: Encoder main configuration. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | embed_vgg_like: bool = False, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 3, |
| | | macaron_style: bool = False, |
| | | rel_pos_type: str = "legacy", |
| | | pos_enc_layer_type: str = "rel_pos", |
| | | selfattention_layer_type: str = "rel_selfattn", |
| | | activation_type: str = "swish", |
| | | use_cnn_module: bool = True, |
| | | zero_triu: bool = False, |
| | | norm_type: str = "layer_norm", |
| | | cnn_module_kernel: int = 31, |
| | | conv_mod_norm_eps: float = 0.00001, |
| | | conv_mod_norm_momentum: float = 0.1, |
| | | simplified_att_score: bool = False, |
| | | dynamic_chunk_training: bool = False, |
| | | short_chunk_threshold: float = 0.75, |
| | | short_chunk_size: int = 25, |
| | | left_chunk_size: int = 0, |
| | | time_reduction_factor: int = 1, |
| | | unified_model_training: bool = False, |
| | | default_chunk_size: int = 16, |
| | | jitter_range: int = 4, |
| | | subsampling_factor: int = 1, |
| | | ) -> None: |
| | | """Construct an Encoder object.""" |
| | | super().__init__() |
| | | |
| | | assert check_argument_types() |
| | | |
| | | self.embed = StreamingConvInput( |
| | | input_size, |
| | | output_size, |
| | | subsampling_factor, |
| | | vgg_like=embed_vgg_like, |
| | | output_size=output_size, |
| | | ) |
| | | |
| | | self.pos_enc = StreamingRelPositionalEncoding( |
| | | output_size, |
| | | positional_dropout_rate, |
| | | ) |
| | | |
| | | activation = get_activation( |
| | | activation_type |
| | | ) |
| | | |
| | | pos_wise_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positional_dropout_rate, |
| | | activation, |
| | | ) |
| | | |
| | | conv_mod_norm_args = { |
| | | "eps": conv_mod_norm_eps, |
| | | "momentum": conv_mod_norm_momentum, |
| | | } |
| | | |
| | | conv_mod_args = ( |
| | | output_size, |
| | | cnn_module_kernel, |
| | | activation, |
| | | conv_mod_norm_args, |
| | | dynamic_chunk_training or unified_model_training, |
| | | ) |
| | | |
| | | mult_att_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | simplified_att_score, |
| | | ) |
| | | |
| | | |
| | | fn_modules = [] |
| | | for _ in range(num_blocks): |
| | | module = lambda: ChunkEncoderLayer( |
| | | output_size, |
| | | RelPositionMultiHeadedAttentionChunk(*mult_att_args), |
| | | PositionwiseFeedForward(*pos_wise_args), |
| | | PositionwiseFeedForward(*pos_wise_args), |
| | | CausalConvolution(*conv_mod_args), |
| | | dropout_rate=dropout_rate, |
| | | ) |
| | | fn_modules.append(module) |
| | | |
| | | self.encoders = MultiBlocks( |
| | | [fn() for fn in fn_modules], |
| | | output_size, |
| | | ) |
| | | |
| | | self._output_size = output_size |
| | | |
| | | self.dynamic_chunk_training = dynamic_chunk_training |
| | | self.short_chunk_threshold = short_chunk_threshold |
| | | self.short_chunk_size = short_chunk_size |
| | | self.left_chunk_size = left_chunk_size |
| | | |
| | | self.unified_model_training = unified_model_training |
| | | self.default_chunk_size = default_chunk_size |
| | | self.jitter_range = jitter_range |
| | | |
| | | self.time_reduction_factor = time_reduction_factor |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: |
| | | """Return the corresponding number of sample for a given chunk size, in frames. |
| | | Where size is the number of features frames after applying subsampling. |
| | | Args: |
| | | size: Number of frames after subsampling. |
| | | hop_length: Frontend's hop length |
| | | Returns: |
| | | : Number of raw samples |
| | | """ |
| | | return self.embed.get_size_before_subsampling(size) * hop_length |
| | | |
| | | def get_encoder_input_size(self, size: int) -> int: |
| | | """Return the corresponding number of sample for a given chunk size, in frames. |
| | | Where size is the number of features frames after applying subsampling. |
| | | Args: |
| | | size: Number of frames after subsampling. |
| | | Returns: |
| | | : Number of raw samples |
| | | """ |
| | | return self.embed.get_size_before_subsampling(size) |
| | | |
| | | |
| | | def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: |
| | | """Initialize/Reset encoder streaming cache. |
| | | Args: |
| | | left_context: Number of frames in left context. |
| | | device: Device ID. |
| | | """ |
| | | return self.encoders.reset_streaming_cache(left_context, device) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | x_len: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: Encoder input features. (B, T_in, F) |
| | | x_len: Encoder input features lengths. (B,) |
| | | Returns: |
| | | x: Encoder outputs. (B, T_out, D_enc) |
| | | x_len: Encoder outputs lenghts. (B,) |
| | | """ |
| | | short_status, limit_size = check_short_utt( |
| | | self.embed.subsampling_factor, x.size(1) |
| | | ) |
| | | |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {x.size(1)} frames and is too short for subsampling " |
| | | + f"(it needs more than {limit_size} frames), return empty results", |
| | | x.size(1), |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len) |
| | | |
| | | if self.unified_model_training: |
| | | chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = make_chunk_mask( |
| | | x.size(1), |
| | | chunk_size, |
| | | left_chunk_size=self.left_chunk_size, |
| | | device=x.device, |
| | | ) |
| | | x_utt = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=None, |
| | | ) |
| | | x_chunk = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | |
| | | olens = mask.eq(0).sum(1) |
| | | if self.time_reduction_factor > 1: |
| | | x_utt = x_utt[:,::self.time_reduction_factor,:] |
| | | x_chunk = x_chunk[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x_utt, x_chunk, olens |
| | | |
| | | elif self.dynamic_chunk_training: |
| | | max_len = x.size(1) |
| | | chunk_size = torch.randint(1, max_len, (1,)).item() |
| | | |
| | | if chunk_size > (max_len * self.short_chunk_threshold): |
| | | chunk_size = max_len |
| | | else: |
| | | chunk_size = (chunk_size % self.short_chunk_size) + 1 |
| | | |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | |
| | | chunk_mask = make_chunk_mask( |
| | | x.size(1), |
| | | chunk_size, |
| | | left_chunk_size=self.left_chunk_size, |
| | | device=x.device, |
| | | ) |
| | | else: |
| | | x, mask = self.embed(x, mask, None) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = None |
| | | x = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | |
| | | olens = mask.eq(0).sum(1) |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x, olens |
| | | |
| | | def simu_chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | x_len: torch.Tensor, |
| | | chunk_size: int = 16, |
| | | left_context: int = 32, |
| | | right_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | short_status, limit_size = check_short_utt( |
| | | self.embed.subsampling_factor, x.size(1) |
| | | ) |
| | | |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {x.size(1)} frames and is too short for subsampling " |
| | | + f"(it needs more than {limit_size} frames), return empty results", |
| | | x.size(1), |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len) |
| | | |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = make_chunk_mask( |
| | | x.size(1), |
| | | chunk_size, |
| | | left_chunk_size=self.left_chunk_size, |
| | | device=x.device, |
| | | ) |
| | | |
| | | x = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | olens = mask.eq(0).sum(1) |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | |
| | | return x |
| | | |
| | | def chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | x_len: torch.Tensor, |
| | | processed_frames: torch.tensor, |
| | | chunk_size: int = 16, |
| | | left_context: int = 32, |
| | | right_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Encode input sequences as chunks. |
| | | Args: |
| | | x: Encoder input features. (1, T_in, F) |
| | | x_len: Encoder input features lengths. (1,) |
| | | processed_frames: Number of frames already seen. |
| | | left_context: Number of frames in left context. |
| | | right_context: Number of frames in right context. |
| | | Returns: |
| | | x: Encoder outputs. (B, T_out, D_enc) |
| | | """ |
| | | mask = make_source_mask(x_len) |
| | | x, mask = self.embed(x, mask, None) |
| | | |
| | | if left_context > 0: |
| | | processed_mask = ( |
| | | torch.arange(left_context, device=x.device) |
| | | .view(1, left_context) |
| | | .flip(1) |
| | | ) |
| | | processed_mask = processed_mask >= processed_frames |
| | | mask = torch.cat([processed_mask, mask], dim=1) |
| | | pos_enc = self.pos_enc(x, left_context=left_context) |
| | | x = self.encoders.chunk_forward( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_size=chunk_size, |
| | | left_context=left_context, |
| | | right_context=right_context, |
| | | ) |
| | | |
| | | if right_context > 0: |
| | | x = x[:, 0:-right_context, :] |
| | | |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | return x |
| New file |
| | |
| | | """Transducer joint network implementation.""" |
| | | |
| | | import torch |
| | | |
| | | from funasr.modules.nets_utils import get_activation |
| | | |
| | | |
| | | class JointNetwork(torch.nn.Module): |
| | | """Transducer joint network module. |
| | | |
| | | Args: |
| | | output_size: Output size. |
| | | encoder_size: Encoder output size. |
| | | decoder_size: Decoder output size.. |
| | | joint_space_size: Joint space size. |
| | | joint_act_type: Type of activation for joint network. |
| | | **activation_parameters: Parameters for the activation function. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | output_size: int, |
| | | encoder_size: int, |
| | | decoder_size: int, |
| | | joint_space_size: int = 256, |
| | | joint_activation_type: str = "tanh", |
| | | ) -> None: |
| | | """Construct a JointNetwork object.""" |
| | | super().__init__() |
| | | |
| | | self.lin_enc = torch.nn.Linear(encoder_size, joint_space_size) |
| | | self.lin_dec = torch.nn.Linear(decoder_size, joint_space_size, bias=False) |
| | | |
| | | self.lin_out = torch.nn.Linear(joint_space_size, output_size) |
| | | |
| | | self.joint_activation = get_activation( |
| | | joint_activation_type |
| | | ) |
| | | |
| | | def forward( |
| | | self, |
| | | enc_out: torch.Tensor, |
| | | dec_out: torch.Tensor, |
| | | project_input: bool = True, |
| | | ) -> torch.Tensor: |
| | | """Joint computation of encoder and decoder hidden state sequences. |
| | | |
| | | Args: |
| | | enc_out: Expanded encoder output state sequences (B, T, 1, D_enc) |
| | | dec_out: Expanded decoder output state sequences (B, 1, U, D_dec) |
| | | |
| | | Returns: |
| | | joint_out: Joint output state sequences. (B, T, U, D_out) |
| | | |
| | | """ |
| | | if project_input: |
| | | joint_out = self.joint_activation(self.lin_enc(enc_out) + self.lin_dec(dec_out)) |
| | | else: |
| | | joint_out = self.joint_activation(enc_out + dec_out) |
| | | return self.lin_out(joint_out) |
| | |
| | | import numpy |
| | | import torch |
| | | from torch import nn |
| | | |
| | | from typing import Optional, Tuple |
| | | |
| | | class MultiHeadedAttention(nn.Module): |
| | | """Multi-Head Attention layer. |
| | |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) |
| | | return att_outs |
| | | |
| | | class RelPositionMultiHeadedAttentionChunk(torch.nn.Module): |
| | | """RelPositionMultiHeadedAttention definition. |
| | | Args: |
| | | num_heads: Number of attention heads. |
| | | embed_size: Embedding size. |
| | | dropout_rate: Dropout rate. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | num_heads: int, |
| | | embed_size: int, |
| | | dropout_rate: float = 0.0, |
| | | simplified_attention_score: bool = False, |
| | | ) -> None: |
| | | """Construct an MultiHeadedAttention object.""" |
| | | super().__init__() |
| | | |
| | | self.d_k = embed_size // num_heads |
| | | self.num_heads = num_heads |
| | | |
| | | assert self.d_k * num_heads == embed_size, ( |
| | | "embed_size (%d) must be divisible by num_heads (%d)", |
| | | (embed_size, num_heads), |
| | | ) |
| | | |
| | | self.linear_q = torch.nn.Linear(embed_size, embed_size) |
| | | self.linear_k = torch.nn.Linear(embed_size, embed_size) |
| | | self.linear_v = torch.nn.Linear(embed_size, embed_size) |
| | | |
| | | self.linear_out = torch.nn.Linear(embed_size, embed_size) |
| | | |
| | | if simplified_attention_score: |
| | | self.linear_pos = torch.nn.Linear(embed_size, num_heads) |
| | | |
| | | self.compute_att_score = self.compute_simplified_attention_score |
| | | else: |
| | | self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False) |
| | | |
| | | self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) |
| | | self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) |
| | | torch.nn.init.xavier_uniform_(self.pos_bias_u) |
| | | torch.nn.init.xavier_uniform_(self.pos_bias_v) |
| | | |
| | | self.compute_att_score = self.compute_attention_score |
| | | |
| | | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | | self.attn = None |
| | | |
| | | def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: |
| | | """Compute relative positional encoding. |
| | | Args: |
| | | x: Input sequence. (B, H, T_1, 2 * T_1 - 1) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | x: Output sequence. (B, H, T_1, T_2) |
| | | """ |
| | | batch_size, n_heads, time1, n = x.shape |
| | | time2 = time1 + left_context |
| | | |
| | | batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() |
| | | |
| | | return x.as_strided( |
| | | (batch_size, n_heads, time1, time2), |
| | | (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), |
| | | storage_offset=(n_stride * (time1 - 1)), |
| | | ) |
| | | |
| | | def compute_simplified_attention_score( |
| | | self, |
| | | query: torch.Tensor, |
| | | key: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | left_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Simplified attention score computation. |
| | | Reference: https://github.com/k2-fsa/icefall/pull/458 |
| | | Args: |
| | | query: Transformed query tensor. (B, H, T_1, d_k) |
| | | key: Transformed key tensor. (B, H, T_2, d_k) |
| | | pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | : Attention score. (B, H, T_1, T_2) |
| | | """ |
| | | pos_enc = self.linear_pos(pos_enc) |
| | | |
| | | matrix_ac = torch.matmul(query, key.transpose(2, 3)) |
| | | |
| | | matrix_bd = self.rel_shift( |
| | | pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1), |
| | | left_context=left_context, |
| | | ) |
| | | |
| | | return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) |
| | | |
| | | def compute_attention_score( |
| | | self, |
| | | query: torch.Tensor, |
| | | key: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | left_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Attention score computation. |
| | | Args: |
| | | query: Transformed query tensor. (B, H, T_1, d_k) |
| | | key: Transformed key tensor. (B, H, T_2, d_k) |
| | | pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | : Attention score. (B, H, T_1, T_2) |
| | | """ |
| | | p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k) |
| | | |
| | | query = query.transpose(1, 2) |
| | | q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) |
| | | q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) |
| | | |
| | | matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) |
| | | |
| | | matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1)) |
| | | matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) |
| | | |
| | | return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) |
| | | |
| | | def forward_qkv( |
| | | self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """Transform query, key and value. |
| | | Args: |
| | | query: Query tensor. (B, T_1, size) |
| | | key: Key tensor. (B, T_2, size) |
| | | v: Value tensor. (B, T_2, size) |
| | | Returns: |
| | | q: Transformed query tensor. (B, H, T_1, d_k) |
| | | k: Transformed key tensor. (B, H, T_2, d_k) |
| | | v: Transformed value tensor. (B, H, T_2, d_k) |
| | | """ |
| | | n_batch = query.size(0) |
| | | |
| | | q = ( |
| | | self.linear_q(query) |
| | | .view(n_batch, -1, self.num_heads, self.d_k) |
| | | .transpose(1, 2) |
| | | ) |
| | | k = ( |
| | | self.linear_k(key) |
| | | .view(n_batch, -1, self.num_heads, self.d_k) |
| | | .transpose(1, 2) |
| | | ) |
| | | v = ( |
| | | self.linear_v(value) |
| | | .view(n_batch, -1, self.num_heads, self.d_k) |
| | | .transpose(1, 2) |
| | | ) |
| | | |
| | | return q, k, v |
| | | |
| | | def forward_attention( |
| | | self, |
| | | value: torch.Tensor, |
| | | scores: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_mask: Optional[torch.Tensor] = None, |
| | | ) -> torch.Tensor: |
| | | """Compute attention context vector. |
| | | Args: |
| | | value: Transformed value. (B, H, T_2, d_k) |
| | | scores: Attention score. (B, H, T_1, T_2) |
| | | mask: Source mask. (B, T_2) |
| | | chunk_mask: Chunk mask. (T_1, T_1) |
| | | Returns: |
| | | attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k) |
| | | """ |
| | | batch_size = scores.size(0) |
| | | mask = mask.unsqueeze(1).unsqueeze(2) |
| | | if chunk_mask is not None: |
| | | mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask |
| | | scores = scores.masked_fill(mask, float("-inf")) |
| | | self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) |
| | | |
| | | attn_output = self.dropout(self.attn) |
| | | attn_output = torch.matmul(attn_output, value) |
| | | |
| | | attn_output = self.linear_out( |
| | | attn_output.transpose(1, 2) |
| | | .contiguous() |
| | | .view(batch_size, -1, self.num_heads * self.d_k) |
| | | ) |
| | | |
| | | return attn_output |
| | | |
| | | def forward( |
| | | self, |
| | | query: torch.Tensor, |
| | | key: torch.Tensor, |
| | | value: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_mask: Optional[torch.Tensor] = None, |
| | | left_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Compute scaled dot product attention with rel. positional encoding. |
| | | Args: |
| | | query: Query tensor. (B, T_1, size) |
| | | key: Key tensor. (B, T_2, size) |
| | | value: Value tensor. (B, T_2, size) |
| | | pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) |
| | | mask: Source mask. (B, T_2) |
| | | chunk_mask: Chunk mask. (T_1, T_1) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | : Output tensor. (B, T_1, H * d_k) |
| | | """ |
| | | q, k, v = self.forward_qkv(query, key, value) |
| | | scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) |
| | | return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) |
| New file |
| | |
| | | """Search algorithms for Transducer models.""" |
| | | |
| | | from dataclasses import dataclass |
| | | from typing import Any, Dict, List, Optional, Tuple, Union |
| | | |
| | | import numpy as np |
| | | import torch |
| | | |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | |
| | | |
| | | @dataclass |
| | | class Hypothesis: |
| | | """Default hypothesis definition for Transducer search algorithms. |
| | | |
| | | Args: |
| | | score: Total log-probability. |
| | | yseq: Label sequence as integer ID sequence. |
| | | dec_state: RNNDecoder or StatelessDecoder state. |
| | | ((N, 1, D_dec), (N, 1, D_dec) or None) or None |
| | | lm_state: RNNLM state. ((N, D_lm), (N, D_lm)) or None |
| | | |
| | | """ |
| | | |
| | | score: float |
| | | yseq: List[int] |
| | | dec_state: Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]] = None |
| | | lm_state: Optional[Union[Dict[str, Any], List[Any]]] = None |
| | | |
| | | |
| | | @dataclass |
| | | class ExtendedHypothesis(Hypothesis): |
| | | """Extended hypothesis definition for NSC beam search and mAES. |
| | | |
| | | Args: |
| | | : Hypothesis dataclass arguments. |
| | | dec_out: Decoder output sequence. (B, D_dec) |
| | | lm_score: Log-probabilities of the LM for given label. (vocab_size) |
| | | |
| | | """ |
| | | |
| | | dec_out: torch.Tensor = None |
| | | lm_score: torch.Tensor = None |
| | | |
| | | |
| | | class BeamSearchTransducer: |
| | | """Beam search implementation for Transducer. |
| | | |
| | | Args: |
| | | decoder: Decoder module. |
| | | joint_network: Joint network module. |
| | | beam_size: Size of the beam. |
| | | lm: LM class. |
| | | lm_weight: LM weight for soft fusion. |
| | | search_type: Search algorithm to use during inference. |
| | | max_sym_exp: Number of maximum symbol expansions at each time step. (TSD) |
| | | u_max: Maximum expected target sequence length. (ALSD) |
| | | nstep: Number of maximum expansion steps at each time step. (mAES) |
| | | expansion_gamma: Allowed logp difference for prune-by-value method. (mAES) |
| | | expansion_beta: |
| | | Number of additional candidates for expanded hypotheses selection. (mAES) |
| | | score_norm: Normalize final scores by length. |
| | | nbest: Number of final hypothesis. |
| | | streaming: Whether to perform chunk-by-chunk beam search. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | decoder, |
| | | joint_network: JointNetwork, |
| | | beam_size: int, |
| | | lm: Optional[torch.nn.Module] = None, |
| | | lm_weight: float = 0.1, |
| | | search_type: str = "default", |
| | | max_sym_exp: int = 3, |
| | | u_max: int = 50, |
| | | nstep: int = 2, |
| | | expansion_gamma: float = 2.3, |
| | | expansion_beta: int = 2, |
| | | score_norm: bool = False, |
| | | nbest: int = 1, |
| | | streaming: bool = False, |
| | | ) -> None: |
| | | """Construct a BeamSearchTransducer object.""" |
| | | super().__init__() |
| | | |
| | | self.decoder = decoder |
| | | self.joint_network = joint_network |
| | | |
| | | self.vocab_size = decoder.vocab_size |
| | | |
| | | assert beam_size <= self.vocab_size, ( |
| | | "beam_size (%d) should be smaller than or equal to vocabulary size (%d)." |
| | | % ( |
| | | beam_size, |
| | | self.vocab_size, |
| | | ) |
| | | ) |
| | | self.beam_size = beam_size |
| | | |
| | | if search_type == "default": |
| | | self.search_algorithm = self.default_beam_search |
| | | elif search_type == "tsd": |
| | | assert max_sym_exp > 1, "max_sym_exp (%d) should be greater than one." % ( |
| | | max_sym_exp |
| | | ) |
| | | self.max_sym_exp = max_sym_exp |
| | | |
| | | self.search_algorithm = self.time_sync_decoding |
| | | elif search_type == "alsd": |
| | | assert not streaming, "ALSD is not available in streaming mode." |
| | | |
| | | assert u_max >= 0, "u_max should be a positive integer, a portion of max_T." |
| | | self.u_max = u_max |
| | | |
| | | self.search_algorithm = self.align_length_sync_decoding |
| | | elif search_type == "maes": |
| | | assert self.vocab_size >= beam_size + expansion_beta, ( |
| | | "beam_size (%d) + expansion_beta (%d) " |
| | | " should be smaller than or equal to vocab size (%d)." |
| | | % (beam_size, expansion_beta, self.vocab_size) |
| | | ) |
| | | self.max_candidates = beam_size + expansion_beta |
| | | |
| | | self.nstep = nstep |
| | | self.expansion_gamma = expansion_gamma |
| | | |
| | | self.search_algorithm = self.modified_adaptive_expansion_search |
| | | else: |
| | | raise NotImplementedError( |
| | | "Specified search type (%s) is not supported." % search_type |
| | | ) |
| | | |
| | | self.use_lm = lm is not None |
| | | |
| | | if self.use_lm: |
| | | assert hasattr(lm, "rnn_type"), "Transformer LM is currently not supported." |
| | | |
| | | self.sos = self.vocab_size - 1 |
| | | |
| | | self.lm = lm |
| | | self.lm_weight = lm_weight |
| | | |
| | | self.score_norm = score_norm |
| | | self.nbest = nbest |
| | | |
| | | self.reset_inference_cache() |
| | | |
| | | def __call__( |
| | | self, |
| | | enc_out: torch.Tensor, |
| | | is_final: bool = True, |
| | | ) -> List[Hypothesis]: |
| | | """Perform beam search. |
| | | |
| | | Args: |
| | | enc_out: Encoder output sequence. (T, D_enc) |
| | | is_final: Whether enc_out is the final chunk of data. |
| | | |
| | | Returns: |
| | | nbest_hyps: N-best decoding results |
| | | |
| | | """ |
| | | self.decoder.set_device(enc_out.device) |
| | | |
| | | hyps = self.search_algorithm(enc_out) |
| | | |
| | | if is_final: |
| | | self.reset_inference_cache() |
| | | |
| | | return self.sort_nbest(hyps) |
| | | |
| | | self.search_cache = hyps |
| | | |
| | | return hyps |
| | | |
| | | def reset_inference_cache(self) -> None: |
| | | """Reset cache for decoder scoring and streaming.""" |
| | | self.decoder.score_cache = {} |
| | | self.search_cache = None |
| | | |
| | | def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]: |
| | | """Sort in-place hypotheses by score or score given sequence length. |
| | | |
| | | Args: |
| | | hyps: Hypothesis. |
| | | |
| | | Return: |
| | | hyps: Sorted hypothesis. |
| | | |
| | | """ |
| | | if self.score_norm: |
| | | hyps.sort(key=lambda x: x.score / len(x.yseq), reverse=True) |
| | | else: |
| | | hyps.sort(key=lambda x: x.score, reverse=True) |
| | | |
| | | return hyps[: self.nbest] |
| | | |
| | | def recombine_hyps(self, hyps: List[Hypothesis]) -> List[Hypothesis]: |
| | | """Recombine hypotheses with same label ID sequence. |
| | | |
| | | Args: |
| | | hyps: Hypotheses. |
| | | |
| | | Returns: |
| | | final: Recombined hypotheses. |
| | | |
| | | """ |
| | | final = {} |
| | | |
| | | for hyp in hyps: |
| | | str_yseq = "_".join(map(str, hyp.yseq)) |
| | | |
| | | if str_yseq in final: |
| | | final[str_yseq].score = np.logaddexp(final[str_yseq].score, hyp.score) |
| | | else: |
| | | final[str_yseq] = hyp |
| | | |
| | | return [*final.values()] |
| | | |
| | | def select_k_expansions( |
| | | self, |
| | | hyps: List[ExtendedHypothesis], |
| | | topk_idx: torch.Tensor, |
| | | topk_logp: torch.Tensor, |
| | | ) -> List[ExtendedHypothesis]: |
| | | """Return K hypotheses candidates for expansion from a list of hypothesis. |
| | | |
| | | K candidates are selected according to the extended hypotheses probabilities |
| | | and a prune-by-value method. Where K is equal to beam_size + beta. |
| | | |
| | | Args: |
| | | hyps: Hypotheses. |
| | | topk_idx: Indices of candidates hypothesis. |
| | | topk_logp: Log-probabilities of candidates hypothesis. |
| | | |
| | | Returns: |
| | | k_expansions: Best K expansion hypotheses candidates. |
| | | |
| | | """ |
| | | k_expansions = [] |
| | | |
| | | for i, hyp in enumerate(hyps): |
| | | hyp_i = [ |
| | | (int(k), hyp.score + float(v)) |
| | | for k, v in zip(topk_idx[i], topk_logp[i]) |
| | | ] |
| | | k_best_exp = max(hyp_i, key=lambda x: x[1])[1] |
| | | |
| | | k_expansions.append( |
| | | sorted( |
| | | filter( |
| | | lambda x: (k_best_exp - self.expansion_gamma) <= x[1], hyp_i |
| | | ), |
| | | key=lambda x: x[1], |
| | | reverse=True, |
| | | ) |
| | | ) |
| | | |
| | | return k_expansions |
| | | |
| | | def create_lm_batch_inputs(self, hyps_seq: List[List[int]]) -> torch.Tensor: |
| | | """Make batch of inputs with left padding for LM scoring. |
| | | |
| | | Args: |
| | | hyps_seq: Hypothesis sequences. |
| | | |
| | | Returns: |
| | | : Padded batch of sequences. |
| | | |
| | | """ |
| | | max_len = max([len(h) for h in hyps_seq]) |
| | | |
| | | return torch.LongTensor( |
| | | [[self.sos] + ([0] * (max_len - len(h))) + h[1:] for h in hyps_seq], |
| | | device=self.decoder.device, |
| | | ) |
| | | |
| | | def default_beam_search(self, enc_out: torch.Tensor) -> List[Hypothesis]: |
| | | """Beam search implementation without prefix search. |
| | | |
| | | Modified from https://arxiv.org/pdf/1211.3711.pdf |
| | | |
| | | Args: |
| | | enc_out: Encoder output sequence. (T, D) |
| | | |
| | | Returns: |
| | | nbest_hyps: N-best hypothesis. |
| | | |
| | | """ |
| | | beam_k = min(self.beam_size, (self.vocab_size - 1)) |
| | | max_t = len(enc_out) |
| | | |
| | | if self.search_cache is not None: |
| | | kept_hyps = self.search_cache |
| | | else: |
| | | kept_hyps = [ |
| | | Hypothesis( |
| | | score=0.0, |
| | | yseq=[0], |
| | | dec_state=self.decoder.init_state(1), |
| | | ) |
| | | ] |
| | | |
| | | for t in range(max_t): |
| | | hyps = kept_hyps |
| | | kept_hyps = [] |
| | | |
| | | while True: |
| | | max_hyp = max(hyps, key=lambda x: x.score) |
| | | hyps.remove(max_hyp) |
| | | |
| | | label = torch.full( |
| | | (1, 1), |
| | | max_hyp.yseq[-1], |
| | | dtype=torch.long, |
| | | device=self.decoder.device, |
| | | ) |
| | | dec_out, state = self.decoder.score( |
| | | label, |
| | | max_hyp.yseq, |
| | | max_hyp.dec_state, |
| | | ) |
| | | |
| | | logp = torch.log_softmax( |
| | | self.joint_network(enc_out[t : t + 1, :], dec_out), |
| | | dim=-1, |
| | | ).squeeze(0) |
| | | top_k = logp[1:].topk(beam_k, dim=-1) |
| | | |
| | | kept_hyps.append( |
| | | Hypothesis( |
| | | score=(max_hyp.score + float(logp[0:1])), |
| | | yseq=max_hyp.yseq, |
| | | dec_state=max_hyp.dec_state, |
| | | lm_state=max_hyp.lm_state, |
| | | ) |
| | | ) |
| | | |
| | | if self.use_lm: |
| | | lm_scores, lm_state = self.lm.score( |
| | | torch.LongTensor( |
| | | [self.sos] + max_hyp.yseq[1:], device=self.decoder.device |
| | | ), |
| | | max_hyp.lm_state, |
| | | None, |
| | | ) |
| | | else: |
| | | lm_state = max_hyp.lm_state |
| | | |
| | | for logp, k in zip(*top_k): |
| | | score = max_hyp.score + float(logp) |
| | | |
| | | if self.use_lm: |
| | | score += self.lm_weight * lm_scores[k + 1] |
| | | |
| | | hyps.append( |
| | | Hypothesis( |
| | | score=score, |
| | | yseq=max_hyp.yseq + [int(k + 1)], |
| | | dec_state=state, |
| | | lm_state=lm_state, |
| | | ) |
| | | ) |
| | | |
| | | hyps_max = float(max(hyps, key=lambda x: x.score).score) |
| | | kept_most_prob = sorted( |
| | | [hyp for hyp in kept_hyps if hyp.score > hyps_max], |
| | | key=lambda x: x.score, |
| | | ) |
| | | if len(kept_most_prob) >= self.beam_size: |
| | | kept_hyps = kept_most_prob |
| | | break |
| | | |
| | | return kept_hyps |
| | | |
| | | def align_length_sync_decoding( |
| | | self, |
| | | enc_out: torch.Tensor, |
| | | ) -> List[Hypothesis]: |
| | | """Alignment-length synchronous beam search implementation. |
| | | |
| | | Based on https://ieeexplore.ieee.org/document/9053040 |
| | | |
| | | Args: |
| | | h: Encoder output sequences. (T, D) |
| | | |
| | | Returns: |
| | | nbest_hyps: N-best hypothesis. |
| | | |
| | | """ |
| | | t_max = int(enc_out.size(0)) |
| | | u_max = min(self.u_max, (t_max - 1)) |
| | | |
| | | B = [Hypothesis(yseq=[0], score=0.0, dec_state=self.decoder.init_state(1))] |
| | | final = [] |
| | | |
| | | if self.use_lm: |
| | | B[0].lm_state = self.lm.zero_state() |
| | | |
| | | for i in range(t_max + u_max): |
| | | A = [] |
| | | |
| | | B_ = [] |
| | | B_enc_out = [] |
| | | for hyp in B: |
| | | u = len(hyp.yseq) - 1 |
| | | t = i - u |
| | | |
| | | if t > (t_max - 1): |
| | | continue |
| | | |
| | | B_.append(hyp) |
| | | B_enc_out.append((t, enc_out[t])) |
| | | |
| | | if B_: |
| | | beam_enc_out = torch.stack([b[1] for b in B_enc_out]) |
| | | beam_dec_out, beam_state = self.decoder.batch_score(B_) |
| | | |
| | | beam_logp = torch.log_softmax( |
| | | self.joint_network(beam_enc_out, beam_dec_out), |
| | | dim=-1, |
| | | ) |
| | | beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1) |
| | | |
| | | if self.use_lm: |
| | | beam_lm_scores, beam_lm_states = self.lm.batch_score( |
| | | self.create_lm_batch_inputs([b.yseq for b in B_]), |
| | | [b.lm_state for b in B_], |
| | | None, |
| | | ) |
| | | |
| | | for i, hyp in enumerate(B_): |
| | | new_hyp = Hypothesis( |
| | | score=(hyp.score + float(beam_logp[i, 0])), |
| | | yseq=hyp.yseq[:], |
| | | dec_state=hyp.dec_state, |
| | | lm_state=hyp.lm_state, |
| | | ) |
| | | |
| | | A.append(new_hyp) |
| | | |
| | | if B_enc_out[i][0] == (t_max - 1): |
| | | final.append(new_hyp) |
| | | |
| | | for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): |
| | | new_hyp = Hypothesis( |
| | | score=(hyp.score + float(logp)), |
| | | yseq=(hyp.yseq[:] + [int(k)]), |
| | | dec_state=self.decoder.select_state(beam_state, i), |
| | | lm_state=hyp.lm_state, |
| | | ) |
| | | |
| | | if self.use_lm: |
| | | new_hyp.score += self.lm_weight * beam_lm_scores[i, k] |
| | | new_hyp.lm_state = beam_lm_states[i] |
| | | |
| | | A.append(new_hyp) |
| | | |
| | | B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size] |
| | | B = self.recombine_hyps(B) |
| | | |
| | | if final: |
| | | return final |
| | | |
| | | return B |
| | | |
| | | def time_sync_decoding(self, enc_out: torch.Tensor) -> List[Hypothesis]: |
| | | """Time synchronous beam search implementation. |
| | | |
| | | Based on https://ieeexplore.ieee.org/document/9053040 |
| | | |
| | | Args: |
| | | enc_out: Encoder output sequence. (T, D) |
| | | |
| | | Returns: |
| | | nbest_hyps: N-best hypothesis. |
| | | |
| | | """ |
| | | if self.search_cache is not None: |
| | | B = self.search_cache |
| | | else: |
| | | B = [ |
| | | Hypothesis( |
| | | yseq=[0], |
| | | score=0.0, |
| | | dec_state=self.decoder.init_state(1), |
| | | ) |
| | | ] |
| | | |
| | | if self.use_lm: |
| | | B[0].lm_state = self.lm.zero_state() |
| | | |
| | | for enc_out_t in enc_out: |
| | | A = [] |
| | | C = B |
| | | |
| | | enc_out_t = enc_out_t.unsqueeze(0) |
| | | |
| | | for v in range(self.max_sym_exp): |
| | | D = [] |
| | | |
| | | beam_dec_out, beam_state = self.decoder.batch_score(C) |
| | | |
| | | beam_logp = torch.log_softmax( |
| | | self.joint_network(enc_out_t, beam_dec_out), |
| | | dim=-1, |
| | | ) |
| | | beam_topk = beam_logp[:, 1:].topk(self.beam_size, dim=-1) |
| | | |
| | | seq_A = [h.yseq for h in A] |
| | | |
| | | for i, hyp in enumerate(C): |
| | | if hyp.yseq not in seq_A: |
| | | A.append( |
| | | Hypothesis( |
| | | score=(hyp.score + float(beam_logp[i, 0])), |
| | | yseq=hyp.yseq[:], |
| | | dec_state=hyp.dec_state, |
| | | lm_state=hyp.lm_state, |
| | | ) |
| | | ) |
| | | else: |
| | | dict_pos = seq_A.index(hyp.yseq) |
| | | |
| | | A[dict_pos].score = np.logaddexp( |
| | | A[dict_pos].score, (hyp.score + float(beam_logp[i, 0])) |
| | | ) |
| | | |
| | | if v < (self.max_sym_exp - 1): |
| | | if self.use_lm: |
| | | beam_lm_scores, beam_lm_states = self.lm.batch_score( |
| | | self.create_lm_batch_inputs([c.yseq for c in C]), |
| | | [c.lm_state for c in C], |
| | | None, |
| | | ) |
| | | |
| | | for i, hyp in enumerate(C): |
| | | for logp, k in zip(beam_topk[0][i], beam_topk[1][i] + 1): |
| | | new_hyp = Hypothesis( |
| | | score=(hyp.score + float(logp)), |
| | | yseq=(hyp.yseq + [int(k)]), |
| | | dec_state=self.decoder.select_state(beam_state, i), |
| | | lm_state=hyp.lm_state, |
| | | ) |
| | | |
| | | if self.use_lm: |
| | | new_hyp.score += self.lm_weight * beam_lm_scores[i, k] |
| | | new_hyp.lm_state = beam_lm_states[i] |
| | | |
| | | D.append(new_hyp) |
| | | |
| | | C = sorted(D, key=lambda x: x.score, reverse=True)[: self.beam_size] |
| | | |
| | | B = sorted(A, key=lambda x: x.score, reverse=True)[: self.beam_size] |
| | | |
| | | return B |
| | | |
| | | def modified_adaptive_expansion_search( |
| | | self, |
| | | enc_out: torch.Tensor, |
| | | ) -> List[ExtendedHypothesis]: |
| | | """Modified version of Adaptive Expansion Search (mAES). |
| | | |
| | | Based on AES (https://ieeexplore.ieee.org/document/9250505) and |
| | | NSC (https://arxiv.org/abs/2201.05420). |
| | | |
| | | Args: |
| | | enc_out: Encoder output sequence. (T, D_enc) |
| | | |
| | | Returns: |
| | | nbest_hyps: N-best hypothesis. |
| | | |
| | | """ |
| | | if self.search_cache is not None: |
| | | kept_hyps = self.search_cache |
| | | else: |
| | | init_tokens = [ |
| | | ExtendedHypothesis( |
| | | yseq=[0], |
| | | score=0.0, |
| | | dec_state=self.decoder.init_state(1), |
| | | ) |
| | | ] |
| | | |
| | | beam_dec_out, beam_state = self.decoder.batch_score( |
| | | init_tokens, |
| | | ) |
| | | |
| | | if self.use_lm: |
| | | beam_lm_scores, beam_lm_states = self.lm.batch_score( |
| | | self.create_lm_batch_inputs([h.yseq for h in init_tokens]), |
| | | [h.lm_state for h in init_tokens], |
| | | None, |
| | | ) |
| | | |
| | | lm_state = beam_lm_states[0] |
| | | lm_score = beam_lm_scores[0] |
| | | else: |
| | | lm_state = None |
| | | lm_score = None |
| | | |
| | | kept_hyps = [ |
| | | ExtendedHypothesis( |
| | | yseq=[0], |
| | | score=0.0, |
| | | dec_state=self.decoder.select_state(beam_state, 0), |
| | | dec_out=beam_dec_out[0], |
| | | lm_state=lm_state, |
| | | lm_score=lm_score, |
| | | ) |
| | | ] |
| | | |
| | | for enc_out_t in enc_out: |
| | | hyps = kept_hyps |
| | | kept_hyps = [] |
| | | |
| | | beam_enc_out = enc_out_t.unsqueeze(0) |
| | | |
| | | list_b = [] |
| | | for n in range(self.nstep): |
| | | beam_dec_out = torch.stack([h.dec_out for h in hyps]) |
| | | |
| | | beam_logp, beam_idx = torch.log_softmax( |
| | | self.joint_network(beam_enc_out, beam_dec_out), |
| | | dim=-1, |
| | | ).topk(self.max_candidates, dim=-1) |
| | | |
| | | k_expansions = self.select_k_expansions(hyps, beam_idx, beam_logp) |
| | | |
| | | list_exp = [] |
| | | for i, hyp in enumerate(hyps): |
| | | for k, new_score in k_expansions[i]: |
| | | new_hyp = ExtendedHypothesis( |
| | | yseq=hyp.yseq[:], |
| | | score=new_score, |
| | | dec_out=hyp.dec_out, |
| | | dec_state=hyp.dec_state, |
| | | lm_state=hyp.lm_state, |
| | | lm_score=hyp.lm_score, |
| | | ) |
| | | |
| | | if k == 0: |
| | | list_b.append(new_hyp) |
| | | else: |
| | | new_hyp.yseq.append(int(k)) |
| | | |
| | | if self.use_lm: |
| | | new_hyp.score += self.lm_weight * float(hyp.lm_score[k]) |
| | | |
| | | list_exp.append(new_hyp) |
| | | |
| | | if not list_exp: |
| | | kept_hyps = sorted( |
| | | self.recombine_hyps(list_b), key=lambda x: x.score, reverse=True |
| | | )[: self.beam_size] |
| | | |
| | | break |
| | | else: |
| | | beam_dec_out, beam_state = self.decoder.batch_score( |
| | | list_exp, |
| | | ) |
| | | |
| | | if self.use_lm: |
| | | beam_lm_scores, beam_lm_states = self.lm.batch_score( |
| | | self.create_lm_batch_inputs([h.yseq for h in list_exp]), |
| | | [h.lm_state for h in list_exp], |
| | | None, |
| | | ) |
| | | |
| | | if n < (self.nstep - 1): |
| | | for i, hyp in enumerate(list_exp): |
| | | hyp.dec_out = beam_dec_out[i] |
| | | hyp.dec_state = self.decoder.select_state(beam_state, i) |
| | | |
| | | if self.use_lm: |
| | | hyp.lm_state = beam_lm_states[i] |
| | | hyp.lm_score = beam_lm_scores[i] |
| | | |
| | | hyps = list_exp[:] |
| | | else: |
| | | beam_logp = torch.log_softmax( |
| | | self.joint_network(beam_enc_out, beam_dec_out), |
| | | dim=-1, |
| | | ) |
| | | |
| | | for i, hyp in enumerate(list_exp): |
| | | hyp.score += float(beam_logp[i, 0]) |
| | | |
| | | hyp.dec_out = beam_dec_out[i] |
| | | hyp.dec_state = self.decoder.select_state(beam_state, i) |
| | | |
| | | if self.use_lm: |
| | | hyp.lm_state = beam_lm_states[i] |
| | | hyp.lm_score = beam_lm_scores[i] |
| | | |
| | | kept_hyps = sorted( |
| | | self.recombine_hyps(list_b + list_exp), |
| | | key=lambda x: x.score, |
| | | reverse=True, |
| | | )[: self.beam_size] |
| | | |
| | | return kept_hyps |
| | |
| | | |
| | | """Common functions for ASR.""" |
| | | |
| | | from typing import List, Optional, Tuple |
| | | |
| | | import json |
| | | import logging |
| | | import sys |
| | |
| | | from itertools import groupby |
| | | import numpy as np |
| | | import six |
| | | import torch |
| | | |
| | | from funasr.modules.beam_search.beam_search_transducer import BeamSearchTransducer |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | |
| | | def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): |
| | | """End detection. |
| | |
| | | word_eds.append(editdistance.eval(hyp_words, ref_words)) |
| | | word_ref_lens.append(len(ref_words)) |
| | | return float(sum(word_eds)) / sum(word_ref_lens) |
| | | |
| | | class ErrorCalculatorTransducer: |
| | | """Calculate CER and WER for transducer models. |
| | | Args: |
| | | decoder: Decoder module. |
| | | joint_network: Joint Network module. |
| | | token_list: List of token units. |
| | | sym_space: Space symbol. |
| | | sym_blank: Blank symbol. |
| | | report_cer: Whether to compute CER. |
| | | report_wer: Whether to compute WER. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | decoder, |
| | | joint_network: JointNetwork, |
| | | token_list: List[int], |
| | | sym_space: str, |
| | | sym_blank: str, |
| | | report_cer: bool = False, |
| | | report_wer: bool = False, |
| | | ) -> None: |
| | | """Construct an ErrorCalculatorTransducer object.""" |
| | | super().__init__() |
| | | |
| | | self.beam_search = BeamSearchTransducer( |
| | | decoder=decoder, |
| | | joint_network=joint_network, |
| | | beam_size=1, |
| | | search_type="default", |
| | | score_norm=False, |
| | | ) |
| | | |
| | | self.decoder = decoder |
| | | |
| | | self.token_list = token_list |
| | | self.space = sym_space |
| | | self.blank = sym_blank |
| | | |
| | | self.report_cer = report_cer |
| | | self.report_wer = report_wer |
| | | |
| | | def __call__( |
| | | self, encoder_out: torch.Tensor, target: torch.Tensor |
| | | ) -> Tuple[Optional[float], Optional[float]]: |
| | | """Calculate sentence-level WER or/and CER score for Transducer model. |
| | | Args: |
| | | encoder_out: Encoder output sequences. (B, T, D_enc) |
| | | target: Target label ID sequences. (B, L) |
| | | Returns: |
| | | : Sentence-level CER score. |
| | | : Sentence-level WER score. |
| | | """ |
| | | cer, wer = None, None |
| | | |
| | | batchsize = int(encoder_out.size(0)) |
| | | |
| | | encoder_out = encoder_out.to(next(self.decoder.parameters()).device) |
| | | |
| | | batch_nbest = [self.beam_search(encoder_out[b]) for b in range(batchsize)] |
| | | pred = [nbest_hyp[0].yseq[1:] for nbest_hyp in batch_nbest] |
| | | |
| | | char_pred, char_target = self.convert_to_char(pred, target) |
| | | |
| | | if self.report_cer: |
| | | cer = self.calculate_cer(char_pred, char_target) |
| | | |
| | | if self.report_wer: |
| | | wer = self.calculate_wer(char_pred, char_target) |
| | | |
| | | return cer, wer |
| | | |
| | | def convert_to_char( |
| | | self, pred: torch.Tensor, target: torch.Tensor |
| | | ) -> Tuple[List, List]: |
| | | """Convert label ID sequences to character sequences. |
| | | Args: |
| | | pred: Prediction label ID sequences. (B, U) |
| | | target: Target label ID sequences. (B, L) |
| | | Returns: |
| | | char_pred: Prediction character sequences. (B, ?) |
| | | char_target: Target character sequences. (B, ?) |
| | | """ |
| | | char_pred, char_target = [], [] |
| | | |
| | | for i, pred_i in enumerate(pred): |
| | | char_pred_i = [self.token_list[int(h)] for h in pred_i] |
| | | char_target_i = [self.token_list[int(r)] for r in target[i]] |
| | | |
| | | char_pred_i = "".join(char_pred_i).replace(self.space, " ") |
| | | char_pred_i = char_pred_i.replace(self.blank, "") |
| | | |
| | | char_target_i = "".join(char_target_i).replace(self.space, " ") |
| | | char_target_i = char_target_i.replace(self.blank, "") |
| | | |
| | | char_pred.append(char_pred_i) |
| | | char_target.append(char_target_i) |
| | | |
| | | return char_pred, char_target |
| | | |
| | | def calculate_cer( |
| | | self, char_pred: torch.Tensor, char_target: torch.Tensor |
| | | ) -> float: |
| | | """Calculate sentence-level CER score. |
| | | Args: |
| | | char_pred: Prediction character sequences. (B, ?) |
| | | char_target: Target character sequences. (B, ?) |
| | | Returns: |
| | | : Average sentence-level CER score. |
| | | """ |
| | | import editdistance |
| | | |
| | | distances, lens = [], [] |
| | | |
| | | for i, char_pred_i in enumerate(char_pred): |
| | | pred = char_pred_i.replace(" ", "") |
| | | target = char_target[i].replace(" ", "") |
| | | distances.append(editdistance.eval(pred, target)) |
| | | lens.append(len(target)) |
| | | |
| | | return float(sum(distances)) / sum(lens) |
| | | |
| | | def calculate_wer( |
| | | self, char_pred: torch.Tensor, char_target: torch.Tensor |
| | | ) -> float: |
| | | """Calculate sentence-level WER score. |
| | | Args: |
| | | char_pred: Prediction character sequences. (B, ?) |
| | | char_target: Target character sequences. (B, ?) |
| | | Returns: |
| | | : Average sentence-level WER score |
| | | """ |
| | | import editdistance |
| | | |
| | | distances, lens = [], [] |
| | | |
| | | for i, char_pred_i in enumerate(char_pred): |
| | | pred = char_pred_i.replace("▁", " ").split() |
| | | target = char_target[i].replace("▁", " ").split() |
| | | |
| | | distances.append(editdistance.eval(pred, target)) |
| | | lens.append(len(target)) |
| | | |
| | | return float(sum(distances)) / sum(lens) |
| | |
| | | outputs = F.pad(outputs, (pad_left, pad_right)) |
| | | outputs = outputs.transpose(1, 2) |
| | | return outputs |
| | | |
| | | |
| | | class StreamingRelPositionalEncoding(torch.nn.Module): |
| | | """Relative positional encoding. |
| | | Args: |
| | | size: Module size. |
| | | max_len: Maximum input length. |
| | | dropout_rate: Dropout rate. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, size: int, dropout_rate: float = 0.0, max_len: int = 5000 |
| | | ) -> None: |
| | | """Construct a RelativePositionalEncoding object.""" |
| | | super().__init__() |
| | | |
| | | self.size = size |
| | | |
| | | self.pe = None |
| | | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) |
| | | self._register_load_state_dict_pre_hook(_pre_hook) |
| | | |
| | | def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None: |
| | | """Reset positional encoding. |
| | | Args: |
| | | x: Input sequences. (B, T, ?) |
| | | left_context: Number of frames in left context. |
| | | """ |
| | | time1 = x.size(1) + left_context |
| | | |
| | | if self.pe is not None: |
| | | if self.pe.size(1) >= time1 * 2 - 1: |
| | | if self.pe.dtype != x.dtype or self.pe.device != x.device: |
| | | self.pe = self.pe.to(device=x.device, dtype=x.dtype) |
| | | return |
| | | |
| | | pe_positive = torch.zeros(time1, self.size) |
| | | pe_negative = torch.zeros(time1, self.size) |
| | | |
| | | position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1) |
| | | div_term = torch.exp( |
| | | torch.arange(0, self.size, 2, dtype=torch.float32) |
| | | * -(math.log(10000.0) / self.size) |
| | | ) |
| | | |
| | | pe_positive[:, 0::2] = torch.sin(position * div_term) |
| | | pe_positive[:, 1::2] = torch.cos(position * div_term) |
| | | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| | | |
| | | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
| | | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
| | | pe_negative = pe_negative[1:].unsqueeze(0) |
| | | |
| | | self.pe = torch.cat([pe_positive, pe_negative], dim=1).to( |
| | | dtype=x.dtype, device=x.device |
| | | ) |
| | | |
| | | def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: |
| | | """Compute positional encoding. |
| | | Args: |
| | | x: Input sequences. (B, T, ?) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?) |
| | | """ |
| | | self.extend_pe(x, left_context=left_context) |
| | | |
| | | time1 = x.size(1) + left_context |
| | | |
| | | pos_enc = self.pe[ |
| | | :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1) |
| | | ] |
| | | pos_enc = self.dropout(pos_enc) |
| | | |
| | | return pos_enc |
| | |
| | | """Network related utility tools.""" |
| | | |
| | | import logging |
| | | from typing import Dict |
| | | from typing import Dict, List, Tuple |
| | | |
| | | import numpy as np |
| | | import torch |
| | |
| | | } |
| | | |
| | | return activation_funcs[act]() |
| | | |
| | | class TooShortUttError(Exception): |
| | | """Raised when the utt is too short for subsampling. |
| | | |
| | | Args: |
| | | message: Error message to display. |
| | | actual_size: The size that cannot pass the subsampling. |
| | | limit: The size limit for subsampling. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, message: str, actual_size: int, limit: int) -> None: |
| | | """Construct a TooShortUttError module.""" |
| | | super().__init__(message) |
| | | |
| | | self.actual_size = actual_size |
| | | self.limit = limit |
| | | |
| | | |
| | | def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]: |
| | | """Check if the input is too short for subsampling. |
| | | |
| | | Args: |
| | | sub_factor: Subsampling factor for Conv2DSubsampling. |
| | | size: Input size. |
| | | |
| | | Returns: |
| | | : Whether an error should be sent. |
| | | : Size limit for specified subsampling factor. |
| | | |
| | | """ |
| | | if sub_factor == 2 and size < 3: |
| | | return True, 7 |
| | | elif sub_factor == 4 and size < 7: |
| | | return True, 7 |
| | | elif sub_factor == 6 and size < 11: |
| | | return True, 11 |
| | | |
| | | return False, -1 |
| | | |
| | | |
| | | def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]: |
| | | """Get conv2D second layer parameters for given subsampling factor. |
| | | |
| | | Args: |
| | | sub_factor: Subsampling factor (1/X). |
| | | input_size: Input size. |
| | | |
| | | Returns: |
| | | : Kernel size for second convolution. |
| | | : Stride for second convolution. |
| | | : Conv2DSubsampling output size. |
| | | |
| | | """ |
| | | if sub_factor == 2: |
| | | return 3, 1, (((input_size - 1) // 2 - 2)) |
| | | elif sub_factor == 4: |
| | | return 3, 2, (((input_size - 1) // 2 - 1) // 2) |
| | | elif sub_factor == 6: |
| | | return 5, 3, (((input_size - 1) // 2 - 2) // 3) |
| | | else: |
| | | raise ValueError( |
| | | "subsampling_factor parameter should be set to either 2, 4 or 6." |
| | | ) |
| | | |
| | | |
| | | def make_chunk_mask( |
| | | size: int, |
| | | chunk_size: int, |
| | | left_chunk_size: int = 0, |
| | | device: torch.device = None, |
| | | ) -> torch.Tensor: |
| | | """Create chunk mask for the subsequent steps (size, size). |
| | | |
| | | Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py |
| | | |
| | | Args: |
| | | size: Size of the source mask. |
| | | chunk_size: Number of frames in chunk. |
| | | left_chunk_size: Size of the left context in chunks (0 means full context). |
| | | device: Device for the mask tensor. |
| | | |
| | | Returns: |
| | | mask: Chunk mask. (size, size) |
| | | |
| | | """ |
| | | mask = torch.zeros(size, size, device=device, dtype=torch.bool) |
| | | |
| | | for i in range(size): |
| | | if left_chunk_size <= 0: |
| | | start = 0 |
| | | else: |
| | | start = max((i // chunk_size - left_chunk_size) * chunk_size, 0) |
| | | |
| | | end = min((i // chunk_size + 1) * chunk_size, size) |
| | | mask[i, start:end] = True |
| | | |
| | | return ~mask |
| | | |
| | | def make_source_mask(lengths: torch.Tensor) -> torch.Tensor: |
| | | """Create source mask for given lengths. |
| | | |
| | | Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py |
| | | |
| | | Args: |
| | | lengths: Sequence lengths. (B,) |
| | | |
| | | Returns: |
| | | : Mask for the sequence lengths. (B, max_len) |
| | | |
| | | """ |
| | | max_len = lengths.max() |
| | | batch_size = lengths.size(0) |
| | | |
| | | expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths) |
| | | |
| | | return expanded_lengths >= lengths.unsqueeze(1) |
| | | |
| | | |
| | | def get_transducer_task_io( |
| | | labels: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ignore_id: int = -1, |
| | | blank_id: int = 0, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """Get Transducer loss I/O. |
| | | |
| | | Args: |
| | | labels: Label ID sequences. (B, L) |
| | | encoder_out_lens: Encoder output lengths. (B,) |
| | | ignore_id: Padding symbol ID. |
| | | blank_id: Blank symbol ID. |
| | | |
| | | Returns: |
| | | decoder_in: Decoder inputs. (B, U) |
| | | target: Target label ID sequences. (B, U) |
| | | t_len: Time lengths. (B,) |
| | | u_len: Label lengths. (B,) |
| | | |
| | | """ |
| | | |
| | | def pad_list(labels: List[torch.Tensor], padding_value: int = 0): |
| | | """Create padded batch of labels from a list of labels sequences. |
| | | |
| | | Args: |
| | | labels: Labels sequences. [B x (?)] |
| | | padding_value: Padding value. |
| | | |
| | | Returns: |
| | | labels: Batch of padded labels sequences. (B,) |
| | | |
| | | """ |
| | | batch_size = len(labels) |
| | | |
| | | padded = ( |
| | | labels[0] |
| | | .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:]) |
| | | .fill_(padding_value) |
| | | ) |
| | | |
| | | for i in range(batch_size): |
| | | padded[i, : labels[i].size(0)] = labels[i] |
| | | |
| | | return padded |
| | | |
| | | device = labels.device |
| | | |
| | | labels_unpad = [y[y != ignore_id] for y in labels] |
| | | blank = labels[0].new([blank_id]) |
| | | |
| | | decoder_in = pad_list( |
| | | [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id |
| | | ).to(device) |
| | | |
| | | target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device) |
| | | |
| | | encoder_out_lens = list(map(int, encoder_out_lens)) |
| | | t_len = torch.IntTensor(encoder_out_lens).to(device) |
| | | |
| | | u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device) |
| | | |
| | | return decoder_in, target, t_len, u_len |
| | | |
| | | def pad_to_len(t: torch.Tensor, pad_len: int, dim: int): |
| | | """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros.""" |
| | | if t.size(dim) == pad_len: |
| | | return t |
| | | else: |
| | | pad_size = list(t.shape) |
| | | pad_size[dim] = pad_len - t.size(dim) |
| | | return torch.cat( |
| | | [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim |
| | | ) |
| | |
| | | |
| | | """Repeat the same layer definition.""" |
| | | |
| | | from typing import Dict, List, Optional |
| | | |
| | | import torch |
| | | |
| | | |
| | |
| | | |
| | | """ |
| | | return MultiSequential(*[fn(n) for n in range(N)]) |
| | | |
| | | |
| | | class MultiBlocks(torch.nn.Module): |
| | | """MultiBlocks definition. |
| | | Args: |
| | | block_list: Individual blocks of the encoder architecture. |
| | | output_size: Architecture output size. |
| | | norm_class: Normalization module class. |
| | | norm_args: Normalization module arguments. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | block_list: List[torch.nn.Module], |
| | | output_size: int, |
| | | norm_class: torch.nn.Module = torch.nn.LayerNorm, |
| | | ) -> None: |
| | | """Construct a MultiBlocks object.""" |
| | | super().__init__() |
| | | |
| | | self.blocks = torch.nn.ModuleList(block_list) |
| | | self.norm_blocks = norm_class(output_size) |
| | | |
| | | self.num_blocks = len(block_list) |
| | | |
| | | def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: |
| | | """Initialize/Reset encoder streaming cache. |
| | | Args: |
| | | left_context: Number of left frames during chunk-by-chunk inference. |
| | | device: Device to use for cache tensor. |
| | | """ |
| | | for idx in range(self.num_blocks): |
| | | self.blocks[idx].reset_streaming_cache(left_context, device) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_mask: Optional[torch.Tensor] = None, |
| | | ) -> torch.Tensor: |
| | | """Forward each block of the encoder architecture. |
| | | Args: |
| | | x: MultiBlocks input sequences. (B, T, D_block_1) |
| | | pos_enc: Positional embedding sequences. |
| | | mask: Source mask. (B, T) |
| | | chunk_mask: Chunk mask. (T_2, T_2) |
| | | Returns: |
| | | x: Output sequences. (B, T, D_block_N) |
| | | """ |
| | | for block_index, block in enumerate(self.blocks): |
| | | x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask) |
| | | |
| | | x = self.norm_blocks(x) |
| | | |
| | | return x |
| | | |
| | | def chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_size: int = 0, |
| | | left_context: int = 0, |
| | | right_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Forward each block of the encoder architecture. |
| | | Args: |
| | | x: MultiBlocks input sequences. (B, T, D_block_1) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att) |
| | | mask: Source mask. (B, T_2) |
| | | left_context: Number of frames in left context. |
| | | right_context: Number of frames in right context. |
| | | Returns: |
| | | x: MultiBlocks output sequences. (B, T, D_block_N) |
| | | """ |
| | | for block_idx, block in enumerate(self.blocks): |
| | | x, pos_enc = block.chunk_forward( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_size=chunk_size, |
| | | left_context=left_context, |
| | | right_context=right_context, |
| | | ) |
| | | |
| | | x = self.norm_blocks(x) |
| | | |
| | | return x |
| | |
| | | from funasr.modules.embedding import PositionalEncoding |
| | | import logging |
| | | from funasr.modules.streaming_utils.utils import sequence_mask |
| | | from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len |
| | | from typing import Optional, Tuple, Union |
| | | import math |
| | | |
| | | class TooShortUttError(Exception): |
| | | """Raised when the utt is too short for subsampling. |
| | | |
| | |
| | | var_dict_tf[name_tf].shape)) |
| | | return var_dict_torch_update |
| | | |
| | | class StreamingConvInput(torch.nn.Module): |
| | | """Streaming ConvInput module definition. |
| | | Args: |
| | | input_size: Input size. |
| | | conv_size: Convolution size. |
| | | subsampling_factor: Subsampling factor. |
| | | vgg_like: Whether to use a VGG-like network. |
| | | output_size: Block output dimension. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | conv_size: Union[int, Tuple], |
| | | subsampling_factor: int = 4, |
| | | vgg_like: bool = True, |
| | | output_size: Optional[int] = None, |
| | | ) -> None: |
| | | """Construct a ConvInput object.""" |
| | | super().__init__() |
| | | if vgg_like: |
| | | if subsampling_factor == 1: |
| | | conv_size1, conv_size2 = conv_size |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((1, 2)), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((1, 2)), |
| | | ) |
| | | |
| | | output_proj = conv_size2 * ((input_size // 2) // 2) |
| | | |
| | | self.subsampling_factor = 1 |
| | | |
| | | self.stride_1 = 1 |
| | | |
| | | self.create_new_mask = self.create_new_vgg_mask |
| | | |
| | | else: |
| | | conv_size1, conv_size2 = conv_size |
| | | |
| | | kernel_1 = int(subsampling_factor / 2) |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((kernel_1, 2)), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((2, 2)), |
| | | ) |
| | | |
| | | output_proj = conv_size2 * ((input_size // 2) // 2) |
| | | |
| | | self.subsampling_factor = subsampling_factor |
| | | |
| | | self.create_new_mask = self.create_new_vgg_mask |
| | | |
| | | self.stride_1 = kernel_1 |
| | | |
| | | else: |
| | | if subsampling_factor == 1: |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]), |
| | | torch.nn.ReLU(), |
| | | ) |
| | | |
| | | output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2) |
| | | |
| | | self.subsampling_factor = subsampling_factor |
| | | self.kernel_2 = 3 |
| | | self.stride_2 = 1 |
| | | |
| | | self.create_new_mask = self.create_new_conv2d_mask |
| | | |
| | | else: |
| | | kernel_2, stride_2, conv_2_output_size = sub_factor_to_params( |
| | | subsampling_factor, |
| | | input_size, |
| | | ) |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size, 3, 2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2), |
| | | torch.nn.ReLU(), |
| | | ) |
| | | |
| | | output_proj = conv_size * conv_2_output_size |
| | | |
| | | self.subsampling_factor = subsampling_factor |
| | | self.kernel_2 = kernel_2 |
| | | self.stride_2 = stride_2 |
| | | |
| | | self.create_new_mask = self.create_new_conv2d_mask |
| | | |
| | | self.vgg_like = vgg_like |
| | | self.min_frame_length = 7 |
| | | |
| | | if output_size is not None: |
| | | self.output = torch.nn.Linear(output_proj, output_size) |
| | | self.output_size = output_size |
| | | else: |
| | | self.output = None |
| | | self.output_size = output_proj |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor] |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: ConvInput input sequences. (B, T, D_feats) |
| | | mask: Mask of input sequences. (B, 1, T) |
| | | Returns: |
| | | x: ConvInput output sequences. (B, sub(T), D_out) |
| | | mask: Mask of output sequences. (B, 1, sub(T)) |
| | | """ |
| | | if mask is not None: |
| | | mask = self.create_new_mask(mask) |
| | | olens = max(mask.eq(0).sum(1)) |
| | | |
| | | b, t, f = x.size() |
| | | x = x.unsqueeze(1) # (b. 1. t. f) |
| | | |
| | | if chunk_size is not None: |
| | | max_input_length = int( |
| | | chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) )) |
| | | ) |
| | | x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x) |
| | | x = list(x) |
| | | x = torch.stack(x, dim=0) |
| | | N_chunks = max_input_length // ( chunk_size * self.subsampling_factor) |
| | | x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f) |
| | | |
| | | x = self.conv(x) |
| | | |
| | | _, c, _, f = x.size() |
| | | if chunk_size is not None: |
| | | x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:] |
| | | else: |
| | | x = x.transpose(1, 2).contiguous().view(b, -1, c * f) |
| | | |
| | | if self.output is not None: |
| | | x = self.output(x) |
| | | |
| | | return x, mask[:,:olens][:,:x.size(1)] |
| | | |
| | | def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor: |
| | | """Create a new mask for VGG output sequences. |
| | | Args: |
| | | mask: Mask of input sequences. (B, T) |
| | | Returns: |
| | | mask: Mask of output sequences. (B, sub(T)) |
| | | """ |
| | | if self.subsampling_factor > 1: |
| | | vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 )) |
| | | mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2] |
| | | |
| | | vgg2_t_len = mask.size(1) - (mask.size(1) % 2) |
| | | mask = mask[:, :vgg2_t_len][:, ::2] |
| | | else: |
| | | mask = mask |
| | | |
| | | return mask |
| | | |
| | | def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor: |
| | | """Create new conformer mask for Conv2d output sequences. |
| | | Args: |
| | | mask: Mask of input sequences. (B, T) |
| | | Returns: |
| | | mask: Mask of output sequences. (B, sub(T)) |
| | | """ |
| | | if self.subsampling_factor > 1: |
| | | return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2] |
| | | else: |
| | | return mask |
| | | |
| | | def get_size_before_subsampling(self, size: int) -> int: |
| | | """Return the original size before subsampling for a given size. |
| | | Args: |
| | | size: Number of frames after subsampling. |
| | | Returns: |
| | | : Number of frames before subsampling. |
| | | """ |
| | | return size * self.subsampling_factor |
| | |
| | | from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN |
| | | from funasr.models.decoder.transformer_decoder import TransformerDecoder |
| | | from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder |
| | | from funasr.models.decoder.rnnt_decoder import RNNTDecoder |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | from funasr.models.e2e_asr import ESPnetASRModel |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerOnline, ParaformerBert, BiCifParaformer, ContextualParaformer |
| | | from funasr.models.e2e_tp import TimestampPredictor |
| | | from funasr.models.e2e_asr_mfcca import MFCCA |
| | | from funasr.models.e2e_uni_asr import UniASR |
| | | from funasr.models.e2e_asr_transducer import TransducerModel, UnifiedTransducerModel |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder, ConformerChunkEncoder |
| | | from funasr.models.encoder.data2vec_encoder import Data2VecEncoder |
| | | from funasr.models.encoder.rnn_encoder import RNNEncoder |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt |
| | |
| | | sanm_chunk_opt=SANMEncoderChunkOpt, |
| | | data2vec_encoder=Data2VecEncoder, |
| | | mfcca_enc=MFCCAEncoder, |
| | | chunk_conformer=ConformerChunkEncoder, |
| | | ), |
| | | type_check=AbsEncoder, |
| | | default="rnn", |
| | |
| | | type_check=AbsDecoder, |
| | | default="rnn", |
| | | ) |
| | | |
| | | rnnt_decoder_choices = ClassChoices( |
| | | "rnnt_decoder", |
| | | classes=dict( |
| | | rnnt=RNNTDecoder, |
| | | ), |
| | | type_check=RNNTDecoder, |
| | | default="rnnt", |
| | | ) |
| | | |
| | | predictor_choices = ClassChoices( |
| | | name="predictor", |
| | | classes=dict( |
| | |
| | | ) -> Tuple[str, ...]: |
| | | retval = ("speech", "text") |
| | | return retval |
| | | |
| | | |
| | | class ASRTransducerTask(AbsTask): |
| | | """ASR Transducer Task definition.""" |
| | | |
| | | num_optimizers: int = 1 |
| | | |
| | | class_choices_list = [ |
| | | frontend_choices, |
| | | specaug_choices, |
| | | normalize_choices, |
| | | encoder_choices, |
| | | rnnt_decoder_choices, |
| | | ] |
| | | |
| | | trainer = Trainer |
| | | |
| | | @classmethod |
| | | def add_task_arguments(cls, parser: argparse.ArgumentParser): |
| | | """Add Transducer task arguments. |
| | | Args: |
| | | cls: ASRTransducerTask object. |
| | | parser: Transducer arguments parser. |
| | | """ |
| | | group = parser.add_argument_group(description="Task related.") |
| | | |
| | | # required = parser.get_default("required") |
| | | # required += ["token_list"] |
| | | |
| | | group.add_argument( |
| | | "--token_list", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="Integer-string mapper for tokens.", |
| | | ) |
| | | group.add_argument( |
| | | "--split_with_space", |
| | | type=str2bool, |
| | | default=True, |
| | | help="whether to split text using <space>", |
| | | ) |
| | | group.add_argument( |
| | | "--input_size", |
| | | type=int_or_none, |
| | | default=None, |
| | | help="The number of dimensions for input features.", |
| | | ) |
| | | group.add_argument( |
| | | "--init", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="Type of model initialization to use.", |
| | | ) |
| | | group.add_argument( |
| | | "--model_conf", |
| | | action=NestedDictAction, |
| | | default=get_default_kwargs(TransducerModel), |
| | | help="The keyword arguments for the model class.", |
| | | ) |
| | | # group.add_argument( |
| | | # "--encoder_conf", |
| | | # action=NestedDictAction, |
| | | # default={}, |
| | | # help="The keyword arguments for the encoder class.", |
| | | # ) |
| | | group.add_argument( |
| | | "--joint_network_conf", |
| | | action=NestedDictAction, |
| | | default={}, |
| | | help="The keyword arguments for the joint network class.", |
| | | ) |
| | | group = parser.add_argument_group(description="Preprocess related.") |
| | | group.add_argument( |
| | | "--use_preprocessor", |
| | | type=str2bool, |
| | | default=True, |
| | | help="Whether to apply preprocessing to input data.", |
| | | ) |
| | | group.add_argument( |
| | | "--token_type", |
| | | type=str, |
| | | default="bpe", |
| | | choices=["bpe", "char", "word", "phn"], |
| | | help="The type of tokens to use during tokenization.", |
| | | ) |
| | | group.add_argument( |
| | | "--bpemodel", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The path of the sentencepiece model.", |
| | | ) |
| | | parser.add_argument( |
| | | "--non_linguistic_symbols", |
| | | type=str_or_none, |
| | | help="The 'non_linguistic_symbols' file path.", |
| | | ) |
| | | parser.add_argument( |
| | | "--cleaner", |
| | | type=str_or_none, |
| | | choices=[None, "tacotron", "jaconv", "vietnamese"], |
| | | default=None, |
| | | help="Text cleaner to use.", |
| | | ) |
| | | parser.add_argument( |
| | | "--g2p", |
| | | type=str_or_none, |
| | | choices=g2p_choices, |
| | | default=None, |
| | | help="g2p method to use if --token_type=phn.", |
| | | ) |
| | | parser.add_argument( |
| | | "--speech_volume_normalize", |
| | | type=float_or_none, |
| | | default=None, |
| | | help="Normalization value for maximum amplitude scaling.", |
| | | ) |
| | | parser.add_argument( |
| | | "--rir_scp", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The RIR SCP file path.", |
| | | ) |
| | | parser.add_argument( |
| | | "--rir_apply_prob", |
| | | type=float, |
| | | default=1.0, |
| | | help="The probability of the applied RIR convolution.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_scp", |
| | | type=str_or_none, |
| | | default=None, |
| | | help="The path of noise SCP file.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_apply_prob", |
| | | type=float, |
| | | default=1.0, |
| | | help="The probability of the applied noise addition.", |
| | | ) |
| | | parser.add_argument( |
| | | "--noise_db_range", |
| | | type=str, |
| | | default="13_15", |
| | | help="The range of the noise decibel level.", |
| | | ) |
| | | for class_choices in cls.class_choices_list: |
| | | # Append --<name> and --<name>_conf. |
| | | # e.g. --decoder and --decoder_conf |
| | | class_choices.add_arguments(group) |
| | | |
| | | @classmethod |
| | | def build_collate_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Callable[ |
| | | [Collection[Tuple[str, Dict[str, np.ndarray]]]], |
| | | Tuple[List[str], Dict[str, torch.Tensor]], |
| | | ]: |
| | | """Build collate function. |
| | | Args: |
| | | cls: ASRTransducerTask object. |
| | | args: Task arguments. |
| | | train: Training mode. |
| | | Return: |
| | | : Callable collate function. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1) |
| | | |
| | | @classmethod |
| | | def build_preprocess_fn( |
| | | cls, args: argparse.Namespace, train: bool |
| | | ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: |
| | | """Build pre-processing function. |
| | | Args: |
| | | cls: ASRTransducerTask object. |
| | | args: Task arguments. |
| | | train: Training mode. |
| | | Return: |
| | | : Callable pre-processing function. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | if args.use_preprocessor: |
| | | retval = CommonPreprocessor( |
| | | train=train, |
| | | token_type=args.token_type, |
| | | token_list=args.token_list, |
| | | bpemodel=args.bpemodel, |
| | | non_linguistic_symbols=args.non_linguistic_symbols, |
| | | text_cleaner=args.cleaner, |
| | | g2p_type=args.g2p, |
| | | split_with_space=args.split_with_space if hasattr(args, "split_with_space") else False, |
| | | rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None, |
| | | rir_apply_prob=args.rir_apply_prob |
| | | if hasattr(args, "rir_apply_prob") |
| | | else 1.0, |
| | | noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None, |
| | | noise_apply_prob=args.noise_apply_prob |
| | | if hasattr(args, "noise_apply_prob") |
| | | else 1.0, |
| | | noise_db_range=args.noise_db_range |
| | | if hasattr(args, "noise_db_range") |
| | | else "13_15", |
| | | speech_volume_normalize=args.speech_volume_normalize |
| | | if hasattr(args, "rir_scp") |
| | | else None, |
| | | ) |
| | | else: |
| | | retval = None |
| | | |
| | | assert check_return_type(retval) |
| | | return retval |
| | | |
| | | @classmethod |
| | | def required_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | """Required data depending on task mode. |
| | | Args: |
| | | cls: ASRTransducerTask object. |
| | | train: Training mode. |
| | | inference: Inference mode. |
| | | Return: |
| | | retval: Required task data. |
| | | """ |
| | | if not inference: |
| | | retval = ("speech", "text") |
| | | else: |
| | | retval = ("speech",) |
| | | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def optional_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | """Optional data depending on task mode. |
| | | Args: |
| | | cls: ASRTransducerTask object. |
| | | train: Training mode. |
| | | inference: Inference mode. |
| | | Return: |
| | | retval: Optional task data. |
| | | """ |
| | | retval = () |
| | | assert check_return_type(retval) |
| | | |
| | | return retval |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace) -> TransducerModel: |
| | | """Required data depending on task mode. |
| | | Args: |
| | | cls: ASRTransducerTask object. |
| | | args: Task arguments. |
| | | Return: |
| | | model: ASR Transducer model. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | |
| | | # Overwriting token_list to keep it as "portable". |
| | | args.token_list = list(token_list) |
| | | elif isinstance(args.token_list, (tuple, list)): |
| | | token_list = list(args.token_list) |
| | | else: |
| | | raise RuntimeError("token_list must be str or list") |
| | | vocab_size = len(token_list) |
| | | logging.info(f"Vocabulary size: {vocab_size }") |
| | | |
| | | # 1. frontend |
| | | if args.input_size is None: |
| | | # Extract features in the model |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | # Give features from data-loader |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | # 2. Data augmentation for spectrogram |
| | | if args.specaug is not None: |
| | | specaug_class = specaug_choices.get_class(args.specaug) |
| | | specaug = specaug_class(**args.specaug_conf) |
| | | else: |
| | | specaug = None |
| | | |
| | | # 3. Normalization layer |
| | | if args.normalize is not None: |
| | | normalize_class = normalize_choices.get_class(args.normalize) |
| | | normalize = normalize_class(**args.normalize_conf) |
| | | else: |
| | | normalize = None |
| | | |
| | | # 4. Encoder |
| | | |
| | | if getattr(args, "encoder", None) is not None: |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(input_size, **args.encoder_conf) |
| | | else: |
| | | encoder = Encoder(input_size, **args.encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | | |
| | | # 5. Decoder |
| | | rnnt_decoder_class = rnnt_decoder_choices.get_class(args.rnnt_decoder) |
| | | decoder = rnnt_decoder_class( |
| | | vocab_size, |
| | | **args.rnnt_decoder_conf, |
| | | ) |
| | | decoder_output_size = decoder.output_size |
| | | |
| | | if getattr(args, "decoder", None) is not None: |
| | | att_decoder_class = decoder_choices.get_class(args.att_decoder) |
| | | |
| | | att_decoder = att_decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | **args.decoder_conf, |
| | | ) |
| | | else: |
| | | att_decoder = None |
| | | # 6. Joint Network |
| | | joint_network = JointNetwork( |
| | | vocab_size, |
| | | encoder_output_size, |
| | | decoder_output_size, |
| | | **args.joint_network_conf, |
| | | ) |
| | | |
| | | # 7. Build model |
| | | |
| | | if encoder.unified_model_training: |
| | | model = UnifiedTransducerModel( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | else: |
| | | model = TransducerModel( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |
| | | frontend=frontend, |
| | | specaug=specaug, |
| | | normalize=normalize, |
| | | encoder=encoder, |
| | | decoder=decoder, |
| | | att_decoder=att_decoder, |
| | | joint_network=joint_network, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | # 8. Initialize model |
| | | if args.init is not None: |
| | | raise NotImplementedError( |
| | | "Currently not supported.", |
| | | "Initialization part will be reworked in a short future.", |
| | | ) |
| | | |
| | | #assert check_return_type(model) |
| | | |
| | | return model |