Support eparaformer model on aishell1 recipe (#2327)
| New file |
| | |
| | | |
| | | # network architecture |
| | | model: EParaformer |
| | | model_conf: |
| | | ctc_weight: 0.0 |
| | | lsm_weight: 0.1 |
| | | length_normalized_loss: false |
| | | predictor_weight: 1.0 |
| | | predictor_bias: 2 |
| | | sampling_ratio: 0.4 |
| | | use_1st_decoder_loss: true |
| | | |
| | | # encoder |
| | | encoder: ConformerEncoder |
| | | encoder_conf: |
| | | output_size: 256 # dimension of attention |
| | | attention_heads: 4 |
| | | linear_units: 2048 # the number of units of position-wise feed forward |
| | | num_blocks: 12 # the number of encoder blocks |
| | | dropout_rate: 0.1 |
| | | positional_dropout_rate: 0.1 |
| | | attention_dropout_rate: 0.0 |
| | | input_layer: conv2d # encoder architecture type |
| | | normalize_before: true |
| | | pos_enc_layer_type: rel_pos |
| | | selfattention_layer_type: rel_selfattn |
| | | activation_type: swish |
| | | macaron_style: true |
| | | use_cnn_module: true |
| | | cnn_module_kernel: 15 |
| | | |
| | | # decoder |
| | | decoder: ParaformerSANDecoder |
| | | decoder_conf: |
| | | attention_heads: 4 |
| | | linear_units: 2048 |
| | | num_blocks: 6 |
| | | dropout_rate: 0.1 |
| | | positional_dropout_rate: 0.1 |
| | | self_attention_dropout_rate: 0.0 |
| | | src_attention_dropout_rate: 0.0 |
| | | |
| | | # predictor |
| | | predictor: PifPredictor |
| | | predictor_conf: |
| | | idim: 256 |
| | | threshold: 1.0 |
| | | l_order: 1 |
| | | r_order: 1 |
| | | sigma: 0.5 |
| | | bias: 0.0 |
| | | sigma_heads: 4 |
| | | |
| | | # frontend related |
| | | frontend: WavFrontend |
| | | frontend_conf: |
| | | fs: 16000 |
| | | window: hamming |
| | | n_mels: 80 |
| | | frame_length: 25 |
| | | frame_shift: 10 |
| | | lfr_m: 1 |
| | | lfr_n: 1 |
| | | |
| | | specaug: SpecAug |
| | | specaug_conf: |
| | | apply_time_warp: true |
| | | time_warp_window: 5 |
| | | time_warp_mode: bicubic |
| | | apply_freq_mask: true |
| | | freq_mask_width_range: |
| | | - 0 |
| | | - 30 |
| | | num_freq_mask: 2 |
| | | apply_time_mask: true |
| | | time_mask_width_range: |
| | | - 0 |
| | | - 40 |
| | | num_time_mask: 2 |
| | | |
| | | train_conf: |
| | | accum_grad: 4 |
| | | grad_clip: 5 |
| | | max_epoch: 150 |
| | | keep_nbest_models: 20 |
| | | avg_nbest_model: 15 |
| | | log_interval: 50 |
| | | |
| | | optim: adam |
| | | optim_conf: |
| | | lr: 0.0005 |
| | | scheduler: warmuplr |
| | | scheduler_conf: |
| | | warmup_steps: 30000 |
| | | |
| | | dataset: AudioDataset |
| | | dataset_conf: |
| | | index_ds: IndexDSJsonl |
| | | batch_sampler: EspnetStyleBatchSampler |
| | | batch_type: length # example or length |
| | | batch_size: 25000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; |
| | | max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, |
| | | buffer_size: 1024 |
| | | shuffle: True |
| | | num_workers: 4 |
| | | preprocessor_speech: SpeechPreprocessSpeedPerturb |
| | | preprocessor_speech_conf: |
| | | speed_perturb: [0.9, 1.0, 1.1] |
| | | |
| | | tokenizer: CharTokenizer |
| | | tokenizer_conf: |
| | | unk_symbol: <unk> |
| | | |
| | | ctc_conf: |
| | | dropout_rate: 0.0 |
| | | ctc_type: builtin |
| | | reduce: true |
| | | ignore_nan_grad: true |
| | | normalize: null |
| | | |
| | | |
| New file |
| | |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | |
| | | |
| | | python -m funasr.bin.inference \ |
| | | --config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \ |
| | | --config-name="config.yaml" \ |
| | | ++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \ |
| | | ++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \ |
| | | ++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \ |
| | | ++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \ |
| | | ++output_dir="./outputs/debug" \ |
| | | ++device="cuda:0" \ |
| | | |
| New file |
| | |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | |
| | | # which gpu to train or finetune |
| | | export CUDA_VISIBLE_DEVICES="0,1" |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | |
| | | # data dir, which contains: train.json, val.json, tokens.jsonl/tokens.txt, am.mvn |
| | | data_dir="/Users/zhifu/funasr1.0/data/list" |
| | | |
| | | ## generate jsonl from wav.scp and text.txt |
| | | #python -m funasr.datasets.audio_datasets.scp2jsonl \ |
| | | #++scp_file_list='["/Users/zhifu/funasr1.0/test_local/wav.scp", "/Users/zhifu/funasr1.0/test_local/text.txt"]' \ |
| | | #++data_type_list='["source", "target"]' \ |
| | | #++jsonl_file_out=/Users/zhifu/funasr1.0/test_local/audio_datasets.jsonl |
| | | |
| | | train_data="${data_dir}/train.jsonl" |
| | | val_data="${data_dir}/val.jsonl" |
| | | tokens="${data_dir}/tokens.json" |
| | | cmvn_file="${data_dir}/am.mvn" |
| | | |
| | | # exp output dir |
| | | output_dir="/Users/zhifu/exp" |
| | | log_file="${output_dir}/log.txt" |
| | | |
| | | workspace=`pwd` |
| | | config="paraformer_conformer_12e_6d_2048_256.yaml" |
| | | |
| | | init_param="${output_dir}/model.pt" |
| | | |
| | | mkdir -p ${output_dir} |
| | | echo "log_file: ${log_file}" |
| | | |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | | --nproc_per_node ${gpu_num} \ |
| | | ../../../funasr/bin/train.py \ |
| | | --config-path "${workspace}/conf" \ |
| | | --config-name "${config}" \ |
| | | ++train_data_set_list="${train_data}" \ |
| | | ++valid_data_set_list="${val_data}" \ |
| | | ++tokenizer_conf.token_list="${tokens}" \ |
| | | ++frontend_conf.cmvn_file="${cmvn_file}" \ |
| | | ++dataset_conf.batch_size=32 \ |
| | | ++dataset_conf.batch_type="example" \ |
| | | ++dataset_conf.num_workers=4 \ |
| | | ++train_conf.max_epoch=150 \ |
| | | ++optim_conf.lr=0.0002 \ |
| | | ++init_param="${init_param}" \ |
| | | ++output_dir="${output_dir}" &> ${log_file} |
| New file |
| | |
| | | #!/bin/bash |
| | | |
| | | # Copyright 2017 Xingyu Na |
| | | # Apache 2.0 |
| | | |
| | | #. ./path.sh || exit 1; |
| | | |
| | | if [ $# != 3 ]; then |
| | | echo "Usage: $0 <audio-path> <text-path> <output-path>" |
| | | echo " $0 /export/a05/xna/data/data_aishell/wav /export/a05/xna/data/data_aishell/transcript data" |
| | | exit 1; |
| | | fi |
| | | |
| | | aishell_audio_dir=$1 |
| | | aishell_text=$2/aishell_transcript_v0.8.txt |
| | | output_dir=$3 |
| | | |
| | | train_dir=$output_dir/data/local/train |
| | | dev_dir=$output_dir/data/local/dev |
| | | test_dir=$output_dir/data/local/test |
| | | tmp_dir=$output_dir/data/local/tmp |
| | | |
| | | mkdir -p $train_dir |
| | | mkdir -p $dev_dir |
| | | mkdir -p $test_dir |
| | | mkdir -p $tmp_dir |
| | | |
| | | # data directory check |
| | | if [ ! -d $aishell_audio_dir ] || [ ! -f $aishell_text ]; then |
| | | echo "Error: $0 requires two directory arguments" |
| | | exit 1; |
| | | fi |
| | | |
| | | # find wav audio file for train, dev and test resp. |
| | | find $aishell_audio_dir -iname "*.wav" > $tmp_dir/wav.flist |
| | | n=`cat $tmp_dir/wav.flist | wc -l` |
| | | [ $n -ne 141925 ] && \ |
| | | echo Warning: expected 141925 data data files, found $n |
| | | |
| | | grep -i "wav/train" $tmp_dir/wav.flist > $train_dir/wav.flist || exit 1; |
| | | grep -i "wav/dev" $tmp_dir/wav.flist > $dev_dir/wav.flist || exit 1; |
| | | grep -i "wav/test" $tmp_dir/wav.flist > $test_dir/wav.flist || exit 1; |
| | | |
| | | rm -r $tmp_dir |
| | | |
| | | # Transcriptions preparation |
| | | for dir in $train_dir $dev_dir $test_dir; do |
| | | echo Preparing $dir transcriptions |
| | | sed -e 's/\.wav//' $dir/wav.flist | awk -F '/' '{print $NF}' > $dir/utt.list |
| | | paste -d' ' $dir/utt.list $dir/wav.flist > $dir/wav.scp_all |
| | | utils/filter_scp.pl -f 1 $dir/utt.list $aishell_text > $dir/transcripts.txt |
| | | awk '{print $1}' $dir/transcripts.txt > $dir/utt.list |
| | | utils/filter_scp.pl -f 1 $dir/utt.list $dir/wav.scp_all | sort -u > $dir/wav.scp |
| | | sort -u $dir/transcripts.txt > $dir/text |
| | | done |
| | | |
| | | mkdir -p $output_dir/data/train $output_dir/data/dev $output_dir/data/test |
| | | |
| | | for f in wav.scp text; do |
| | | cp $train_dir/$f $output_dir/data/train/$f || exit 1; |
| | | cp $dev_dir/$f $output_dir/data/dev/$f || exit 1; |
| | | cp $test_dir/$f $output_dir/data/test/$f || exit 1; |
| | | done |
| | | |
| | | echo "$0: AISHELL data preparation succeeded" |
| | | exit 0; |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | # Copyright 2014 Johns Hopkins University (author: Daniel Povey) |
| | | # 2017 Xingyu Na |
| | | # Apache 2.0 |
| | | |
| | | remove_archive=false |
| | | |
| | | if [ "$1" == --remove-archive ]; then |
| | | remove_archive=true |
| | | shift |
| | | fi |
| | | |
| | | if [ $# -ne 3 ]; then |
| | | echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>" |
| | | echo "e.g.: $0 /export/a05/xna/data www.openslr.org/resources/33 data_aishell" |
| | | echo "With --remove-archive it will remove the archive after successfully un-tarring it." |
| | | echo "<corpus-part> can be one of: data_aishell, resource_aishell." |
| | | fi |
| | | |
| | | data=$1 |
| | | url=$2 |
| | | part=$3 |
| | | |
| | | if [ ! -d "$data" ]; then |
| | | echo "$0: no such directory $data" |
| | | exit 1; |
| | | fi |
| | | |
| | | part_ok=false |
| | | list="data_aishell resource_aishell" |
| | | for x in $list; do |
| | | if [ "$part" == $x ]; then part_ok=true; fi |
| | | done |
| | | if ! $part_ok; then |
| | | echo "$0: expected <corpus-part> to be one of $list, but got '$part'" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ -z "$url" ]; then |
| | | echo "$0: empty URL base." |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ -f $data/$part/.complete ]; then |
| | | echo "$0: data part $part was already successfully extracted, nothing to do." |
| | | exit 0; |
| | | fi |
| | | |
| | | # sizes of the archive files in bytes. |
| | | sizes="15582913665 1246920" |
| | | |
| | | if [ -f $data/$part.tgz ]; then |
| | | size=$(/bin/ls -l $data/$part.tgz | awk '{print $5}') |
| | | size_ok=false |
| | | for s in $sizes; do if [ $s == $size ]; then size_ok=true; fi; done |
| | | if ! $size_ok; then |
| | | echo "$0: removing existing file $data/$part.tgz because its size in bytes $size" |
| | | echo "does not equal the size of one of the archives." |
| | | rm $data/$part.tgz |
| | | else |
| | | echo "$data/$part.tgz exists and appears to be complete." |
| | | fi |
| | | fi |
| | | |
| | | if [ ! -f $data/$part.tgz ]; then |
| | | if ! command -v wget >/dev/null; then |
| | | echo "$0: wget is not installed." |
| | | exit 1; |
| | | fi |
| | | full_url=$url/$part.tgz |
| | | echo "$0: downloading data from $full_url. This may take some time, please be patient." |
| | | |
| | | cd $data || exit 1 |
| | | if ! wget --no-check-certificate $full_url; then |
| | | echo "$0: error executing wget $full_url" |
| | | exit 1; |
| | | fi |
| | | fi |
| | | |
| | | cd $data || exit 1 |
| | | |
| | | if ! tar -xvzf $part.tgz; then |
| | | echo "$0: error un-tarring archive $data/$part.tgz" |
| | | exit 1; |
| | | fi |
| | | |
| | | touch $data/$part/.complete |
| | | |
| | | if [ $part == "data_aishell" ]; then |
| | | cd $data/$part/wav || exit 1 |
| | | for wav in ./*.tar.gz; do |
| | | echo "Extracting wav from $wav" |
| | | tar -zxf $wav && rm $wav |
| | | done |
| | | fi |
| | | |
| | | echo "$0: Successfully downloaded and un-tarred $data/$part.tgz" |
| | | |
| | | if $remove_archive; then |
| | | echo "$0: removing $data/$part.tgz file since --remove-archive option was supplied." |
| | | rm $data/$part.tgz |
| | | fi |
| | | |
| | | exit 0; |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | |
| | | CUDA_VISIBLE_DEVICES="0,1" |
| | | |
| | | # general configuration |
| | | feats_dir="../DATA" #feature output dictionary |
| | | exp_dir=`pwd` |
| | | lang=zh |
| | | token_type=char |
| | | stage=0 |
| | | stop_stage=5 |
| | | |
| | | # feature configuration |
| | | nj=32 |
| | | |
| | | inference_device="cuda" #"cpu" |
| | | inference_checkpoint="model.pt.avg10" |
| | | inference_scp="wav.scp" |
| | | inference_batch_size=32 |
| | | |
| | | # data |
| | | raw_data=../raw_data |
| | | data_url=www.openslr.org/resources/33 |
| | | |
| | | # exp tag |
| | | tag="exp1" |
| | | workspace=`pwd` |
| | | |
| | | master_port=12345 |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # Set bash to 'debug' mode, it will exit on : |
| | | # -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands', |
| | | set -e |
| | | set -u |
| | | set -o pipefail |
| | | |
| | | train_set=train |
| | | valid_set=dev |
| | | test_sets="dev test" |
| | | |
| | | config=e_paraformer_conformer_12e_6d_2048_256.yaml |
| | | model_dir="baseline_$(basename "${config}" .yaml)_${lang}_${token_type}_${tag}" |
| | | |
| | | |
| | | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then |
| | | echo "stage -1: Data Download" |
| | | mkdir -p ${raw_data} |
| | | local/download_and_untar.sh ${raw_data} ${data_url} data_aishell |
| | | local/download_and_untar.sh ${raw_data} ${data_url} resource_aishell |
| | | fi |
| | | |
| | | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then |
| | | echo "stage 0: Data preparation" |
| | | # Data preparation |
| | | local/aishell_data_prep.sh ${raw_data}/data_aishell/wav ${raw_data}/data_aishell/transcript ${feats_dir} |
| | | for x in train dev test; do |
| | | cp ${feats_dir}/data/${x}/text ${feats_dir}/data/${x}/text.org |
| | | paste -d " " <(cut -f 1 -d" " ${feats_dir}/data/${x}/text.org) <(cut -f 2- -d" " ${feats_dir}/data/${x}/text.org | tr -d " ") \ |
| | | > ${feats_dir}/data/${x}/text |
| | | utils/text2token.py -n 1 -s 1 ${feats_dir}/data/${x}/text > ${feats_dir}/data/${x}/text.org |
| | | mv ${feats_dir}/data/${x}/text.org ${feats_dir}/data/${x}/text |
| | | |
| | | # convert wav.scp text to jsonl |
| | | scp_file_list_arg="++scp_file_list='[\"${feats_dir}/data/${x}/wav.scp\",\"${feats_dir}/data/${x}/text\"]'" |
| | | python ../../../funasr/datasets/audio_datasets/scp2jsonl.py \ |
| | | ++data_type_list='["source", "target"]' \ |
| | | ++jsonl_file_out=${feats_dir}/data/${x}/audio_datasets.jsonl \ |
| | | ${scp_file_list_arg} |
| | | done |
| | | fi |
| | | |
| | | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then |
| | | echo "stage 1: Feature and CMVN Generation" |
| | | python ../../../funasr/bin/compute_audio_cmvn.py \ |
| | | --config-path "${workspace}/conf" \ |
| | | --config-name "${config}" \ |
| | | ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \ |
| | | ++cmvn_file="${feats_dir}/data/${train_set}/cmvn.json" |
| | | fi |
| | | |
| | | token_list=${feats_dir}/data/${lang}_token_list/$token_type/tokens.txt |
| | | echo "dictionary: ${token_list}" |
| | | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then |
| | | echo "stage 2: Dictionary Preparation" |
| | | mkdir -p ${feats_dir}/data/${lang}_token_list/$token_type/ |
| | | |
| | | echo "make a dictionary" |
| | | echo "<blank>" > ${token_list} |
| | | echo "<s>" >> ${token_list} |
| | | echo "</s>" >> ${token_list} |
| | | utils/text2token.py -s 1 -n 1 --space "" ${feats_dir}/data/$train_set/text | cut -f 2- -d" " | tr " " "\n" \ |
| | | | sort | uniq | grep -a -v -e '^\s*$' | awk '{print $0}' >> ${token_list} |
| | | echo "<unk>" >> ${token_list} |
| | | fi |
| | | |
| | | # LM Training Stage |
| | | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then |
| | | echo "stage 3: LM Training" |
| | | fi |
| | | |
| | | # ASR Training Stage |
| | | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then |
| | | echo "stage 4: ASR Training" |
| | | |
| | | mkdir -p ${exp_dir}/exp/${model_dir} |
| | | current_time=$(date "+%Y-%m-%d_%H-%M") |
| | | log_file="${exp_dir}/exp/${model_dir}/train.log.txt.${current_time}" |
| | | echo "log_file: ${log_file}" |
| | | |
| | | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | | --nproc_per_node ${gpu_num} \ |
| | | --master_port ${master_port} \ |
| | | ../../../funasr/bin/train.py \ |
| | | --config-path "${workspace}/conf" \ |
| | | --config-name "${config}" \ |
| | | ++train_data_set_list="${feats_dir}/data/${train_set}/audio_datasets.jsonl" \ |
| | | ++valid_data_set_list="${feats_dir}/data/${valid_set}/audio_datasets.jsonl" \ |
| | | ++tokenizer_conf.token_list="${token_list}" \ |
| | | ++frontend_conf.cmvn_file="${feats_dir}/data/${train_set}/am.mvn" \ |
| | | ++output_dir="${exp_dir}/exp/${model_dir}" &> ${log_file} |
| | | fi |
| | | |
| | | |
| | | |
| | | # Testing Stage |
| | | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then |
| | | echo "stage 5: Inference" |
| | | |
| | | if [ ${inference_device} == "cuda" ]; then |
| | | nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | else |
| | | inference_batch_size=1 |
| | | CUDA_VISIBLE_DEVICES="" |
| | | for JOB in $(seq ${nj}); do |
| | | CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1," |
| | | done |
| | | fi |
| | | |
| | | for dset in ${test_sets}; do |
| | | |
| | | inference_dir="${exp_dir}/exp/${model_dir}/inference-${inference_checkpoint}/${dset}" |
| | | _logdir="${inference_dir}/logdir" |
| | | echo "inference_dir: ${inference_dir}" |
| | | |
| | | mkdir -p "${_logdir}" |
| | | data_dir="${feats_dir}/data/${dset}" |
| | | key_file=${data_dir}/${inference_scp} |
| | | |
| | | split_scps= |
| | | for JOB in $(seq "${nj}"); do |
| | | split_scps+=" ${_logdir}/keys.${JOB}.scp" |
| | | done |
| | | utils/split_scp.pl "${key_file}" ${split_scps} |
| | | |
| | | gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ }) |
| | | for JOB in $(seq ${nj}); do |
| | | { |
| | | id=$((JOB-1)) |
| | | gpuid=${gpuid_list_array[$id]} |
| | | |
| | | export CUDA_VISIBLE_DEVICES=${gpuid} |
| | | python ../../../funasr/bin/inference.py \ |
| | | --config-path="${exp_dir}/exp/${model_dir}" \ |
| | | --config-name="config.yaml" \ |
| | | ++init_param="${exp_dir}/exp/${model_dir}/${inference_checkpoint}" \ |
| | | ++tokenizer_conf.token_list="${token_list}" \ |
| | | ++frontend_conf.cmvn_file="${feats_dir}/data/${train_set}/am.mvn" \ |
| | | ++input="${_logdir}/keys.${JOB}.scp" \ |
| | | ++output_dir="${inference_dir}/${JOB}" \ |
| | | ++device="${inference_device}" \ |
| | | ++ncpu=1 \ |
| | | ++disable_log=true \ |
| | | ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt |
| | | }& |
| | | |
| | | done |
| | | wait |
| | | |
| | | mkdir -p ${inference_dir}/1best_recog |
| | | for f in token score text; do |
| | | if [ -f "${inference_dir}/${JOB}/1best_recog/${f}" ]; then |
| | | for JOB in $(seq "${nj}"); do |
| | | cat "${inference_dir}/${JOB}/1best_recog/${f}" |
| | | done | sort -k1 >"${inference_dir}/1best_recog/${f}" |
| | | fi |
| | | done |
| | | |
| | | echo "Computing WER ..." |
| | | python utils/postprocess_text_zh.py ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc |
| | | python utils/postprocess_text_zh.py ${data_dir}/text ${inference_dir}/1best_recog/text.ref |
| | | python utils/compute_wer.py ${inference_dir}/1best_recog/text.ref ${inference_dir}/1best_recog/text.proc ${inference_dir}/1best_recog/text.cer |
| | | tail -n 3 ${inference_dir}/1best_recog/text.cer |
| | | done |
| | | |
| | | fi |
| New file |
| | |
| | | import os |
| | | import numpy as np |
| | | import sys |
| | | |
| | | |
| | | def compute_wer(ref_file, hyp_file, cer_detail_file): |
| | | rst = { |
| | | "Wrd": 0, |
| | | "Corr": 0, |
| | | "Ins": 0, |
| | | "Del": 0, |
| | | "Sub": 0, |
| | | "Snt": 0, |
| | | "Err": 0.0, |
| | | "S.Err": 0.0, |
| | | "wrong_words": 0, |
| | | "wrong_sentences": 0, |
| | | } |
| | | |
| | | hyp_dict = {} |
| | | ref_dict = {} |
| | | with open(hyp_file, "r") as hyp_reader: |
| | | for line in hyp_reader: |
| | | key = line.strip().split()[0] |
| | | value = line.strip().split()[1:] |
| | | hyp_dict[key] = value |
| | | with open(ref_file, "r") as ref_reader: |
| | | for line in ref_reader: |
| | | key = line.strip().split()[0] |
| | | value = line.strip().split()[1:] |
| | | ref_dict[key] = value |
| | | |
| | | cer_detail_writer = open(cer_detail_file, "w") |
| | | for hyp_key in hyp_dict: |
| | | if hyp_key in ref_dict: |
| | | out_item = compute_wer_by_line(hyp_dict[hyp_key], ref_dict[hyp_key]) |
| | | rst["Wrd"] += out_item["nwords"] |
| | | rst["Corr"] += out_item["cor"] |
| | | rst["wrong_words"] += out_item["wrong"] |
| | | rst["Ins"] += out_item["ins"] |
| | | rst["Del"] += out_item["del"] |
| | | rst["Sub"] += out_item["sub"] |
| | | rst["Snt"] += 1 |
| | | if out_item["wrong"] > 0: |
| | | rst["wrong_sentences"] += 1 |
| | | cer_detail_writer.write(hyp_key + print_cer_detail(out_item) + "\n") |
| | | cer_detail_writer.write( |
| | | "ref:" + "\t" + " ".join(list(map(lambda x: x.lower(), ref_dict[hyp_key]))) + "\n" |
| | | ) |
| | | cer_detail_writer.write( |
| | | "hyp:" + "\t" + " ".join(list(map(lambda x: x.lower(), hyp_dict[hyp_key]))) + "\n" |
| | | ) |
| | | |
| | | if rst["Wrd"] > 0: |
| | | rst["Err"] = round(rst["wrong_words"] * 100 / rst["Wrd"], 2) |
| | | if rst["Snt"] > 0: |
| | | rst["S.Err"] = round(rst["wrong_sentences"] * 100 / rst["Snt"], 2) |
| | | |
| | | cer_detail_writer.write("\n") |
| | | cer_detail_writer.write( |
| | | "%WER " |
| | | + str(rst["Err"]) |
| | | + " [ " |
| | | + str(rst["wrong_words"]) |
| | | + " / " |
| | | + str(rst["Wrd"]) |
| | | + ", " |
| | | + str(rst["Ins"]) |
| | | + " ins, " |
| | | + str(rst["Del"]) |
| | | + " del, " |
| | | + str(rst["Sub"]) |
| | | + " sub ]" |
| | | + "\n" |
| | | ) |
| | | cer_detail_writer.write( |
| | | "%SER " |
| | | + str(rst["S.Err"]) |
| | | + " [ " |
| | | + str(rst["wrong_sentences"]) |
| | | + " / " |
| | | + str(rst["Snt"]) |
| | | + " ]" |
| | | + "\n" |
| | | ) |
| | | cer_detail_writer.write( |
| | | "Scored " |
| | | + str(len(hyp_dict)) |
| | | + " sentences, " |
| | | + str(len(hyp_dict) - rst["Snt"]) |
| | | + " not present in hyp." |
| | | + "\n" |
| | | ) |
| | | |
| | | |
| | | def compute_wer_by_line(hyp, ref): |
| | | hyp = list(map(lambda x: x.lower(), hyp)) |
| | | ref = list(map(lambda x: x.lower(), ref)) |
| | | |
| | | len_hyp = len(hyp) |
| | | len_ref = len(ref) |
| | | |
| | | cost_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int16) |
| | | |
| | | ops_matrix = np.zeros((len_hyp + 1, len_ref + 1), dtype=np.int8) |
| | | |
| | | for i in range(len_hyp + 1): |
| | | cost_matrix[i][0] = i |
| | | for j in range(len_ref + 1): |
| | | cost_matrix[0][j] = j |
| | | |
| | | for i in range(1, len_hyp + 1): |
| | | for j in range(1, len_ref + 1): |
| | | if hyp[i - 1] == ref[j - 1]: |
| | | cost_matrix[i][j] = cost_matrix[i - 1][j - 1] |
| | | else: |
| | | substitution = cost_matrix[i - 1][j - 1] + 1 |
| | | insertion = cost_matrix[i - 1][j] + 1 |
| | | deletion = cost_matrix[i][j - 1] + 1 |
| | | |
| | | compare_val = [substitution, insertion, deletion] |
| | | |
| | | min_val = min(compare_val) |
| | | operation_idx = compare_val.index(min_val) + 1 |
| | | cost_matrix[i][j] = min_val |
| | | ops_matrix[i][j] = operation_idx |
| | | |
| | | match_idx = [] |
| | | i = len_hyp |
| | | j = len_ref |
| | | rst = {"nwords": len_ref, "cor": 0, "wrong": 0, "ins": 0, "del": 0, "sub": 0} |
| | | while i >= 0 or j >= 0: |
| | | i_idx = max(0, i) |
| | | j_idx = max(0, j) |
| | | |
| | | if ops_matrix[i_idx][j_idx] == 0: # correct |
| | | if i - 1 >= 0 and j - 1 >= 0: |
| | | match_idx.append((j - 1, i - 1)) |
| | | rst["cor"] += 1 |
| | | |
| | | i -= 1 |
| | | j -= 1 |
| | | |
| | | elif ops_matrix[i_idx][j_idx] == 2: # insert |
| | | i -= 1 |
| | | rst["ins"] += 1 |
| | | |
| | | elif ops_matrix[i_idx][j_idx] == 3: # delete |
| | | j -= 1 |
| | | rst["del"] += 1 |
| | | |
| | | elif ops_matrix[i_idx][j_idx] == 1: # substitute |
| | | i -= 1 |
| | | j -= 1 |
| | | rst["sub"] += 1 |
| | | |
| | | if i < 0 and j >= 0: |
| | | rst["del"] += 1 |
| | | elif j < 0 and i >= 0: |
| | | rst["ins"] += 1 |
| | | |
| | | match_idx.reverse() |
| | | wrong_cnt = cost_matrix[len_hyp][len_ref] |
| | | rst["wrong"] = wrong_cnt |
| | | |
| | | return rst |
| | | |
| | | |
| | | def print_cer_detail(rst): |
| | | return ( |
| | | "(" |
| | | + "nwords=" |
| | | + str(rst["nwords"]) |
| | | + ",cor=" |
| | | + str(rst["cor"]) |
| | | + ",ins=" |
| | | + str(rst["ins"]) |
| | | + ",del=" |
| | | + str(rst["del"]) |
| | | + ",sub=" |
| | | + str(rst["sub"]) |
| | | + ") corr:" |
| | | + "{:.2%}".format(rst["cor"] / rst["nwords"]) |
| | | + ",cer:" |
| | | + "{:.2%}".format(rst["wrong"] / rst["nwords"]) |
| | | ) |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | if len(sys.argv) != 4: |
| | | print("usage : python compute-wer.py test.ref test.hyp test.wer") |
| | | sys.exit(0) |
| | | |
| | | ref_file = sys.argv[1] |
| | | hyp_file = sys.argv[2] |
| | | cer_detail_file = sys.argv[3] |
| | | compute_wer(ref_file, hyp_file, cer_detail_file) |
| New file |
| | |
| | | from transformers import AutoTokenizer, AutoModel, pipeline |
| | | import numpy as np |
| | | import sys |
| | | import os |
| | | import torch |
| | | from kaldiio import WriteHelper |
| | | import re |
| | | |
| | | text_file_json = sys.argv[1] |
| | | out_ark = sys.argv[2] |
| | | out_scp = sys.argv[3] |
| | | out_shape = sys.argv[4] |
| | | device = int(sys.argv[5]) |
| | | model_path = sys.argv[6] |
| | | |
| | | model = AutoModel.from_pretrained(model_path) |
| | | tokenizer = AutoTokenizer.from_pretrained(model_path) |
| | | extractor = pipeline(task="feature-extraction", model=model, tokenizer=tokenizer, device=device) |
| | | |
| | | with open(text_file_json, "r") as f: |
| | | js = f.readlines() |
| | | |
| | | |
| | | f_shape = open(out_shape, "w") |
| | | with WriteHelper("ark,scp:{},{}".format(out_ark, out_scp)) as writer: |
| | | with torch.no_grad(): |
| | | for idx, line in enumerate(js): |
| | | id, tokens = line.strip().split(" ", 1) |
| | | tokens = re.sub(" ", "", tokens.strip()) |
| | | tokens = " ".join([j for j in tokens]) |
| | | token_num = len(tokens.split(" ")) |
| | | outputs = extractor(tokens) |
| | | outputs = np.array(outputs) |
| | | embeds = outputs[0, 1:-1, :] |
| | | |
| | | token_num_embeds, dim = embeds.shape |
| | | if token_num == token_num_embeds: |
| | | writer(id, embeds) |
| | | shape_line = "{} {},{}\n".format(id, token_num_embeds, dim) |
| | | f_shape.write(shape_line) |
| | | else: |
| | | print( |
| | | "{}, size has changed, {}, {}, {}".format( |
| | | id, token_num, token_num_embeds, tokens |
| | | ) |
| | | ) |
| | | |
| | | |
| | | f_shape.close() |
| New file |
| | |
| | | #!/usr/bin/env perl |
| | | # Copyright 2010-2012 Microsoft Corporation |
| | | # Johns Hopkins University (author: Daniel Povey) |
| | | |
| | | # Licensed under the Apache License, Version 2.0 (the "License"); |
| | | # you may not use this file except in compliance with the License. |
| | | # You may obtain a copy of the License at |
| | | # |
| | | # http://www.apache.org/licenses/LICENSE-2.0 |
| | | # |
| | | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| | | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED |
| | | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, |
| | | # MERCHANTABLITY OR NON-INFRINGEMENT. |
| | | # See the Apache 2 License for the specific language governing permissions and |
| | | # limitations under the License. |
| | | |
| | | |
| | | # This script takes a list of utterance-ids or any file whose first field |
| | | # of each line is an utterance-id, and filters an scp |
| | | # file (or any file whose "n-th" field is an utterance id), printing |
| | | # out only those lines whose "n-th" field is in id_list. The index of |
| | | # the "n-th" field is 1, by default, but can be changed by using |
| | | # the -f <n> switch |
| | | |
| | | $exclude = 0; |
| | | $field = 1; |
| | | $shifted = 0; |
| | | |
| | | do { |
| | | $shifted=0; |
| | | if ($ARGV[0] eq "--exclude") { |
| | | $exclude = 1; |
| | | shift @ARGV; |
| | | $shifted=1; |
| | | } |
| | | if ($ARGV[0] eq "-f") { |
| | | $field = $ARGV[1]; |
| | | shift @ARGV; shift @ARGV; |
| | | $shifted=1 |
| | | } |
| | | } while ($shifted); |
| | | |
| | | if(@ARGV < 1 || @ARGV > 2) { |
| | | die "Usage: filter_scp.pl [--exclude] [-f <field-to-filter-on>] id_list [in.scp] > out.scp \n" . |
| | | "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . |
| | | "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . |
| | | "only the lines that were *not* in id_list.\n" . |
| | | "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . |
| | | "If your older scripts (written before Oct 2014) stopped working and you used the\n" . |
| | | "-f option, add 1 to the argument.\n" . |
| | | "See also: scripts/filter_scp.pl .\n"; |
| | | } |
| | | |
| | | |
| | | $idlist = shift @ARGV; |
| | | open(F, "<$idlist") || die "Could not open id-list file $idlist"; |
| | | while(<F>) { |
| | | @A = split; |
| | | @A>=1 || die "Invalid id-list file line $_"; |
| | | $seen{$A[0]} = 1; |
| | | } |
| | | |
| | | if ($field == 1) { # Treat this as special case, since it is common. |
| | | while(<>) { |
| | | $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; |
| | | # $1 is what we filter on. |
| | | if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { |
| | | print $_; |
| | | } |
| | | } |
| | | } else { |
| | | while(<>) { |
| | | @A = split; |
| | | @A > 0 || die "Invalid scp file line $_"; |
| | | @A >= $field || die "Invalid scp file line $_"; |
| | | if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { |
| | | print $_; |
| | | } |
| | | } |
| | | } |
| | | |
| | | # tests: |
| | | # the following should print "foo 1" |
| | | # ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl <(echo foo) |
| | | # the following should print "bar 2". |
| | | # ( echo foo 1; echo bar 2 ) | scripts/filter_scp.pl -f 2 <(echo 2) |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | echo "$0 $@" |
| | | data_dir=$1 |
| | | |
| | | if [ ! -f ${data_dir}/wav.scp ]; then |
| | | echo "$0: wav.scp is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ ! -f ${data_dir}/text ]; then |
| | | echo "$0: text is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | |
| | | |
| | | mkdir -p ${data_dir}/.backup |
| | | |
| | | awk '{print $1}' ${data_dir}/wav.scp > ${data_dir}/.backup/wav_id |
| | | awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id |
| | | |
| | | sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id |
| | | |
| | | cp ${data_dir}/wav.scp ${data_dir}/.backup/wav.scp |
| | | cp ${data_dir}/text ${data_dir}/.backup/text |
| | | |
| | | mv ${data_dir}/wav.scp ${data_dir}/wav.scp.bak |
| | | mv ${data_dir}/text ${data_dir}/text.bak |
| | | |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/wav.scp.bak | sort -k1,1 -u > ${data_dir}/wav.scp |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text |
| | | |
| | | rm ${data_dir}/wav.scp.bak |
| | | rm ${data_dir}/text.bak |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | echo "$0 $@" |
| | | data_dir=$1 |
| | | |
| | | if [ ! -f ${data_dir}/feats.scp ]; then |
| | | echo "$0: feats.scp is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ ! -f ${data_dir}/text ]; then |
| | | echo "$0: text is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ ! -f ${data_dir}/speech_shape ]; then |
| | | echo "$0: feature lengths is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | if [ ! -f ${data_dir}/text_shape ]; then |
| | | echo "$0: text lengths is not found" |
| | | exit 1; |
| | | fi |
| | | |
| | | mkdir -p ${data_dir}/.backup |
| | | |
| | | awk '{print $1}' ${data_dir}/feats.scp > ${data_dir}/.backup/wav_id |
| | | awk '{print $1}' ${data_dir}/text > ${data_dir}/.backup/text_id |
| | | |
| | | sort ${data_dir}/.backup/wav_id ${data_dir}/.backup/text_id | uniq -d > ${data_dir}/.backup/id |
| | | |
| | | cp ${data_dir}/feats.scp ${data_dir}/.backup/feats.scp |
| | | cp ${data_dir}/text ${data_dir}/.backup/text |
| | | cp ${data_dir}/speech_shape ${data_dir}/.backup/speech_shape |
| | | cp ${data_dir}/text_shape ${data_dir}/.backup/text_shape |
| | | |
| | | mv ${data_dir}/feats.scp ${data_dir}/feats.scp.bak |
| | | mv ${data_dir}/text ${data_dir}/text.bak |
| | | mv ${data_dir}/speech_shape ${data_dir}/speech_shape.bak |
| | | mv ${data_dir}/text_shape ${data_dir}/text_shape.bak |
| | | |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/feats.scp.bak | sort -k1,1 -u > ${data_dir}/feats.scp |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text.bak | sort -k1,1 -u > ${data_dir}/text |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/speech_shape.bak | sort -k1,1 -u > ${data_dir}/speech_shape |
| | | utils/filter_scp.pl -f 1 ${data_dir}/.backup/id ${data_dir}/text_shape.bak | sort -k1,1 -u > ${data_dir}/text_shape |
| | | |
| | | rm ${data_dir}/feats.scp.bak |
| | | rm ${data_dir}/text.bak |
| | | rm ${data_dir}/speech_shape.bak |
| | | rm ${data_dir}/text_shape.bak |
| | | |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); |
| | | # Arnab Ghoshal, Karel Vesely |
| | | |
| | | # Licensed under the Apache License, Version 2.0 (the "License"); |
| | | # you may not use this file except in compliance with the License. |
| | | # You may obtain a copy of the License at |
| | | # |
| | | # http://www.apache.org/licenses/LICENSE-2.0 |
| | | # |
| | | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| | | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED |
| | | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, |
| | | # MERCHANTABLITY OR NON-INFRINGEMENT. |
| | | # See the Apache 2 License for the specific language governing permissions and |
| | | # limitations under the License. |
| | | |
| | | |
| | | # Parse command-line options. |
| | | # To be sourced by another script (as in ". parse_options.sh"). |
| | | # Option format is: --option-name arg |
| | | # and shell variable "option_name" gets set to value "arg." |
| | | # The exception is --help, which takes no arguments, but prints the |
| | | # $help_message variable (if defined). |
| | | |
| | | |
| | | ### |
| | | ### The --config file options have lower priority to command line |
| | | ### options, so we need to import them first... |
| | | ### |
| | | |
| | | # Now import all the configs specified by command-line, in left-to-right order |
| | | for ((argpos=1; argpos<$#; argpos++)); do |
| | | if [ "${!argpos}" == "--config" ]; then |
| | | argpos_plus1=$((argpos+1)) |
| | | config=${!argpos_plus1} |
| | | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 |
| | | . $config # source the config file. |
| | | fi |
| | | done |
| | | |
| | | |
| | | ### |
| | | ### Now we process the command line options |
| | | ### |
| | | while true; do |
| | | [ -z "${1:-}" ] && break; # break if there are no arguments |
| | | case "$1" in |
| | | # If the enclosing script is called with --help option, print the help |
| | | # message and exit. Scripts should put help messages in $help_message |
| | | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; |
| | | else printf "$help_message\n" 1>&2 ; fi; |
| | | exit 0 ;; |
| | | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" |
| | | exit 1 ;; |
| | | # If the first command-line argument begins with "--" (e.g. --foo-bar), |
| | | # then work out the variable name as $name, which will equal "foo_bar". |
| | | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; |
| | | # Next we test whether the variable in question is undefned-- if so it's |
| | | # an invalid option and we die. Note: $0 evaluates to the name of the |
| | | # enclosing script. |
| | | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar |
| | | # is undefined. We then have to wrap this test inside "eval" because |
| | | # foo_bar is itself inside a variable ($name). |
| | | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; |
| | | |
| | | oldval="`eval echo \\$$name`"; |
| | | # Work out whether we seem to be expecting a Boolean argument. |
| | | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then |
| | | was_bool=true; |
| | | else |
| | | was_bool=false; |
| | | fi |
| | | |
| | | # Set the variable to the right value-- the escaped quotes make it work if |
| | | # the option had spaces, like --cmd "queue.pl -sync y" |
| | | eval $name=\"$2\"; |
| | | |
| | | # Check that Boolean-valued arguments are really Boolean. |
| | | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then |
| | | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 |
| | | exit 1; |
| | | fi |
| | | shift 2; |
| | | ;; |
| | | *) break; |
| | | esac |
| | | done |
| | | |
| | | |
| | | # Check for an empty argument to the --cmd option, which can easily occur as a |
| | | # result of scripting errors. |
| | | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; |
| | | |
| | | |
| | | true; # so this script returns exit code 0. |
| New file |
| | |
| | | import sys |
| | | import re |
| | | |
| | | in_f = sys.argv[1] |
| | | out_f = sys.argv[2] |
| | | |
| | | |
| | | with open(in_f, "r", encoding="utf-8") as f: |
| | | lines = f.readlines() |
| | | |
| | | with open(out_f, "w", encoding="utf-8") as f: |
| | | for line in lines: |
| | | outs = line.strip().split(" ", 1) |
| | | if len(outs) == 2: |
| | | idx, text = outs |
| | | text = re.sub("</s>", "", text) |
| | | text = re.sub("<s>", "", text) |
| | | text = re.sub("@@", "", text) |
| | | text = re.sub("@", "", text) |
| | | text = re.sub("<unk>", "", text) |
| | | text = re.sub(" ", "", text) |
| | | text = text.lower() |
| | | else: |
| | | idx = outs[0] |
| | | text = " " |
| | | |
| | | text = [x for x in text] |
| | | text = " ".join(text) |
| | | out = "{} {}\n".format(idx, text) |
| | | f.write(out) |
| New file |
| | |
| | | #!/usr/bin/env perl |
| | | |
| | | # Copyright 2013 Johns Hopkins University (author: Daniel Povey) |
| | | |
| | | # Licensed under the Apache License, Version 2.0 (the "License"); |
| | | # you may not use this file except in compliance with the License. |
| | | # You may obtain a copy of the License at |
| | | # |
| | | # http://www.apache.org/licenses/LICENSE-2.0 |
| | | # |
| | | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| | | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED |
| | | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, |
| | | # MERCHANTABLITY OR NON-INFRINGEMENT. |
| | | # See the Apache 2 License for the specific language governing permissions and |
| | | # limitations under the License. |
| | | |
| | | |
| | | if ($ARGV[0] eq "--srand") { |
| | | $n = $ARGV[1]; |
| | | $n =~ m/\d+/ || die "Bad argument to --srand option: \"$n\""; |
| | | srand($ARGV[1]); |
| | | shift; |
| | | shift; |
| | | } else { |
| | | srand(0); # Gives inconsistent behavior if we don't seed. |
| | | } |
| | | |
| | | if (@ARGV > 1 || $ARGV[0] =~ m/^-.+/) { # >1 args, or an option we |
| | | # don't understand. |
| | | print "Usage: shuffle_list.pl [--srand N] [input file] > output\n"; |
| | | print "randomizes the order of lines of input.\n"; |
| | | exit(1); |
| | | } |
| | | |
| | | @lines; |
| | | while (<>) { |
| | | push @lines, [ (rand(), $_)] ; |
| | | } |
| | | |
| | | @lines = sort { $a->[0] cmp $b->[0] } @lines; |
| | | foreach $l (@lines) { |
| | | print $l->[1]; |
| | | } |
| New file |
| | |
| | | #!/usr/bin/env perl |
| | | |
| | | # Copyright 2010-2011 Microsoft Corporation |
| | | |
| | | # See ../../COPYING for clarification regarding multiple authors |
| | | # |
| | | # Licensed under the Apache License, Version 2.0 (the "License"); |
| | | # you may not use this file except in compliance with the License. |
| | | # You may obtain a copy of the License at |
| | | # |
| | | # http://www.apache.org/licenses/LICENSE-2.0 |
| | | # |
| | | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| | | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED |
| | | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, |
| | | # MERCHANTABLITY OR NON-INFRINGEMENT. |
| | | # See the Apache 2 License for the specific language governing permissions and |
| | | # limitations under the License. |
| | | |
| | | |
| | | # This program splits up any kind of .scp or archive-type file. |
| | | # If there is no utt2spk option it will work on any text file and |
| | | # will split it up with an approximately equal number of lines in |
| | | # each but. |
| | | # With the --utt2spk option it will work on anything that has the |
| | | # utterance-id as the first entry on each line; the utt2spk file is |
| | | # of the form "utterance speaker" (on each line). |
| | | # It splits it into equal size chunks as far as it can. If you use the utt2spk |
| | | # option it will make sure these chunks coincide with speaker boundaries. In |
| | | # this case, if there are more chunks than speakers (and in some other |
| | | # circumstances), some of the resulting chunks will be empty and it will print |
| | | # an error message and exit with nonzero status. |
| | | # You will normally call this like: |
| | | # split_scp.pl scp scp.1 scp.2 scp.3 ... |
| | | # or |
| | | # split_scp.pl --utt2spk=utt2spk scp scp.1 scp.2 scp.3 ... |
| | | # Note that you can use this script to split the utt2spk file itself, |
| | | # e.g. split_scp.pl --utt2spk=utt2spk utt2spk utt2spk.1 utt2spk.2 ... |
| | | |
| | | # You can also call the scripts like: |
| | | # split_scp.pl -j 3 0 scp scp.0 |
| | | # [note: with this option, it assumes zero-based indexing of the split parts, |
| | | # i.e. the second number must be 0 <= n < num-jobs.] |
| | | |
| | | use warnings; |
| | | |
| | | $num_jobs = 0; |
| | | $job_id = 0; |
| | | $utt2spk_file = ""; |
| | | $one_based = 0; |
| | | |
| | | for ($x = 1; $x <= 3 && @ARGV > 0; $x++) { |
| | | if ($ARGV[0] eq "-j") { |
| | | shift @ARGV; |
| | | $num_jobs = shift @ARGV; |
| | | $job_id = shift @ARGV; |
| | | } |
| | | if ($ARGV[0] =~ /--utt2spk=(.+)/) { |
| | | $utt2spk_file=$1; |
| | | shift; |
| | | } |
| | | if ($ARGV[0] eq '--one-based') { |
| | | $one_based = 1; |
| | | shift @ARGV; |
| | | } |
| | | } |
| | | |
| | | if ($num_jobs != 0 && ($num_jobs < 0 || $job_id - $one_based < 0 || |
| | | $job_id - $one_based >= $num_jobs)) { |
| | | die "$0: Invalid job number/index values for '-j $num_jobs $job_id" . |
| | | ($one_based ? " --one-based" : "") . "'\n" |
| | | } |
| | | |
| | | $one_based |
| | | and $job_id--; |
| | | |
| | | if(($num_jobs == 0 && @ARGV < 2) || ($num_jobs > 0 && (@ARGV < 1 || @ARGV > 2))) { |
| | | die |
| | | "Usage: split_scp.pl [--utt2spk=<utt2spk_file>] in.scp out1.scp out2.scp ... |
| | | or: split_scp.pl -j num-jobs job-id [--one-based] [--utt2spk=<utt2spk_file>] in.scp [out.scp] |
| | | ... where 0 <= job-id < num-jobs, or 1 <= job-id <- num-jobs if --one-based.\n"; |
| | | } |
| | | |
| | | $error = 0; |
| | | $inscp = shift @ARGV; |
| | | if ($num_jobs == 0) { # without -j option |
| | | @OUTPUTS = @ARGV; |
| | | } else { |
| | | for ($j = 0; $j < $num_jobs; $j++) { |
| | | if ($j == $job_id) { |
| | | if (@ARGV > 0) { push @OUTPUTS, $ARGV[0]; } |
| | | else { push @OUTPUTS, "-"; } |
| | | } else { |
| | | push @OUTPUTS, "/dev/null"; |
| | | } |
| | | } |
| | | } |
| | | |
| | | if ($utt2spk_file ne "") { # We have the --utt2spk option... |
| | | open($u_fh, '<', $utt2spk_file) || die "$0: Error opening utt2spk file $utt2spk_file: $!\n"; |
| | | while(<$u_fh>) { |
| | | @A = split; |
| | | @A == 2 || die "$0: Bad line $_ in utt2spk file $utt2spk_file\n"; |
| | | ($u,$s) = @A; |
| | | $utt2spk{$u} = $s; |
| | | } |
| | | close $u_fh; |
| | | open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; |
| | | @spkrs = (); |
| | | while(<$i_fh>) { |
| | | @A = split; |
| | | if(@A == 0) { die "$0: Empty or space-only line in scp file $inscp\n"; } |
| | | $u = $A[0]; |
| | | $s = $utt2spk{$u}; |
| | | defined $s || die "$0: No utterance $u in utt2spk file $utt2spk_file\n"; |
| | | if(!defined $spk_count{$s}) { |
| | | push @spkrs, $s; |
| | | $spk_count{$s} = 0; |
| | | $spk_data{$s} = []; # ref to new empty array. |
| | | } |
| | | $spk_count{$s}++; |
| | | push @{$spk_data{$s}}, $_; |
| | | } |
| | | # Now split as equally as possible .. |
| | | # First allocate spks to files by allocating an approximately |
| | | # equal number of speakers. |
| | | $numspks = @spkrs; # number of speakers. |
| | | $numscps = @OUTPUTS; # number of output files. |
| | | if ($numspks < $numscps) { |
| | | die "$0: Refusing to split data because number of speakers $numspks " . |
| | | "is less than the number of output .scp files $numscps\n"; |
| | | } |
| | | for($scpidx = 0; $scpidx < $numscps; $scpidx++) { |
| | | $scparray[$scpidx] = []; # [] is array reference. |
| | | } |
| | | for ($spkidx = 0; $spkidx < $numspks; $spkidx++) { |
| | | $scpidx = int(($spkidx*$numscps) / $numspks); |
| | | $spk = $spkrs[$spkidx]; |
| | | push @{$scparray[$scpidx]}, $spk; |
| | | $scpcount[$scpidx] += $spk_count{$spk}; |
| | | } |
| | | |
| | | # Now will try to reassign beginning + ending speakers |
| | | # to different scp's and see if it gets more balanced. |
| | | # Suppose objf we're minimizing is sum_i (num utts in scp[i] - average)^2. |
| | | # We can show that if considering changing just 2 scp's, we minimize |
| | | # this by minimizing the squared difference in sizes. This is |
| | | # equivalent to minimizing the absolute difference in sizes. This |
| | | # shows this method is bound to converge. |
| | | |
| | | $changed = 1; |
| | | while($changed) { |
| | | $changed = 0; |
| | | for($scpidx = 0; $scpidx < $numscps; $scpidx++) { |
| | | # First try to reassign ending spk of this scp. |
| | | if($scpidx < $numscps-1) { |
| | | $sz = @{$scparray[$scpidx]}; |
| | | if($sz > 0) { |
| | | $spk = $scparray[$scpidx]->[$sz-1]; |
| | | $count = $spk_count{$spk}; |
| | | $nutt1 = $scpcount[$scpidx]; |
| | | $nutt2 = $scpcount[$scpidx+1]; |
| | | if( abs( ($nutt2+$count) - ($nutt1-$count)) |
| | | < abs($nutt2 - $nutt1)) { # Would decrease |
| | | # size-diff by reassigning spk... |
| | | $scpcount[$scpidx+1] += $count; |
| | | $scpcount[$scpidx] -= $count; |
| | | pop @{$scparray[$scpidx]}; |
| | | unshift @{$scparray[$scpidx+1]}, $spk; |
| | | $changed = 1; |
| | | } |
| | | } |
| | | } |
| | | if($scpidx > 0 && @{$scparray[$scpidx]} > 0) { |
| | | $spk = $scparray[$scpidx]->[0]; |
| | | $count = $spk_count{$spk}; |
| | | $nutt1 = $scpcount[$scpidx-1]; |
| | | $nutt2 = $scpcount[$scpidx]; |
| | | if( abs( ($nutt2-$count) - ($nutt1+$count)) |
| | | < abs($nutt2 - $nutt1)) { # Would decrease |
| | | # size-diff by reassigning spk... |
| | | $scpcount[$scpidx-1] += $count; |
| | | $scpcount[$scpidx] -= $count; |
| | | shift @{$scparray[$scpidx]}; |
| | | push @{$scparray[$scpidx-1]}, $spk; |
| | | $changed = 1; |
| | | } |
| | | } |
| | | } |
| | | } |
| | | # Now print out the files... |
| | | for($scpidx = 0; $scpidx < $numscps; $scpidx++) { |
| | | $scpfile = $OUTPUTS[$scpidx]; |
| | | ($scpfile ne '-' ? open($f_fh, '>', $scpfile) |
| | | : open($f_fh, '>&', \*STDOUT)) || |
| | | die "$0: Could not open scp file $scpfile for writing: $!\n"; |
| | | $count = 0; |
| | | if(@{$scparray[$scpidx]} == 0) { |
| | | print STDERR "$0: eError: split_scp.pl producing empty .scp file " . |
| | | "$scpfile (too many splits and too few speakers?)\n"; |
| | | $error = 1; |
| | | } else { |
| | | foreach $spk ( @{$scparray[$scpidx]} ) { |
| | | print $f_fh @{$spk_data{$spk}}; |
| | | $count += $spk_count{$spk}; |
| | | } |
| | | $count == $scpcount[$scpidx] || die "Count mismatch [code error]"; |
| | | } |
| | | close($f_fh); |
| | | } |
| | | } else { |
| | | # This block is the "normal" case where there is no --utt2spk |
| | | # option and we just break into equal size chunks. |
| | | |
| | | open($i_fh, '<', $inscp) || die "$0: Error opening input scp file $inscp: $!\n"; |
| | | |
| | | $numscps = @OUTPUTS; # size of array. |
| | | @F = (); |
| | | while(<$i_fh>) { |
| | | push @F, $_; |
| | | } |
| | | $numlines = @F; |
| | | if($numlines == 0) { |
| | | print STDERR "$0: error: empty input scp file $inscp\n"; |
| | | $error = 1; |
| | | } |
| | | $linesperscp = int( $numlines / $numscps); # the "whole part".. |
| | | $linesperscp >= 1 || die "$0: You are splitting into too many pieces! [reduce \$nj ($numscps) to be smaller than the number of lines ($numlines) in $inscp]\n"; |
| | | $remainder = $numlines - ($linesperscp * $numscps); |
| | | ($remainder >= 0 && $remainder < $numlines) || die "bad remainder $remainder"; |
| | | # [just doing int() rounds down]. |
| | | $n = 0; |
| | | for($scpidx = 0; $scpidx < @OUTPUTS; $scpidx++) { |
| | | $scpfile = $OUTPUTS[$scpidx]; |
| | | ($scpfile ne '-' ? open($o_fh, '>', $scpfile) |
| | | : open($o_fh, '>&', \*STDOUT)) || |
| | | die "$0: Could not open scp file $scpfile for writing: $!\n"; |
| | | for($k = 0; $k < $linesperscp + ($scpidx < $remainder ? 1 : 0); $k++) { |
| | | print $o_fh $F[$n++]; |
| | | } |
| | | close($o_fh) || die "$0: Eror closing scp file $scpfile: $!\n"; |
| | | } |
| | | $n == $numlines || die "$n != $numlines [code error]"; |
| | | } |
| | | |
| | | exit ($error); |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | |
| | | # Copyright 2017 Johns Hopkins University (Shinji Watanabe) |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | |
| | | import argparse |
| | | import codecs |
| | | import re |
| | | import sys |
| | | import json |
| | | |
| | | is_python2 = sys.version_info[0] == 2 |
| | | |
| | | |
| | | def exist_or_not(i, match_pos): |
| | | start_pos = None |
| | | end_pos = None |
| | | for pos in match_pos: |
| | | if pos[0] <= i < pos[1]: |
| | | start_pos = pos[0] |
| | | end_pos = pos[1] |
| | | break |
| | | |
| | | return start_pos, end_pos |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = argparse.ArgumentParser( |
| | | description="convert raw text to tokenized text", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | parser.add_argument( |
| | | "--nchar", |
| | | "-n", |
| | | default=1, |
| | | type=int, |
| | | help="number of characters to split, i.e., \ |
| | | aabb -> a a b b with -n 1 and aa bb with -n 2", |
| | | ) |
| | | parser.add_argument("--skip-ncols", "-s", default=0, type=int, help="skip first n columns") |
| | | parser.add_argument("--space", default="<space>", type=str, help="space symbol") |
| | | parser.add_argument( |
| | | "--non-lang-syms", |
| | | "-l", |
| | | default=None, |
| | | type=str, |
| | | help="list of non-linguistic symobles, e.g., <NOISE> etc.", |
| | | ) |
| | | parser.add_argument("text", type=str, default=False, nargs="?", help="input text") |
| | | parser.add_argument( |
| | | "--trans_type", |
| | | "-t", |
| | | type=str, |
| | | default="char", |
| | | choices=["char", "phn"], |
| | | help="""Transcript type. char/phn. e.g., for TIMIT FADG0_SI1279 - |
| | | If trans_type is char, |
| | | read from SI1279.WRD file -> "bricks are an alternative" |
| | | Else if trans_type is phn, |
| | | read from SI1279.PHN file -> "sil b r ih sil k s aa r er n aa l |
| | | sil t er n ih sil t ih v sil" """, |
| | | ) |
| | | parser.add_argument( |
| | | "--text_format", |
| | | default="text", |
| | | type=str, |
| | | help="text, jsonl", |
| | | ) |
| | | return parser |
| | | |
| | | |
| | | def main(): |
| | | parser = get_parser() |
| | | args = parser.parse_args() |
| | | |
| | | rs = [] |
| | | if args.non_lang_syms is not None: |
| | | with codecs.open(args.non_lang_syms, "r", encoding="utf-8") as f: |
| | | nls = [x.rstrip() for x in f.readlines()] |
| | | rs = [re.compile(re.escape(x)) for x in nls] |
| | | |
| | | if args.text: |
| | | f = codecs.open(args.text, encoding="utf-8") |
| | | else: |
| | | f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) |
| | | |
| | | sys.stdout = codecs.getwriter("utf-8")(sys.stdout if is_python2 else sys.stdout.buffer) |
| | | line = f.readline() |
| | | n = args.nchar |
| | | while line: |
| | | if args.text_format == "jsonl": |
| | | data = json.loads(line.strip()) |
| | | line = data["target"] |
| | | x = line.split() |
| | | print(" ".join(x[: args.skip_ncols]), end=" ") |
| | | a = " ".join(x[args.skip_ncols :]) |
| | | |
| | | # get all matched positions |
| | | match_pos = [] |
| | | for r in rs: |
| | | i = 0 |
| | | while i >= 0: |
| | | m = r.search(a, i) |
| | | if m: |
| | | match_pos.append([m.start(), m.end()]) |
| | | i = m.end() |
| | | else: |
| | | break |
| | | |
| | | if args.trans_type == "phn": |
| | | a = a.split(" ") |
| | | else: |
| | | if len(match_pos) > 0: |
| | | chars = [] |
| | | i = 0 |
| | | while i < len(a): |
| | | start_pos, end_pos = exist_or_not(i, match_pos) |
| | | if start_pos is not None: |
| | | chars.append(a[start_pos:end_pos]) |
| | | i = end_pos |
| | | else: |
| | | chars.append(a[i]) |
| | | i += 1 |
| | | a = chars |
| | | |
| | | a = [a[j : j + n] for j in range(0, len(a), n)] |
| | | |
| | | a_flat = [] |
| | | for z in a: |
| | | a_flat.append("".join(z)) |
| | | |
| | | a_chars = [z.replace(" ", args.space) for z in a_flat] |
| | | if args.trans_type == "phn": |
| | | a_chars = [z.replace("sil", args.space) for z in a_chars] |
| | | print(" ".join(a_chars)) |
| | | line = f.readline() |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| New file |
| | |
| | | import re |
| | | import argparse |
| | | |
| | | |
| | | def load_dict(seg_file): |
| | | seg_dict = {} |
| | | with open(seg_file, "r") as infile: |
| | | for line in infile: |
| | | s = line.strip().split() |
| | | key = s[0] |
| | | value = s[1:] |
| | | seg_dict[key] = " ".join(value) |
| | | return seg_dict |
| | | |
| | | |
| | | def forward_segment(text, dic): |
| | | word_list = [] |
| | | i = 0 |
| | | while i < len(text): |
| | | longest_word = text[i] |
| | | for j in range(i + 1, len(text) + 1): |
| | | word = text[i:j] |
| | | if word in dic: |
| | | if len(word) > len(longest_word): |
| | | longest_word = word |
| | | word_list.append(longest_word) |
| | | i += len(longest_word) |
| | | return word_list |
| | | |
| | | |
| | | def tokenize(txt, seg_dict): |
| | | out_txt = "" |
| | | pattern = re.compile(r"([\u4E00-\u9FA5A-Za-z0-9])") |
| | | for word in txt: |
| | | if pattern.match(word): |
| | | if word in seg_dict: |
| | | out_txt += seg_dict[word] + " " |
| | | else: |
| | | out_txt += "<unk>" + " " |
| | | else: |
| | | continue |
| | | return out_txt.strip() |
| | | |
| | | |
| | | def get_parser(): |
| | | parser = argparse.ArgumentParser( |
| | | description="text tokenize", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | parser.add_argument( |
| | | "--text-file", |
| | | "-t", |
| | | default=False, |
| | | required=True, |
| | | type=str, |
| | | help="input text", |
| | | ) |
| | | parser.add_argument( |
| | | "--seg-file", |
| | | "-s", |
| | | default=False, |
| | | required=True, |
| | | type=str, |
| | | help="seg file", |
| | | ) |
| | | parser.add_argument( |
| | | "--txt-index", |
| | | "-i", |
| | | default=1, |
| | | required=True, |
| | | type=int, |
| | | help="txt index", |
| | | ) |
| | | parser.add_argument( |
| | | "--output-dir", |
| | | "-o", |
| | | default=False, |
| | | required=True, |
| | | type=str, |
| | | help="output dir", |
| | | ) |
| | | return parser |
| | | |
| | | |
| | | def main(): |
| | | parser = get_parser() |
| | | args = parser.parse_args() |
| | | |
| | | txt_writer = open("{}/text.{}.txt".format(args.output_dir, args.txt_index), "w") |
| | | shape_writer = open("{}/len.{}".format(args.output_dir, args.txt_index), "w") |
| | | seg_dict = load_dict(args.seg_file) |
| | | with open(args.text_file, "r") as infile: |
| | | for line in infile: |
| | | s = line.strip().split() |
| | | text_id = s[0] |
| | | text_list = forward_segment("".join(s[1:]).lower(), seg_dict) |
| | | text = tokenize(text_list, seg_dict) |
| | | lens = len(text.strip().split()) |
| | | txt_writer.write(text_id + " " + text + "\n") |
| | | shape_writer.write(text_id + " " + str(lens) + "\n") |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | main() |
| New file |
| | |
| | | #!/usr/bin/env bash |
| | | |
| | | |
| | | # Begin configuration section. |
| | | nj=32 |
| | | cmd=utils/run.pl |
| | | |
| | | echo "$0 $@" |
| | | |
| | | . utils/parse_options.sh || exit 1; |
| | | |
| | | # tokenize configuration |
| | | text_dir=$1 |
| | | seg_file=$2 |
| | | logdir=$3 |
| | | output_dir=$4 |
| | | |
| | | txt_dir=${output_dir}/txt; mkdir -p ${output_dir}/txt |
| | | mkdir -p ${logdir} |
| | | |
| | | $cmd JOB=1:$nj $logdir/text_tokenize.JOB.log \ |
| | | python utils/text_tokenize.py -t ${text_dir}/txt/text.JOB.txt \ |
| | | -s ${seg_file} -i JOB -o ${txt_dir} \ |
| | | || exit 1; |
| | | |
| | | # concatenate the text files together. |
| | | for n in $(seq $nj); do |
| | | cat ${txt_dir}/text.$n.txt || exit 1 |
| | | done > ${output_dir}/text || exit 1 |
| | | |
| | | for n in $(seq $nj); do |
| | | cat ${txt_dir}/len.$n || exit 1 |
| | | done > ${output_dir}/text_shape || exit 1 |
| | | |
| | | echo "$0: Succeeded text tokenize" |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # coding=utf-8 |
| | | |
| | | # Authors: |
| | | # 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) |
| | | # 2019.9 Jiayu DU |
| | | # |
| | | # requirements: |
| | | # - python 3.X |
| | | # notes: python 2.X WILL fail or produce misleading results |
| | | |
| | | import sys, os, argparse, codecs, string, re |
| | | |
| | | # ================================================================================ # |
| | | # basic constant |
| | | # ================================================================================ # |
| | | CHINESE_DIGIS = "零一二三四五六七八九" |
| | | BIG_CHINESE_DIGIS_SIMPLIFIED = "零壹贰叁肆伍陆柒捌玖" |
| | | BIG_CHINESE_DIGIS_TRADITIONAL = "零壹貳參肆伍陸柒捌玖" |
| | | SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = "十百千万" |
| | | SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = "拾佰仟萬" |
| | | LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "亿兆京垓秭穰沟涧正载" |
| | | LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = "億兆京垓秭穰溝澗正載" |
| | | SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = "十百千万" |
| | | SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = "拾佰仟萬" |
| | | |
| | | ZERO_ALT = "〇" |
| | | ONE_ALT = "幺" |
| | | TWO_ALTS = ["两", "兩"] |
| | | |
| | | POSITIVE = ["正", "正"] |
| | | NEGATIVE = ["负", "負"] |
| | | POINT = ["点", "點"] |
| | | # PLUS = [u'加', u'加'] |
| | | # SIL = [u'杠', u'槓'] |
| | | |
| | | FILLER_CHARS = ["呃", "啊"] |
| | | ER_WHITELIST = ( |
| | | "(儿女|儿子|儿孙|女儿|儿媳|妻儿|" |
| | | "胎儿|婴儿|新生儿|婴幼儿|幼儿|少儿|小儿|儿歌|儿童|儿科|托儿所|孤儿|" |
| | | "儿戏|儿化|台儿庄|鹿儿岛|正儿八经|吊儿郎当|生儿育女|托儿带女|养儿防老|痴儿呆女|" |
| | | "佳儿佳妇|儿怜兽扰|儿无常父|儿不嫌母丑|儿行千里母担忧|儿大不由爷|苏乞儿)" |
| | | ) |
| | | |
| | | # 中文数字系统类型 |
| | | NUMBERING_TYPES = ["low", "mid", "high"] |
| | | |
| | | CURRENCY_NAMES = ( |
| | | "(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|" |
| | | "里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)" |
| | | ) |
| | | CURRENCY_UNITS = ( |
| | | "((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)" |
| | | ) |
| | | COM_QUANTIFIERS = ( |
| | | "(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|" |
| | | "砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|" |
| | | "针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|" |
| | | "毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|" |
| | | "盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|" |
| | | "纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)" |
| | | ) |
| | | |
| | | # punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) |
| | | CHINESE_PUNC_STOP = "!?。。" |
| | | CHINESE_PUNC_NON_STOP = ""#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏" |
| | | CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP |
| | | |
| | | |
| | | # ================================================================================ # |
| | | # basic class |
| | | # ================================================================================ # |
| | | class ChineseChar(object): |
| | | """ |
| | | 中文字符 |
| | | 每个字符对应简体和繁体, |
| | | e.g. 简体 = '负', 繁体 = '負' |
| | | 转换时可转换为简体或繁体 |
| | | """ |
| | | |
| | | def __init__(self, simplified, traditional): |
| | | self.simplified = simplified |
| | | self.traditional = traditional |
| | | # self.__repr__ = self.__str__ |
| | | |
| | | def __str__(self): |
| | | return self.simplified or self.traditional or None |
| | | |
| | | def __repr__(self): |
| | | return self.__str__() |
| | | |
| | | |
| | | class ChineseNumberUnit(ChineseChar): |
| | | """ |
| | | 中文数字/数位字符 |
| | | 每个字符除繁简体外还有一个额外的大写字符 |
| | | e.g. '陆' 和 '陸' |
| | | """ |
| | | |
| | | def __init__(self, power, simplified, traditional, big_s, big_t): |
| | | super(ChineseNumberUnit, self).__init__(simplified, traditional) |
| | | self.power = power |
| | | self.big_s = big_s |
| | | self.big_t = big_t |
| | | |
| | | def __str__(self): |
| | | return "10^{}".format(self.power) |
| | | |
| | | @classmethod |
| | | def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): |
| | | |
| | | if small_unit: |
| | | return ChineseNumberUnit( |
| | | power=index + 1, |
| | | simplified=value[0], |
| | | traditional=value[1], |
| | | big_s=value[1], |
| | | big_t=value[1], |
| | | ) |
| | | elif numbering_type == NUMBERING_TYPES[0]: |
| | | return ChineseNumberUnit( |
| | | power=index + 8, |
| | | simplified=value[0], |
| | | traditional=value[1], |
| | | big_s=value[0], |
| | | big_t=value[1], |
| | | ) |
| | | elif numbering_type == NUMBERING_TYPES[1]: |
| | | return ChineseNumberUnit( |
| | | power=(index + 2) * 4, |
| | | simplified=value[0], |
| | | traditional=value[1], |
| | | big_s=value[0], |
| | | big_t=value[1], |
| | | ) |
| | | elif numbering_type == NUMBERING_TYPES[2]: |
| | | return ChineseNumberUnit( |
| | | power=pow(2, index + 3), |
| | | simplified=value[0], |
| | | traditional=value[1], |
| | | big_s=value[0], |
| | | big_t=value[1], |
| | | ) |
| | | else: |
| | | raise ValueError( |
| | | "Counting type should be in {0} ({1} provided).".format( |
| | | NUMBERING_TYPES, numbering_type |
| | | ) |
| | | ) |
| | | |
| | | |
| | | class ChineseNumberDigit(ChineseChar): |
| | | """ |
| | | 中文数字字符 |
| | | """ |
| | | |
| | | def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None): |
| | | super(ChineseNumberDigit, self).__init__(simplified, traditional) |
| | | self.value = value |
| | | self.big_s = big_s |
| | | self.big_t = big_t |
| | | self.alt_s = alt_s |
| | | self.alt_t = alt_t |
| | | |
| | | def __str__(self): |
| | | return str(self.value) |
| | | |
| | | @classmethod |
| | | def create(cls, i, v): |
| | | return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) |
| | | |
| | | |
| | | class ChineseMath(ChineseChar): |
| | | """ |
| | | 中文数位字符 |
| | | """ |
| | | |
| | | def __init__(self, simplified, traditional, symbol, expression=None): |
| | | super(ChineseMath, self).__init__(simplified, traditional) |
| | | self.symbol = symbol |
| | | self.expression = expression |
| | | self.big_s = simplified |
| | | self.big_t = traditional |
| | | |
| | | |
| | | CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath |
| | | |
| | | |
| | | class NumberSystem(object): |
| | | """ |
| | | 中文数字系统 |
| | | """ |
| | | |
| | | pass |
| | | |
| | | |
| | | class MathSymbol(object): |
| | | """ |
| | | 用于中文数字系统的数学符号 (繁/简体), e.g. |
| | | positive = ['正', '正'] |
| | | negative = ['负', '負'] |
| | | point = ['点', '點'] |
| | | """ |
| | | |
| | | def __init__(self, positive, negative, point): |
| | | self.positive = positive |
| | | self.negative = negative |
| | | self.point = point |
| | | |
| | | def __iter__(self): |
| | | for v in self.__dict__.values(): |
| | | yield v |
| | | |
| | | |
| | | # class OtherSymbol(object): |
| | | # """ |
| | | # 其他符号 |
| | | # """ |
| | | # |
| | | # def __init__(self, sil): |
| | | # self.sil = sil |
| | | # |
| | | # def __iter__(self): |
| | | # for v in self.__dict__.values(): |
| | | # yield v |
| | | |
| | | |
| | | # ================================================================================ # |
| | | # basic utils |
| | | # ================================================================================ # |
| | | def create_system(numbering_type=NUMBERING_TYPES[1]): |
| | | """ |
| | | 根据数字系统类型返回创建相应的数字系统,默认为 mid |
| | | NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 |
| | | low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. |
| | | mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. |
| | | high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. |
| | | 返回对应的数字系统 |
| | | """ |
| | | |
| | | # chinese number units of '亿' and larger |
| | | all_larger_units = zip( |
| | | LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL |
| | | ) |
| | | larger_units = [CNU.create(i, v, numbering_type, False) for i, v in enumerate(all_larger_units)] |
| | | # chinese number units of '十, 百, 千, 万' |
| | | all_smaller_units = zip( |
| | | SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL |
| | | ) |
| | | smaller_units = [CNU.create(i, v, small_unit=True) for i, v in enumerate(all_smaller_units)] |
| | | # digis |
| | | chinese_digis = zip( |
| | | CHINESE_DIGIS, CHINESE_DIGIS, BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL |
| | | ) |
| | | digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] |
| | | digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT |
| | | digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT |
| | | digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] |
| | | |
| | | # symbols |
| | | positive_cn = CM(POSITIVE[0], POSITIVE[1], "+", lambda x: x) |
| | | negative_cn = CM(NEGATIVE[0], NEGATIVE[1], "-", lambda x: -x) |
| | | point_cn = CM(POINT[0], POINT[1], ".", lambda x, y: float(str(x) + "." + str(y))) |
| | | # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) |
| | | system = NumberSystem() |
| | | system.units = smaller_units + larger_units |
| | | system.digits = digits |
| | | system.math = MathSymbol(positive_cn, negative_cn, point_cn) |
| | | # system.symbols = OtherSymbol(sil_cn) |
| | | return system |
| | | |
| | | |
| | | def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): |
| | | |
| | | def get_symbol(char, system): |
| | | for u in system.units: |
| | | if char in [u.traditional, u.simplified, u.big_s, u.big_t]: |
| | | return u |
| | | for d in system.digits: |
| | | if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]: |
| | | return d |
| | | for m in system.math: |
| | | if char in [m.traditional, m.simplified]: |
| | | return m |
| | | |
| | | def string2symbols(chinese_string, system): |
| | | int_string, dec_string = chinese_string, "" |
| | | for p in [system.math.point.simplified, system.math.point.traditional]: |
| | | if p in chinese_string: |
| | | int_string, dec_string = chinese_string.split(p) |
| | | break |
| | | return [get_symbol(c, system) for c in int_string], [ |
| | | get_symbol(c, system) for c in dec_string |
| | | ] |
| | | |
| | | def correct_symbols(integer_symbols, system): |
| | | """ |
| | | 一百八 to 一百八十 |
| | | 一亿一千三百万 to 一亿 一千万 三百万 |
| | | """ |
| | | |
| | | if integer_symbols and isinstance(integer_symbols[0], CNU): |
| | | if integer_symbols[0].power == 1: |
| | | integer_symbols = [system.digits[1]] + integer_symbols |
| | | |
| | | if len(integer_symbols) > 1: |
| | | if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU): |
| | | integer_symbols.append(CNU(integer_symbols[-2].power - 1, None, None, None, None)) |
| | | |
| | | result = [] |
| | | unit_count = 0 |
| | | for s in integer_symbols: |
| | | if isinstance(s, CND): |
| | | result.append(s) |
| | | unit_count = 0 |
| | | elif isinstance(s, CNU): |
| | | current_unit = CNU(s.power, None, None, None, None) |
| | | unit_count += 1 |
| | | |
| | | if unit_count == 1: |
| | | result.append(current_unit) |
| | | elif unit_count > 1: |
| | | for i in range(len(result)): |
| | | if ( |
| | | isinstance(result[-i - 1], CNU) |
| | | and result[-i - 1].power < current_unit.power |
| | | ): |
| | | result[-i - 1] = CNU( |
| | | result[-i - 1].power + current_unit.power, None, None, None, None |
| | | ) |
| | | return result |
| | | |
| | | def compute_value(integer_symbols): |
| | | """ |
| | | Compute the value. |
| | | When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. |
| | | e.g. '两千万' = 2000 * 10000 not 2000 + 10000 |
| | | """ |
| | | value = [0] |
| | | last_power = 0 |
| | | for s in integer_symbols: |
| | | if isinstance(s, CND): |
| | | value[-1] = s.value |
| | | elif isinstance(s, CNU): |
| | | value[-1] *= pow(10, s.power) |
| | | if s.power > last_power: |
| | | value[:-1] = list(map(lambda v: v * pow(10, s.power), value[:-1])) |
| | | last_power = s.power |
| | | value.append(0) |
| | | return sum(value) |
| | | |
| | | system = create_system(numbering_type) |
| | | int_part, dec_part = string2symbols(chinese_string, system) |
| | | int_part = correct_symbols(int_part, system) |
| | | int_str = str(compute_value(int_part)) |
| | | dec_str = "".join([str(d.value) for d in dec_part]) |
| | | if dec_part: |
| | | return "{0}.{1}".format(int_str, dec_str) |
| | | else: |
| | | return int_str |
| | | |
| | | |
| | | def num2chn( |
| | | number_string, |
| | | numbering_type=NUMBERING_TYPES[1], |
| | | big=False, |
| | | traditional=False, |
| | | alt_zero=False, |
| | | alt_one=False, |
| | | alt_two=True, |
| | | use_zeros=True, |
| | | use_units=True, |
| | | ): |
| | | |
| | | def get_value(value_string, use_zeros=True): |
| | | |
| | | striped_string = value_string.lstrip("0") |
| | | |
| | | # record nothing if all zeros |
| | | if not striped_string: |
| | | return [] |
| | | |
| | | # record one digits |
| | | elif len(striped_string) == 1: |
| | | if use_zeros and len(value_string) != len(striped_string): |
| | | return [system.digits[0], system.digits[int(striped_string)]] |
| | | else: |
| | | return [system.digits[int(striped_string)]] |
| | | |
| | | # recursively record multiple digits |
| | | else: |
| | | result_unit = next(u for u in reversed(system.units) if u.power < len(striped_string)) |
| | | result_string = value_string[: -result_unit.power] |
| | | return ( |
| | | get_value(result_string) |
| | | + [result_unit] |
| | | + get_value(striped_string[-result_unit.power :]) |
| | | ) |
| | | |
| | | system = create_system(numbering_type) |
| | | |
| | | int_dec = number_string.split(".") |
| | | if len(int_dec) == 1: |
| | | int_string = int_dec[0] |
| | | dec_string = "" |
| | | elif len(int_dec) == 2: |
| | | int_string = int_dec[0] |
| | | dec_string = int_dec[1] |
| | | else: |
| | | raise ValueError( |
| | | "invalid input num string with more than one dot: {}".format(number_string) |
| | | ) |
| | | |
| | | if use_units and len(int_string) > 1: |
| | | result_symbols = get_value(int_string) |
| | | else: |
| | | result_symbols = [system.digits[int(c)] for c in int_string] |
| | | dec_symbols = [system.digits[int(c)] for c in dec_string] |
| | | if dec_string: |
| | | result_symbols += [system.math.point] + dec_symbols |
| | | |
| | | if alt_two: |
| | | liang = CND( |
| | | 2, |
| | | system.digits[2].alt_s, |
| | | system.digits[2].alt_t, |
| | | system.digits[2].big_s, |
| | | system.digits[2].big_t, |
| | | ) |
| | | for i, v in enumerate(result_symbols): |
| | | if isinstance(v, CND) and v.value == 2: |
| | | next_symbol = result_symbols[i + 1] if i < len(result_symbols) - 1 else None |
| | | previous_symbol = result_symbols[i - 1] if i > 0 else None |
| | | if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))): |
| | | if next_symbol.power != 1 and ( |
| | | (previous_symbol is None) or (previous_symbol.power != 1) |
| | | ): |
| | | result_symbols[i] = liang |
| | | |
| | | # if big is True, '两' will not be used and `alt_two` has no impact on output |
| | | if big: |
| | | attr_name = "big_" |
| | | if traditional: |
| | | attr_name += "t" |
| | | else: |
| | | attr_name += "s" |
| | | else: |
| | | if traditional: |
| | | attr_name = "traditional" |
| | | else: |
| | | attr_name = "simplified" |
| | | |
| | | result = "".join([getattr(s, attr_name) for s in result_symbols]) |
| | | |
| | | # if not use_zeros: |
| | | # result = result.strip(getattr(system.digits[0], attr_name)) |
| | | |
| | | if alt_zero: |
| | | result = result.replace(getattr(system.digits[0], attr_name), system.digits[0].alt_s) |
| | | |
| | | if alt_one: |
| | | result = result.replace(getattr(system.digits[1], attr_name), system.digits[1].alt_s) |
| | | |
| | | for i, p in enumerate(POINT): |
| | | if result.startswith(p): |
| | | return CHINESE_DIGIS[0] + result |
| | | |
| | | # ^10, 11, .., 19 |
| | | if ( |
| | | len(result) >= 2 |
| | | and result[1] |
| | | in [ |
| | | SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], |
| | | SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0], |
| | | ] |
| | | and result[0] |
| | | in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]] |
| | | ): |
| | | result = result[1:] |
| | | |
| | | return result |
| | | |
| | | |
| | | # ================================================================================ # |
| | | # different types of rewriters |
| | | # ================================================================================ # |
| | | class Cardinal: |
| | | """ |
| | | CARDINAL类 |
| | | """ |
| | | |
| | | def __init__(self, cardinal=None, chntext=None): |
| | | self.cardinal = cardinal |
| | | self.chntext = chntext |
| | | |
| | | def chntext2cardinal(self): |
| | | return chn2num(self.chntext) |
| | | |
| | | def cardinal2chntext(self): |
| | | return num2chn(self.cardinal) |
| | | |
| | | |
| | | class Digit: |
| | | """ |
| | | DIGIT类 |
| | | """ |
| | | |
| | | def __init__(self, digit=None, chntext=None): |
| | | self.digit = digit |
| | | self.chntext = chntext |
| | | |
| | | # def chntext2digit(self): |
| | | # return chn2num(self.chntext) |
| | | |
| | | def digit2chntext(self): |
| | | return num2chn(self.digit, alt_two=False, use_units=False) |
| | | |
| | | |
| | | class TelePhone: |
| | | """ |
| | | TELEPHONE类 |
| | | """ |
| | | |
| | | def __init__(self, telephone=None, raw_chntext=None, chntext=None): |
| | | self.telephone = telephone |
| | | self.raw_chntext = raw_chntext |
| | | self.chntext = chntext |
| | | |
| | | # def chntext2telephone(self): |
| | | # sil_parts = self.raw_chntext.split('<SIL>') |
| | | # self.telephone = '-'.join([ |
| | | # str(chn2num(p)) for p in sil_parts |
| | | # ]) |
| | | # return self.telephone |
| | | |
| | | def telephone2chntext(self, fixed=False): |
| | | |
| | | if fixed: |
| | | sil_parts = self.telephone.split("-") |
| | | self.raw_chntext = "<SIL>".join( |
| | | [num2chn(part, alt_two=False, use_units=False) for part in sil_parts] |
| | | ) |
| | | self.chntext = self.raw_chntext.replace("<SIL>", "") |
| | | else: |
| | | sp_parts = self.telephone.strip("+").split() |
| | | self.raw_chntext = "<SP>".join( |
| | | [num2chn(part, alt_two=False, use_units=False) for part in sp_parts] |
| | | ) |
| | | self.chntext = self.raw_chntext.replace("<SP>", "") |
| | | return self.chntext |
| | | |
| | | |
| | | class Fraction: |
| | | """ |
| | | FRACTION类 |
| | | """ |
| | | |
| | | def __init__(self, fraction=None, chntext=None): |
| | | self.fraction = fraction |
| | | self.chntext = chntext |
| | | |
| | | def chntext2fraction(self): |
| | | denominator, numerator = self.chntext.split("分之") |
| | | return chn2num(numerator) + "/" + chn2num(denominator) |
| | | |
| | | def fraction2chntext(self): |
| | | numerator, denominator = self.fraction.split("/") |
| | | return num2chn(denominator) + "分之" + num2chn(numerator) |
| | | |
| | | |
| | | class Date: |
| | | """ |
| | | DATE类 |
| | | """ |
| | | |
| | | def __init__(self, date=None, chntext=None): |
| | | self.date = date |
| | | self.chntext = chntext |
| | | |
| | | # def chntext2date(self): |
| | | # chntext = self.chntext |
| | | # try: |
| | | # year, other = chntext.strip().split('年', maxsplit=1) |
| | | # year = Digit(chntext=year).digit2chntext() + '年' |
| | | # except ValueError: |
| | | # other = chntext |
| | | # year = '' |
| | | # if other: |
| | | # try: |
| | | # month, day = other.strip().split('月', maxsplit=1) |
| | | # month = Cardinal(chntext=month).chntext2cardinal() + '月' |
| | | # except ValueError: |
| | | # day = chntext |
| | | # month = '' |
| | | # if day: |
| | | # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] |
| | | # else: |
| | | # month = '' |
| | | # day = '' |
| | | # date = year + month + day |
| | | # self.date = date |
| | | # return self.date |
| | | |
| | | def date2chntext(self): |
| | | date = self.date |
| | | try: |
| | | year, other = date.strip().split("年", 1) |
| | | year = Digit(digit=year).digit2chntext() + "年" |
| | | except ValueError: |
| | | other = date |
| | | year = "" |
| | | if other: |
| | | try: |
| | | month, day = other.strip().split("月", 1) |
| | | month = Cardinal(cardinal=month).cardinal2chntext() + "月" |
| | | except ValueError: |
| | | day = date |
| | | month = "" |
| | | if day: |
| | | day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] |
| | | else: |
| | | month = "" |
| | | day = "" |
| | | chntext = year + month + day |
| | | self.chntext = chntext |
| | | return self.chntext |
| | | |
| | | |
| | | class Money: |
| | | """ |
| | | MONEY类 |
| | | """ |
| | | |
| | | def __init__(self, money=None, chntext=None): |
| | | self.money = money |
| | | self.chntext = chntext |
| | | |
| | | # def chntext2money(self): |
| | | # return self.money |
| | | |
| | | def money2chntext(self): |
| | | money = self.money |
| | | pattern = re.compile(r"(\d+(\.\d+)?)") |
| | | matchers = pattern.findall(money) |
| | | if matchers: |
| | | for matcher in matchers: |
| | | money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()) |
| | | self.chntext = money |
| | | return self.chntext |
| | | |
| | | |
| | | class Percentage: |
| | | """ |
| | | PERCENTAGE类 |
| | | """ |
| | | |
| | | def __init__(self, percentage=None, chntext=None): |
| | | self.percentage = percentage |
| | | self.chntext = chntext |
| | | |
| | | def chntext2percentage(self): |
| | | return chn2num(self.chntext.strip().strip("百分之")) + "%" |
| | | |
| | | def percentage2chntext(self): |
| | | return "百分之" + num2chn(self.percentage.strip().strip("%")) |
| | | |
| | | |
| | | def remove_erhua(text, er_whitelist): |
| | | """ |
| | | 去除儿化音词中的儿: |
| | | 他女儿在那边儿 -> 他女儿在那边 |
| | | """ |
| | | |
| | | er_pattern = re.compile(er_whitelist) |
| | | new_str = "" |
| | | while re.search("儿", text): |
| | | a = re.search("儿", text).span() |
| | | remove_er_flag = 0 |
| | | |
| | | if er_pattern.search(text): |
| | | b = er_pattern.search(text).span() |
| | | if b[0] <= a[0]: |
| | | remove_er_flag = 1 |
| | | |
| | | if remove_er_flag == 0: |
| | | new_str = new_str + text[0 : a[0]] |
| | | text = text[a[1] :] |
| | | else: |
| | | new_str = new_str + text[0 : b[1]] |
| | | text = text[b[1] :] |
| | | |
| | | text = new_str + text |
| | | return text |
| | | |
| | | |
| | | # ================================================================================ # |
| | | # NSW Normalizer |
| | | # ================================================================================ # |
| | | class NSWNormalizer: |
| | | def __init__(self, raw_text): |
| | | self.raw_text = "^" + raw_text + "$" |
| | | self.norm_text = "" |
| | | |
| | | def _particular(self): |
| | | text = self.norm_text |
| | | pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('particular') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], matcher[1] + "2" + matcher[2], 1) |
| | | self.norm_text = text |
| | | return self.norm_text |
| | | |
| | | def normalize(self): |
| | | text = self.raw_text |
| | | |
| | | # 规范化日期 |
| | | pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('date') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) |
| | | |
| | | # 规范化金钱 |
| | | pattern = re.compile( |
| | | r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)" |
| | | ) |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('money') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) |
| | | |
| | | # 规范化固话/手机号码 |
| | | # 手机 |
| | | # http://www.jihaoba.com/news/show/13680 |
| | | # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 |
| | | # 联通:130、131、132、156、155、186、185、176 |
| | | # 电信:133、153、189、180、181、177 |
| | | pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('telephone') |
| | | for matcher in matchers: |
| | | text = text.replace( |
| | | matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1 |
| | | ) |
| | | # 固话 |
| | | pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('fixed telephone') |
| | | for matcher in matchers: |
| | | text = text.replace( |
| | | matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1 |
| | | ) |
| | | |
| | | # 规范化分数 |
| | | pattern = re.compile(r"(\d+/\d+)") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('fraction') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) |
| | | |
| | | # 规范化百分数 |
| | | text = text.replace("%", "%") |
| | | pattern = re.compile(r"(\d+(\.\d+)?%)") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('percentage') |
| | | for matcher in matchers: |
| | | text = text.replace( |
| | | matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1 |
| | | ) |
| | | |
| | | # 规范化纯数+量词 |
| | | pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('cardinal+quantifier') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) |
| | | |
| | | # 规范化数字编号 |
| | | pattern = re.compile(r"(\d{4,32})") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('digit') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) |
| | | |
| | | # 规范化纯数 |
| | | pattern = re.compile(r"(\d+(\.\d+)?)") |
| | | matchers = pattern.findall(text) |
| | | if matchers: |
| | | # print('cardinal') |
| | | for matcher in matchers: |
| | | text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) |
| | | |
| | | self.norm_text = text |
| | | self._particular() |
| | | |
| | | return self.norm_text.lstrip("^").rstrip("$") |
| | | |
| | | |
| | | def nsw_test_case(raw_text): |
| | | print("I:" + raw_text) |
| | | print("O:" + NSWNormalizer(raw_text).normalize()) |
| | | print("") |
| | | |
| | | |
| | | def nsw_test(): |
| | | nsw_test_case("固话:0595-23865596或23880880。") |
| | | nsw_test_case("固话:0595-23865596或23880880。") |
| | | nsw_test_case("手机:+86 19859213959或15659451527。") |
| | | nsw_test_case("分数:32477/76391。") |
| | | nsw_test_case("百分数:80.03%。") |
| | | nsw_test_case("编号:31520181154418。") |
| | | nsw_test_case("纯数:2983.07克或12345.60米。") |
| | | nsw_test_case("日期:1999年2月20日或09年3月15号。") |
| | | nsw_test_case("金钱:12块5,34.5元,20.1万") |
| | | nsw_test_case("特殊:O2O或B2C。") |
| | | nsw_test_case("3456万吨") |
| | | nsw_test_case("2938个") |
| | | nsw_test_case("938") |
| | | nsw_test_case("今天吃了115个小笼包231个馒头") |
| | | nsw_test_case("有62%的概率") |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | # nsw_test() |
| | | |
| | | p = argparse.ArgumentParser() |
| | | p.add_argument("ifile", help="input filename, assume utf-8 encoding") |
| | | p.add_argument("ofile", help="output filename") |
| | | p.add_argument("--to_upper", action="store_true", help="convert to upper case") |
| | | p.add_argument("--to_lower", action="store_true", help="convert to lower case") |
| | | p.add_argument( |
| | | "--has_key", action="store_true", help="input text has Kaldi's key as first field." |
| | | ) |
| | | p.add_argument( |
| | | "--remove_fillers", type=bool, default=True, help='remove filler chars such as "呃, 啊"' |
| | | ) |
| | | p.add_argument( |
| | | "--remove_erhua", type=bool, default=True, help='remove erhua chars such as "这儿"' |
| | | ) |
| | | p.add_argument( |
| | | "--log_interval", type=int, default=10000, help="log interval in number of processed lines" |
| | | ) |
| | | args = p.parse_args() |
| | | |
| | | ifile = codecs.open(args.ifile, "r", "utf8") |
| | | ofile = codecs.open(args.ofile, "w+", "utf8") |
| | | |
| | | n = 0 |
| | | for l in ifile: |
| | | key = "" |
| | | text = "" |
| | | if args.has_key: |
| | | cols = l.split(maxsplit=1) |
| | | key = cols[0] |
| | | if len(cols) == 2: |
| | | text = cols[1].strip() |
| | | else: |
| | | text = "" |
| | | else: |
| | | text = l.strip() |
| | | |
| | | # cases |
| | | if args.to_upper and args.to_lower: |
| | | sys.stderr.write("text norm: to_upper OR to_lower?") |
| | | exit(1) |
| | | if args.to_upper: |
| | | text = text.upper() |
| | | if args.to_lower: |
| | | text = text.lower() |
| | | |
| | | # Filler chars removal |
| | | if args.remove_fillers: |
| | | for ch in FILLER_CHARS: |
| | | text = text.replace(ch, "") |
| | | |
| | | if args.remove_erhua: |
| | | text = remove_erhua(text, ER_WHITELIST) |
| | | |
| | | # NSW(Non-Standard-Word) normalization |
| | | text = NSWNormalizer(text).normalize() |
| | | |
| | | # Punctuations removal |
| | | old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations |
| | | new_chars = " " * len(old_chars) |
| | | del_chars = "" |
| | | text = text.translate(str.maketrans(old_chars, new_chars, del_chars)) |
| | | |
| | | # |
| | | if args.has_key: |
| | | ofile.write(key + "\t" + text + "\n") |
| | | else: |
| | | ofile.write(text + "\n") |
| | | |
| | | n += 1 |
| | | if n % args.log_interval == 0: |
| | | sys.stderr.write("text norm: {} lines done.\n".format(n)) |
| | | |
| | | sys.stderr.write("text norm: {} lines done in total.\n".format(n)) |
| | | |
| | | ifile.close() |
| | | ofile.close() |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import torch |
| | | from typing import List, Tuple |
| | | |
| | | from funasr.register import tables |
| | | from funasr.models.scama import utils as myutils |
| | | from funasr.models.transformer.utils.repeat import repeat |
| | | from funasr.models.transformer.decoder import DecoderLayer |
| | | from funasr.models.transformer.layer_norm import LayerNorm |
| | | from funasr.models.transformer.embedding import PositionalEncoding |
| | | from funasr.models.transformer.attention import MultiHeadedAttention |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.transformer.decoder import BaseTransformerDecoder |
| | | from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward |
| | | from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM |
| | | from funasr.models.sanm.attention import ( |
| | | MultiHeadedAttentionSANMDecoder, |
| | | MultiHeadedAttentionCrossAtt, |
| | | ) |
| | | |
| | | |
| | | class DecoderLayerSANM(torch.nn.Module): |
| | | """Single decoder layer module. |
| | | |
| | | Args: |
| | | size (int): Input dimension. |
| | | self_attn (torch.nn.Module): Self-attention module instance. |
| | | `MultiHeadedAttention` instance can be used as the argument. |
| | | src_attn (torch.nn.Module): Self-attention module instance. |
| | | `MultiHeadedAttention` instance can be used as the argument. |
| | | feed_forward (torch.nn.Module): Feed-forward module instance. |
| | | `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance |
| | | can be used as the argument. |
| | | dropout_rate (float): Dropout rate. |
| | | normalize_before (bool): Whether to use layer_norm before the first block. |
| | | concat_after (bool): Whether to concat attention layer's input and output. |
| | | if True, additional linear will be applied. |
| | | i.e. x -> x + linear(concat(x, att(x))) |
| | | if False, no additional linear will be applied. i.e. x -> x + att(x) |
| | | |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | self_attn, |
| | | src_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | ): |
| | | """Construct an DecoderLayer object.""" |
| | | super(DecoderLayerSANM, self).__init__() |
| | | self.size = size |
| | | self.self_attn = self_attn |
| | | self.src_attn = src_attn |
| | | self.feed_forward = feed_forward |
| | | self.norm1 = LayerNorm(size) |
| | | if self_attn is not None: |
| | | self.norm2 = LayerNorm(size) |
| | | if src_attn is not None: |
| | | self.norm3 = LayerNorm(size) |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | self.normalize_before = normalize_before |
| | | self.concat_after = concat_after |
| | | if self.concat_after: |
| | | self.concat_linear1 = torch.nn.Linear(size + size, size) |
| | | self.concat_linear2 = torch.nn.Linear(size + size, size) |
| | | self.reserve_attn = False |
| | | self.attn_mat = [] |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): |
| | | """Compute decoded features. |
| | | |
| | | Args: |
| | | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). |
| | | tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). |
| | | memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). |
| | | memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). |
| | | cache (List[torch.Tensor]): List of cached tensors. |
| | | Each tensor shape should be (#batch, maxlen_out - 1, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor(#batch, maxlen_out, size). |
| | | torch.Tensor: Mask for output tensor (#batch, maxlen_out). |
| | | torch.Tensor: Encoded memory (#batch, maxlen_in, size). |
| | | torch.Tensor: Encoded memory mask (#batch, maxlen_in). |
| | | |
| | | """ |
| | | # tgt = self.dropout(tgt) |
| | | residual = tgt |
| | | if self.normalize_before: |
| | | tgt = self.norm1(tgt) |
| | | tgt = self.feed_forward(tgt) |
| | | |
| | | x = tgt |
| | | if self.self_attn: |
| | | if self.normalize_before: |
| | | tgt = self.norm2(tgt) |
| | | x, _ = self.self_attn(tgt, tgt_mask) |
| | | x = residual + self.dropout(x) |
| | | |
| | | if self.src_attn is not None: |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm3(x) |
| | | if self.reserve_attn: |
| | | x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True) |
| | | self.attn_mat.append(attn_mat) |
| | | else: |
| | | x_src_attn = self.src_attn(x, memory, memory_mask, ret_attn=False) |
| | | x = residual + self.dropout(x_src_attn) |
| | | # x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) |
| | | |
| | | return x, tgt_mask, memory, memory_mask, cache |
| | | |
| | | def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): |
| | | residual = tgt |
| | | tgt = self.norm1(tgt) |
| | | tgt = self.feed_forward(tgt) |
| | | |
| | | x = tgt |
| | | if self.self_attn is not None: |
| | | tgt = self.norm2(tgt) |
| | | x, cache = self.self_attn(tgt, tgt_mask, cache=cache) |
| | | x = residual + x |
| | | |
| | | residual = x |
| | | x = self.norm3(x) |
| | | x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True) |
| | | return attn_mat |
| | | |
| | | def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): |
| | | """Compute decoded features. |
| | | |
| | | Args: |
| | | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). |
| | | tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). |
| | | memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). |
| | | memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). |
| | | cache (List[torch.Tensor]): List of cached tensors. |
| | | Each tensor shape should be (#batch, maxlen_out - 1, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor(#batch, maxlen_out, size). |
| | | torch.Tensor: Mask for output tensor (#batch, maxlen_out). |
| | | torch.Tensor: Encoded memory (#batch, maxlen_in, size). |
| | | torch.Tensor: Encoded memory mask (#batch, maxlen_in). |
| | | |
| | | """ |
| | | # tgt = self.dropout(tgt) |
| | | residual = tgt |
| | | if self.normalize_before: |
| | | tgt = self.norm1(tgt) |
| | | tgt = self.feed_forward(tgt) |
| | | |
| | | x = tgt |
| | | if self.self_attn: |
| | | if self.normalize_before: |
| | | tgt = self.norm2(tgt) |
| | | if self.training: |
| | | cache = None |
| | | x, cache = self.self_attn(tgt, tgt_mask, cache=cache) |
| | | x = residual + self.dropout(x) |
| | | |
| | | if self.src_attn is not None: |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm3(x) |
| | | |
| | | x = residual + self.dropout(self.src_attn(x, memory, memory_mask)) |
| | | |
| | | return x, tgt_mask, memory, memory_mask, cache |
| | | |
| | | def forward_chunk( |
| | | self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0 |
| | | ): |
| | | """Compute decoded features. |
| | | |
| | | Args: |
| | | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). |
| | | tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). |
| | | memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). |
| | | memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). |
| | | cache (List[torch.Tensor]): List of cached tensors. |
| | | Each tensor shape should be (#batch, maxlen_out - 1, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor(#batch, maxlen_out, size). |
| | | torch.Tensor: Mask for output tensor (#batch, maxlen_out). |
| | | torch.Tensor: Encoded memory (#batch, maxlen_in, size). |
| | | torch.Tensor: Encoded memory mask (#batch, maxlen_in). |
| | | |
| | | """ |
| | | residual = tgt |
| | | if self.normalize_before: |
| | | tgt = self.norm1(tgt) |
| | | tgt = self.feed_forward(tgt) |
| | | |
| | | x = tgt |
| | | if self.self_attn: |
| | | if self.normalize_before: |
| | | tgt = self.norm2(tgt) |
| | | x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache) |
| | | x = residual + self.dropout(x) |
| | | |
| | | if self.src_attn is not None: |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm3(x) |
| | | |
| | | x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back) |
| | | x = residual + x |
| | | |
| | | return x, memory, fsmn_cache, opt_cache |
| | | |
| | | |
| | | @tables.register("decoder_classes", "ParaformerSANMDecoder") |
| | | class ParaformerSANMDecoder(BaseTransformerDecoder): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | wo_input_layer: bool = False, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | att_layer_num: int = 6, |
| | | kernel_size: int = 21, |
| | | sanm_shfit: int = 0, |
| | | lora_list: List[str] = None, |
| | | lora_rank: int = 8, |
| | | lora_alpha: int = 16, |
| | | lora_dropout: float = 0.1, |
| | | chunk_multiply_factor: tuple = (1,), |
| | | tf2torch_tensor_name_prefix_torch: str = "decoder", |
| | | tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder", |
| | | ): |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | dropout_rate=dropout_rate, |
| | | positional_dropout_rate=positional_dropout_rate, |
| | | input_layer=input_layer, |
| | | use_output_layer=use_output_layer, |
| | | pos_enc_class=pos_enc_class, |
| | | normalize_before=normalize_before, |
| | | ) |
| | | |
| | | attention_dim = encoder_output_size |
| | | if wo_input_layer: |
| | | self.embed = None |
| | | else: |
| | | if input_layer == "embed": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(vocab_size, attention_dim), |
| | | # pos_enc_class(attention_dim, positional_dropout_rate), |
| | | ) |
| | | elif input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(vocab_size, attention_dim), |
| | | torch.nn.LayerNorm(attention_dim), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | pos_enc_class(attention_dim, positional_dropout_rate), |
| | | ) |
| | | else: |
| | | raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}") |
| | | |
| | | self.normalize_before = normalize_before |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(attention_dim) |
| | | if use_output_layer: |
| | | self.output_layer = torch.nn.Linear(attention_dim, vocab_size) |
| | | else: |
| | | self.output_layer = None |
| | | |
| | | self.att_layer_num = att_layer_num |
| | | self.num_blocks = num_blocks |
| | | if sanm_shfit is None: |
| | | sanm_shfit = (kernel_size - 1) // 2 |
| | | self.decoders = repeat( |
| | | att_layer_num, |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit |
| | | ), |
| | | MultiHeadedAttentionCrossAtt( |
| | | attention_heads, |
| | | attention_dim, |
| | | src_attention_dropout_rate, |
| | | lora_list, |
| | | lora_rank, |
| | | lora_alpha, |
| | | lora_dropout, |
| | | ), |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | if num_blocks - att_layer_num <= 0: |
| | | self.decoders2 = None |
| | | else: |
| | | self.decoders2 = repeat( |
| | | num_blocks - att_layer_num, |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | MultiHeadedAttentionSANMDecoder( |
| | | attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0 |
| | | ), |
| | | None, |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | |
| | | self.decoders3 = repeat( |
| | | 1, |
| | | lambda lnum: DecoderLayerSANM( |
| | | attention_dim, |
| | | None, |
| | | None, |
| | | PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch |
| | | self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf |
| | | self.chunk_multiply_factor = chunk_multiply_factor |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | chunk_mask: torch.Tensor = None, |
| | | return_hidden: bool = False, |
| | | return_both: bool = False, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward decoder. |
| | | |
| | | Args: |
| | | hs_pad: encoded memory, float32 (batch, maxlen_in, feat) |
| | | hlens: (batch) |
| | | ys_in_pad: |
| | | input token ids, int64 (batch, maxlen_out) |
| | | if input_layer == "embed" |
| | | input tensor (batch, maxlen_out, #mels) in the other cases |
| | | ys_in_lens: (batch) |
| | | Returns: |
| | | (tuple): tuple containing: |
| | | |
| | | x: decoded token score before softmax (batch, maxlen_out, token) |
| | | if use_output_layer is True, |
| | | olens: (batch, ) |
| | | """ |
| | | tgt = ys_in_pad |
| | | tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | memory = hs_pad |
| | | memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | if chunk_mask is not None: |
| | | memory_mask = memory_mask * chunk_mask |
| | | if tgt_mask.size(1) != memory_mask.size(1): |
| | | memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1) |
| | | |
| | | x = tgt |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask) |
| | | if self.decoders2 is not None: |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask) |
| | | x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask) |
| | | if self.normalize_before: |
| | | hidden = self.after_norm(x) |
| | | |
| | | olens = tgt_mask.sum(1) |
| | | if self.output_layer is not None and return_hidden is False: |
| | | x = self.output_layer(hidden) |
| | | return x, olens |
| | | if return_both: |
| | | x = self.output_layer(hidden) |
| | | return x, hidden, olens |
| | | return hidden, olens |
| | | |
| | | def score(self, ys, state, x): |
| | | """Score.""" |
| | | ys_mask = myutils.sequence_mask( |
| | | torch.tensor([len(ys)], dtype=torch.int32), device=x.device |
| | | )[:, :, None] |
| | | logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state) |
| | | return logp.squeeze(0), state |
| | | |
| | | def forward_asf2( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | ): |
| | | |
| | | tgt = ys_in_pad |
| | | tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | memory = hs_pad |
| | | memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask) |
| | | attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask) |
| | | return attn_mat |
| | | |
| | | def forward_asf6( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | ): |
| | | |
| | | tgt = ys_in_pad |
| | | tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | memory = hs_pad |
| | | memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask) |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.decoders[1](tgt, tgt_mask, memory, memory_mask) |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.decoders[2](tgt, tgt_mask, memory, memory_mask) |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.decoders[3](tgt, tgt_mask, memory, memory_mask) |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.decoders[4](tgt, tgt_mask, memory, memory_mask) |
| | | attn_mat = self.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask) |
| | | return attn_mat |
| | | |
| | | def forward_chunk( |
| | | self, |
| | | memory: torch.Tensor, |
| | | tgt: torch.Tensor, |
| | | cache: dict = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward decoder. |
| | | |
| | | Args: |
| | | hs_pad: encoded memory, float32 (batch, maxlen_in, feat) |
| | | hlens: (batch) |
| | | ys_in_pad: |
| | | input token ids, int64 (batch, maxlen_out) |
| | | if input_layer == "embed" |
| | | input tensor (batch, maxlen_out, #mels) in the other cases |
| | | ys_in_lens: (batch) |
| | | Returns: |
| | | (tuple): tuple containing: |
| | | |
| | | x: decoded token score before softmax (batch, maxlen_out, token) |
| | | if use_output_layer is True, |
| | | olens: (batch, ) |
| | | """ |
| | | x = tgt |
| | | if cache["decode_fsmn"] is None: |
| | | cache_layer_num = len(self.decoders) |
| | | if self.decoders2 is not None: |
| | | cache_layer_num += len(self.decoders2) |
| | | fsmn_cache = [None] * cache_layer_num |
| | | else: |
| | | fsmn_cache = cache["decode_fsmn"] |
| | | |
| | | if cache["opt"] is None: |
| | | cache_layer_num = len(self.decoders) |
| | | opt_cache = [None] * cache_layer_num |
| | | else: |
| | | opt_cache = cache["opt"] |
| | | |
| | | for i in range(self.att_layer_num): |
| | | decoder = self.decoders[i] |
| | | x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk( |
| | | x, |
| | | memory, |
| | | fsmn_cache=fsmn_cache[i], |
| | | opt_cache=opt_cache[i], |
| | | chunk_size=cache["chunk_size"], |
| | | look_back=cache["decoder_chunk_look_back"], |
| | | ) |
| | | |
| | | if self.num_blocks - self.att_layer_num > 1: |
| | | for i in range(self.num_blocks - self.att_layer_num): |
| | | j = i + self.att_layer_num |
| | | decoder = self.decoders2[i] |
| | | x, memory, fsmn_cache[j], _ = decoder.forward_chunk( |
| | | x, memory, fsmn_cache=fsmn_cache[j] |
| | | ) |
| | | |
| | | for decoder in self.decoders3: |
| | | x, memory, _, _ = decoder.forward_chunk(x, memory) |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | if self.output_layer is not None: |
| | | x = self.output_layer(x) |
| | | |
| | | cache["decode_fsmn"] = fsmn_cache |
| | | if cache["decoder_chunk_look_back"] > 0 or cache["decoder_chunk_look_back"] == -1: |
| | | cache["opt"] = opt_cache |
| | | return x |
| | | |
| | | def forward_one_step( |
| | | self, |
| | | tgt: torch.Tensor, |
| | | tgt_mask: torch.Tensor, |
| | | memory: torch.Tensor, |
| | | cache: List[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
| | | """Forward one step. |
| | | |
| | | Args: |
| | | tgt: input token ids, int64 (batch, maxlen_out) |
| | | tgt_mask: input token mask, (batch, maxlen_out) |
| | | dtype=torch.uint8 in PyTorch 1.2- |
| | | dtype=torch.bool in PyTorch 1.2+ (include 1.2) |
| | | memory: encoded memory, float32 (batch, maxlen_in, feat) |
| | | cache: cached output list of (batch, max_time_out-1, size) |
| | | Returns: |
| | | y, cache: NN output value and cache per `self.decoders`. |
| | | y.shape` is (batch, maxlen_out, token) |
| | | """ |
| | | x = self.embed(tgt) |
| | | if cache is None: |
| | | cache_layer_num = len(self.decoders) |
| | | if self.decoders2 is not None: |
| | | cache_layer_num += len(self.decoders2) |
| | | cache = [None] * cache_layer_num |
| | | new_cache = [] |
| | | # for c, decoder in zip(cache, self.decoders): |
| | | for i in range(self.att_layer_num): |
| | | decoder = self.decoders[i] |
| | | c = cache[i] |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step( |
| | | x, tgt_mask, memory, None, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | | |
| | | if self.num_blocks - self.att_layer_num > 1: |
| | | for i in range(self.num_blocks - self.att_layer_num): |
| | | j = i + self.att_layer_num |
| | | decoder = self.decoders2[i] |
| | | c = cache[j] |
| | | x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step( |
| | | x, tgt_mask, memory, None, cache=c |
| | | ) |
| | | new_cache.append(c_ret) |
| | | |
| | | for decoder in self.decoders3: |
| | | |
| | | x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step( |
| | | x, tgt_mask, memory, None, cache=None |
| | | ) |
| | | |
| | | if self.normalize_before: |
| | | y = self.after_norm(x[:, -1]) |
| | | else: |
| | | y = x[:, -1] |
| | | if self.output_layer is not None: |
| | | y = torch.log_softmax(self.output_layer(y), dim=-1) |
| | | |
| | | return y, new_cache |
| | | |
| | | |
| | | class DecoderLayerSANMExport(torch.nn.Module): |
| | | |
| | | def __init__(self, model): |
| | | super().__init__() |
| | | self.self_attn = model.self_attn |
| | | self.src_attn = model.src_attn |
| | | self.feed_forward = model.feed_forward |
| | | self.norm1 = model.norm1 |
| | | self.norm2 = model.norm2 if hasattr(model, "norm2") else None |
| | | self.norm3 = model.norm3 if hasattr(model, "norm3") else None |
| | | self.size = model.size |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): |
| | | |
| | | residual = tgt |
| | | tgt = self.norm1(tgt) |
| | | tgt = self.feed_forward(tgt) |
| | | |
| | | x = tgt |
| | | if self.self_attn is not None: |
| | | tgt = self.norm2(tgt) |
| | | x, cache = self.self_attn(tgt, tgt_mask, cache=cache) |
| | | x = residual + x |
| | | |
| | | if self.src_attn is not None: |
| | | residual = x |
| | | x = self.norm3(x) |
| | | x = residual + self.src_attn(x, memory, memory_mask) |
| | | |
| | | return x, tgt_mask, memory, memory_mask, cache |
| | | |
| | | def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None): |
| | | residual = tgt |
| | | tgt = self.norm1(tgt) |
| | | tgt = self.feed_forward(tgt) |
| | | |
| | | x = tgt |
| | | if self.self_attn is not None: |
| | | tgt = self.norm2(tgt) |
| | | x, cache = self.self_attn(tgt, tgt_mask, cache=cache) |
| | | x = residual + x |
| | | |
| | | residual = x |
| | | x = self.norm3(x) |
| | | x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True) |
| | | return attn_mat |
| | | |
| | | |
| | | @tables.register("decoder_classes", "ParaformerSANMDecoderExport") |
| | | class ParaformerSANMDecoderExport(torch.nn.Module): |
| | | def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs): |
| | | super().__init__() |
| | | # self.embed = model.embed #Embedding(model.embed, max_seq_len) |
| | | |
| | | from funasr.utils.torch_function import sequence_mask |
| | | |
| | | self.model = model |
| | | |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport |
| | | from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport |
| | | |
| | | for i, d in enumerate(self.model.decoders): |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): |
| | | d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn) |
| | | if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt): |
| | | d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn) |
| | | self.model.decoders[i] = DecoderLayerSANMExport(d) |
| | | |
| | | if self.model.decoders2 is not None: |
| | | for i, d in enumerate(self.model.decoders2): |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): |
| | | d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn) |
| | | self.model.decoders2[i] = DecoderLayerSANMExport(d) |
| | | |
| | | for i, d in enumerate(self.model.decoders3): |
| | | self.model.decoders3[i] = DecoderLayerSANMExport(d) |
| | | |
| | | self.output_layer = model.output_layer |
| | | self.after_norm = model.after_norm |
| | | self.model_name = model_name |
| | | |
| | | def prepare_mask(self, mask): |
| | | mask_3d_btd = mask[:, :, None] |
| | | if len(mask.shape) == 2: |
| | | mask_4d_bhlt = 1 - mask[:, None, None, :] |
| | | elif len(mask.shape) == 3: |
| | | mask_4d_bhlt = 1 - mask[:, None, :] |
| | | mask_4d_bhlt = mask_4d_bhlt * -10000.0 |
| | | |
| | | return mask_3d_btd, mask_4d_bhlt |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | return_hidden: bool = False, |
| | | return_both: bool = False, |
| | | ): |
| | | |
| | | tgt = ys_in_pad |
| | | tgt_mask = self.make_pad_mask(ys_in_lens) |
| | | tgt_mask, _ = self.prepare_mask(tgt_mask) |
| | | # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | memory = hs_pad |
| | | memory_mask = self.make_pad_mask(hlens) |
| | | _, memory_mask = self.prepare_mask(memory_mask) |
| | | # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | |
| | | x = tgt |
| | | x, tgt_mask, memory, memory_mask, _ = self.model.decoders(x, tgt_mask, memory, memory_mask) |
| | | if self.model.decoders2 is not None: |
| | | x, tgt_mask, memory, memory_mask, _ = self.model.decoders2( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask) |
| | | hidden = self.after_norm(x) |
| | | # x = self.output_layer(x) |
| | | |
| | | if self.output_layer is not None and return_hidden is False: |
| | | x = self.output_layer(hidden) |
| | | return x, ys_in_lens |
| | | if return_both: |
| | | x = self.output_layer(hidden) |
| | | return x, hidden, ys_in_lens |
| | | return hidden, ys_in_lens |
| | | |
| | | def forward_asf2( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | ): |
| | | |
| | | tgt = ys_in_pad |
| | | tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | memory = hs_pad |
| | | memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | _, memory_mask = self.prepare_mask(memory_mask) |
| | | |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0]( |
| | | tgt, tgt_mask, memory, memory_mask |
| | | ) |
| | | attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask) |
| | | return attn_mat |
| | | |
| | | def forward_asf6( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | ): |
| | | |
| | | tgt = ys_in_pad |
| | | tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | memory = hs_pad |
| | | memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | _, memory_mask = self.prepare_mask(memory_mask) |
| | | |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0]( |
| | | tgt, tgt_mask, memory, memory_mask |
| | | ) |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1]( |
| | | tgt, tgt_mask, memory, memory_mask |
| | | ) |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2]( |
| | | tgt, tgt_mask, memory, memory_mask |
| | | ) |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3]( |
| | | tgt, tgt_mask, memory, memory_mask |
| | | ) |
| | | tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4]( |
| | | tgt, tgt_mask, memory, memory_mask |
| | | ) |
| | | attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask) |
| | | return attn_mat |
| | | |
| | | """ |
| | | def get_dummy_inputs(self, enc_size): |
| | | tgt = torch.LongTensor([0]).unsqueeze(0) |
| | | memory = torch.randn(1, 100, enc_size) |
| | | pre_acoustic_embeds = torch.randn(1, 1, enc_size) |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | cache = [ |
| | | torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)) |
| | | for _ in range(cache_num) |
| | | ] |
| | | return (tgt, memory, pre_acoustic_embeds, cache) |
| | | |
| | | def is_optimizable(self): |
| | | return True |
| | | |
| | | def get_input_names(self): |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | return ['tgt', 'memory', 'pre_acoustic_embeds'] \ |
| | | + ['cache_%d' % i for i in range(cache_num)] |
| | | |
| | | def get_output_names(self): |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | return ['y'] \ |
| | | + ['out_cache_%d' % i for i in range(cache_num)] |
| | | |
| | | def get_dynamic_axes(self): |
| | | ret = { |
| | | 'tgt': { |
| | | 0: 'tgt_batch', |
| | | 1: 'tgt_length' |
| | | }, |
| | | 'memory': { |
| | | 0: 'memory_batch', |
| | | 1: 'memory_length' |
| | | }, |
| | | 'pre_acoustic_embeds': { |
| | | 0: 'acoustic_embeds_batch', |
| | | 1: 'acoustic_embeds_length', |
| | | } |
| | | } |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | ret.update({ |
| | | 'cache_%d' % d: { |
| | | 0: 'cache_%d_batch' % d, |
| | | 2: 'cache_%d_length' % d |
| | | } |
| | | for d in range(cache_num) |
| | | }) |
| | | return ret |
| | | """ |
| | | |
| | | |
| | | @tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport") |
| | | class ParaformerSANMDecoderOnlineExport(torch.nn.Module): |
| | | def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs): |
| | | super().__init__() |
| | | # self.embed = model.embed #Embedding(model.embed, max_seq_len) |
| | | self.model = model |
| | | |
| | | from funasr.utils.torch_function import sequence_mask |
| | | |
| | | self.model = model |
| | | |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport |
| | | from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport |
| | | |
| | | for i, d in enumerate(self.model.decoders): |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): |
| | | d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn) |
| | | if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt): |
| | | d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn) |
| | | self.model.decoders[i] = DecoderLayerSANMExport(d) |
| | | |
| | | if self.model.decoders2 is not None: |
| | | for i, d in enumerate(self.model.decoders2): |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): |
| | | d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn) |
| | | self.model.decoders2[i] = DecoderLayerSANMExport(d) |
| | | |
| | | for i, d in enumerate(self.model.decoders3): |
| | | self.model.decoders3[i] = DecoderLayerSANMExport(d) |
| | | |
| | | self.output_layer = model.output_layer |
| | | self.after_norm = model.after_norm |
| | | self.model_name = model_name |
| | | |
| | | def prepare_mask(self, mask): |
| | | mask_3d_btd = mask[:, :, None] |
| | | if len(mask.shape) == 2: |
| | | mask_4d_bhlt = 1 - mask[:, None, None, :] |
| | | elif len(mask.shape) == 3: |
| | | mask_4d_bhlt = 1 - mask[:, None, :] |
| | | mask_4d_bhlt = mask_4d_bhlt * -10000.0 |
| | | |
| | | return mask_3d_btd, mask_4d_bhlt |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | *args, |
| | | ): |
| | | |
| | | tgt = ys_in_pad |
| | | tgt_mask = self.make_pad_mask(ys_in_lens) |
| | | tgt_mask, _ = self.prepare_mask(tgt_mask) |
| | | # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | memory = hs_pad |
| | | memory_mask = self.make_pad_mask(hlens) |
| | | _, memory_mask = self.prepare_mask(memory_mask) |
| | | # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | |
| | | x = tgt |
| | | out_caches = list() |
| | | for i, decoder in enumerate(self.model.decoders): |
| | | in_cache = args[i] |
| | | x, tgt_mask, memory, memory_mask, out_cache = decoder( |
| | | x, tgt_mask, memory, memory_mask, cache=in_cache |
| | | ) |
| | | out_caches.append(out_cache) |
| | | if self.model.decoders2 is not None: |
| | | for i, decoder in enumerate(self.model.decoders2): |
| | | in_cache = args[i + len(self.model.decoders)] |
| | | x, tgt_mask, memory, memory_mask, out_cache = decoder( |
| | | x, tgt_mask, memory, memory_mask, cache=in_cache |
| | | ) |
| | | out_caches.append(out_cache) |
| | | x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask) |
| | | x = self.after_norm(x) |
| | | x = self.output_layer(x) |
| | | |
| | | return x, out_caches |
| | | |
| | | def get_dummy_inputs(self, enc_size): |
| | | enc = torch.randn(2, 100, enc_size).type(torch.float32) |
| | | enc_len = torch.tensor([30, 100], dtype=torch.int32) |
| | | acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32) |
| | | acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32) |
| | | cache_num = len(self.model.decoders) |
| | | if hasattr(self.model, "decoders2") and self.model.decoders2 is not None: |
| | | cache_num += len(self.model.decoders2) |
| | | cache = [ |
| | | torch.zeros( |
| | | (2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1), |
| | | dtype=torch.float32, |
| | | ) |
| | | for _ in range(cache_num) |
| | | ] |
| | | return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache) |
| | | |
| | | def get_input_names(self): |
| | | cache_num = len(self.model.decoders) |
| | | if hasattr(self.model, "decoders2") and self.model.decoders2 is not None: |
| | | cache_num += len(self.model.decoders2) |
| | | return ["enc", "enc_len", "acoustic_embeds", "acoustic_embeds_len"] + [ |
| | | "in_cache_%d" % i for i in range(cache_num) |
| | | ] |
| | | |
| | | def get_output_names(self): |
| | | cache_num = len(self.model.decoders) |
| | | if hasattr(self.model, "decoders2") and self.model.decoders2 is not None: |
| | | cache_num += len(self.model.decoders2) |
| | | return ["logits", "sample_ids"] + ["out_cache_%d" % i for i in range(cache_num)] |
| | | |
| | | def get_dynamic_axes(self): |
| | | ret = { |
| | | "enc": {0: "batch_size", 1: "enc_length"}, |
| | | "acoustic_embeds": {0: "batch_size", 1: "token_length"}, |
| | | "enc_len": { |
| | | 0: "batch_size", |
| | | }, |
| | | "acoustic_embeds_len": { |
| | | 0: "batch_size", |
| | | }, |
| | | } |
| | | cache_num = len(self.model.decoders) |
| | | if hasattr(self.model, "decoders2") and self.model.decoders2 is not None: |
| | | cache_num += len(self.model.decoders2) |
| | | ret.update( |
| | | { |
| | | "in_cache_%d" |
| | | % d: { |
| | | 0: "batch_size", |
| | | } |
| | | for d in range(cache_num) |
| | | } |
| | | ) |
| | | ret.update( |
| | | { |
| | | "out_cache_%d" |
| | | % d: { |
| | | 0: "batch_size", |
| | | } |
| | | for d in range(cache_num) |
| | | } |
| | | ) |
| | | return ret |
| | | |
| | | |
| | | @tables.register("decoder_classes", "ParaformerSANDecoder") |
| | | class ParaformerSANDecoder(BaseTransformerDecoder): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | vocab_size: int, |
| | | encoder_output_size: int, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | self_attention_dropout_rate: float = 0.0, |
| | | src_attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "embed", |
| | | use_output_layer: bool = True, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | embeds_id: int = -1, |
| | | ): |
| | | super().__init__( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | dropout_rate=dropout_rate, |
| | | positional_dropout_rate=positional_dropout_rate, |
| | | input_layer=input_layer, |
| | | use_output_layer=use_output_layer, |
| | | pos_enc_class=pos_enc_class, |
| | | normalize_before=normalize_before, |
| | | ) |
| | | |
| | | attention_dim = encoder_output_size |
| | | self.decoders = repeat( |
| | | num_blocks, |
| | | lambda lnum: DecoderLayer( |
| | | attention_dim, |
| | | MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate), |
| | | MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate), |
| | | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | self.embeds_id = embeds_id |
| | | self.attention_dim = attention_dim |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward decoder. |
| | | |
| | | Args: |
| | | hs_pad: encoded memory, float32 (batch, maxlen_in, feat) |
| | | hlens: (batch) |
| | | ys_in_pad: |
| | | input token ids, int64 (batch, maxlen_out) |
| | | if input_layer == "embed" |
| | | input tensor (batch, maxlen_out, #mels) in the other cases |
| | | ys_in_lens: (batch) |
| | | Returns: |
| | | (tuple): tuple containing: |
| | | |
| | | x: decoded token score before softmax (batch, maxlen_out, token) |
| | | if use_output_layer is True, |
| | | olens: (batch, ) |
| | | """ |
| | | tgt = ys_in_pad |
| | | tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device) |
| | | |
| | | memory = hs_pad |
| | | memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device) |
| | | # Padding for Longformer |
| | | if memory_mask.shape[-1] != memory.shape[1]: |
| | | padlen = memory.shape[1] - memory_mask.shape[-1] |
| | | memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False) |
| | | |
| | | # x = self.embed(tgt) |
| | | x = tgt |
| | | embeds_outputs = None |
| | | for layer_id, decoder in enumerate(self.decoders): |
| | | x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, memory_mask) |
| | | if layer_id == self.embeds_id: |
| | | embeds_outputs = x |
| | | if self.normalize_before: |
| | | x = self.after_norm(x) |
| | | if self.output_layer is not None: |
| | | x = self.output_layer(x) |
| | | |
| | | olens = tgt_mask.sum(1) |
| | | if embeds_outputs is not None: |
| | | return x, olens, embeds_outputs |
| | | else: |
| | | return x, olens |
| | | |
| | | |
| | | @tables.register("decoder_classes", "ParaformerDecoderSANExport") |
| | | class ParaformerDecoderSANExport(torch.nn.Module): |
| | | def __init__( |
| | | self, |
| | | model, |
| | | max_seq_len=512, |
| | | model_name="decoder", |
| | | onnx: bool = True, |
| | | ): |
| | | super().__init__() |
| | | # self.embed = model.embed #Embedding(model.embed, max_seq_len) |
| | | self.model = model |
| | | |
| | | from funasr.utils.torch_function import sequence_mask |
| | | |
| | | self.model = model |
| | | |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | from funasr.models.transformer.decoder import DecoderLayerExport |
| | | from funasr.models.transformer.attention import MultiHeadedAttentionExport |
| | | |
| | | for i, d in enumerate(self.model.decoders): |
| | | if isinstance(d.src_attn, MultiHeadedAttention): |
| | | d.src_attn = MultiHeadedAttentionExport(d.src_attn) |
| | | self.model.decoders[i] = DecoderLayerExport(d) |
| | | |
| | | self.output_layer = model.output_layer |
| | | self.after_norm = model.after_norm |
| | | self.model_name = model_name |
| | | |
| | | def prepare_mask(self, mask): |
| | | mask_3d_btd = mask[:, :, None] |
| | | if len(mask.shape) == 2: |
| | | mask_4d_bhlt = 1 - mask[:, None, None, :] |
| | | elif len(mask.shape) == 3: |
| | | mask_4d_bhlt = 1 - mask[:, None, :] |
| | | mask_4d_bhlt = mask_4d_bhlt * -10000.0 |
| | | |
| | | return mask_3d_btd, mask_4d_bhlt |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | ): |
| | | |
| | | tgt = ys_in_pad |
| | | tgt_mask = self.make_pad_mask(ys_in_lens) |
| | | tgt_mask, _ = self.prepare_mask(tgt_mask) |
| | | # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | memory = hs_pad |
| | | memory_mask = self.make_pad_mask(hlens) |
| | | _, memory_mask = self.prepare_mask(memory_mask) |
| | | # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | |
| | | x = tgt |
| | | x, tgt_mask, memory, memory_mask = self.model.decoders(x, tgt_mask, memory, memory_mask) |
| | | x = self.after_norm(x) |
| | | x = self.output_layer(x) |
| | | |
| | | return x, ys_in_lens |
| | | |
| | | def get_dummy_inputs(self, enc_size): |
| | | tgt = torch.LongTensor([0]).unsqueeze(0) |
| | | memory = torch.randn(1, 100, enc_size) |
| | | pre_acoustic_embeds = torch.randn(1, 1, enc_size) |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | cache = [ |
| | | torch.zeros( |
| | | (1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size) |
| | | ) |
| | | for _ in range(cache_num) |
| | | ] |
| | | return (tgt, memory, pre_acoustic_embeds, cache) |
| | | |
| | | def is_optimizable(self): |
| | | return True |
| | | |
| | | def get_input_names(self): |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | return ["tgt", "memory", "pre_acoustic_embeds"] + ["cache_%d" % i for i in range(cache_num)] |
| | | |
| | | def get_output_names(self): |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | return ["y"] + ["out_cache_%d" % i for i in range(cache_num)] |
| | | |
| | | def get_dynamic_axes(self): |
| | | ret = { |
| | | "tgt": {0: "tgt_batch", 1: "tgt_length"}, |
| | | "memory": {0: "memory_batch", 1: "memory_length"}, |
| | | "pre_acoustic_embeds": { |
| | | 0: "acoustic_embeds_batch", |
| | | 1: "acoustic_embeds_length", |
| | | }, |
| | | } |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | ret.update( |
| | | { |
| | | "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d} |
| | | for d in range(cache_num) |
| | | } |
| | | ) |
| | | return ret |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import types |
| | | import torch |
| | | from funasr.register import tables |
| | | |
| | | |
| | | def export_rebuild_model(model, **kwargs): |
| | | model.device = kwargs.get("device") |
| | | is_onnx = kwargs.get("type", "onnx") == "onnx" |
| | | encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export") |
| | | model.encoder = encoder_class(model.encoder, onnx=is_onnx) |
| | | |
| | | predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export") |
| | | model.predictor = predictor_class(model.predictor, onnx=is_onnx) |
| | | |
| | | decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export") |
| | | model.decoder = decoder_class(model.decoder, onnx=is_onnx) |
| | | |
| | | from funasr.utils.torch_function import sequence_mask |
| | | |
| | | model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False) |
| | | |
| | | model.forward = types.MethodType(export_forward, model) |
| | | model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model) |
| | | model.export_input_names = types.MethodType(export_input_names, model) |
| | | model.export_output_names = types.MethodType(export_output_names, model) |
| | | model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model) |
| | | model.export_name = types.MethodType(export_name, model) |
| | | |
| | | model.export_name = 'model' |
| | | return model |
| | | |
| | | |
| | | def export_forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | ): |
| | | # a. To device |
| | | batch = {"speech": speech, "speech_lengths": speech_lengths} |
| | | # batch = to_device(batch, device=self.device) |
| | | |
| | | enc, enc_len = self.encoder(**batch) |
| | | mask = self.make_pad_mask(enc_len)[:, None, :] |
| | | pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask) |
| | | pre_token_length = pre_token_length.floor().type(torch.int32) |
| | | |
| | | decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length) |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | # sample_ids = decoder_out.argmax(dim=-1) |
| | | |
| | | return decoder_out, pre_token_length |
| | | |
| | | |
| | | def export_dummy_inputs(self): |
| | | speech = torch.randn(2, 30, 560) |
| | | speech_lengths = torch.tensor([6, 30], dtype=torch.int32) |
| | | return (speech, speech_lengths) |
| | | |
| | | |
| | | def export_input_names(self): |
| | | return ["speech", "speech_lengths"] |
| | | |
| | | |
| | | def export_output_names(self): |
| | | return ["logits", "token_num"] |
| | | |
| | | |
| | | def export_dynamic_axes(self): |
| | | return { |
| | | "speech": {0: "batch_size", 1: "feats_length"}, |
| | | "speech_lengths": { |
| | | 0: "batch_size", |
| | | }, |
| | | "logits": {0: "batch_size", 1: "logits_length"}, |
| | | } |
| | | |
| | | |
| | | def export_name( |
| | | self, |
| | | ): |
| | | return "model.onnx" |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # Copyright 2024 Kun Zou (chinazoukun@gmail.com). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import time |
| | | import copy |
| | | import torch |
| | | import logging |
| | | from torch.cuda.amp import autocast |
| | | from typing import Union, Dict, List, Tuple, Optional |
| | | |
| | | from funasr.register import tables |
| | | from funasr.models.ctc.ctc import CTC |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.train_utils.device_funcs import to_device |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | from funasr.models.paraformer.cif_predictor import mae_loss |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos, add_sos_and_eos |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | |
| | | @tables.register("model_classes", "EParaformer") |
| | | class EParaformer(torch.nn.Module): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2206.08317 |
| | | Author: Kun Zou, chinazoukun@gmail.com |
| | | E-Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognition |
| | | https://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | specaug: Optional[str] = None, |
| | | specaug_conf: Optional[Dict] = None, |
| | | normalize: str = None, |
| | | normalize_conf: Optional[Dict] = None, |
| | | encoder: str = None, |
| | | encoder_conf: Optional[Dict] = None, |
| | | decoder: str = None, |
| | | decoder_conf: Optional[Dict] = None, |
| | | ctc: str = None, |
| | | ctc_conf: Optional[Dict] = None, |
| | | predictor: str = None, |
| | | predictor_conf: Optional[Dict] = None, |
| | | ctc_weight: float = 0.5, |
| | | input_size: int = 80, |
| | | vocab_size: int = -1, |
| | | ignore_id: int = -1, |
| | | blank_id: int = 0, |
| | | sos: int = 1, |
| | | eos: int = 2, |
| | | lsm_weight: float = 0.0, |
| | | length_normalized_loss: bool = False, |
| | | # report_cer: bool = True, |
| | | # report_wer: bool = True, |
| | | # sym_space: str = "<space>", |
| | | # sym_blank: str = "<blank>", |
| | | # extract_feats_in_collect_stats: bool = True, |
| | | # predictor=None, |
| | | predictor_weight: float = 0.0, |
| | | predictor_bias: int = 2, |
| | | sampling_ratio: float = 0.2, |
| | | share_embedding: bool = False, |
| | | # preencoder: Optional[AbsPreEncoder] = None, |
| | | # postencoder: Optional[AbsPostEncoder] = None, |
| | | use_1st_decoder_loss: bool = True, |
| | | **kwargs, |
| | | ): |
| | | |
| | | super().__init__() |
| | | |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = tables.normalize_classes.get(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(input_size=input_size, **encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | | |
| | | if decoder is not None: |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | **decoder_conf, |
| | | ) |
| | | if ctc_weight > 0.0: |
| | | |
| | | if ctc_conf is None: |
| | | ctc_conf = {} |
| | | |
| | | ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf) |
| | | if predictor is not None: |
| | | predictor_class = tables.predictor_classes.get(predictor) |
| | | predictor = predictor_class(**predictor_conf) |
| | | |
| | | # note that eos is the same as sos (equivalent ID) |
| | | self.blank_id = blank_id |
| | | self.sos = sos if sos is not None else vocab_size - 1 |
| | | self.eos = eos if eos is not None else vocab_size - 1 |
| | | self.vocab_size = vocab_size |
| | | self.ignore_id = ignore_id |
| | | self.ctc_weight = ctc_weight |
| | | # self.token_list = token_list.copy() |
| | | # |
| | | # self.frontend = frontend |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | # self.preencoder = preencoder |
| | | # self.postencoder = postencoder |
| | | self.encoder = encoder |
| | | # |
| | | # if not hasattr(self.encoder, "interctc_use_conditioning"): |
| | | # self.encoder.interctc_use_conditioning = False |
| | | # if self.encoder.interctc_use_conditioning: |
| | | # self.encoder.conditioning_layer = torch.nn.Linear( |
| | | # vocab_size, self.encoder.output_size() |
| | | # ) |
| | | # |
| | | # self.error_calculator = None |
| | | # |
| | | if ctc_weight == 1.0: |
| | | self.decoder = None |
| | | else: |
| | | self.decoder = decoder |
| | | |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=vocab_size, |
| | | padding_idx=ignore_id, |
| | | smoothing=lsm_weight, |
| | | normalize_length=length_normalized_loss, |
| | | ) |
| | | |
| | | if use_1st_decoder_loss: |
| | | self.criterion_att_1st = LabelSmoothingLoss( |
| | | size=vocab_size, |
| | | padding_idx=ignore_id, |
| | | smoothing=lsm_weight, |
| | | normalize_length=length_normalized_loss, |
| | | ) |
| | | |
| | | |
| | | # |
| | | # if report_cer or report_wer: |
| | | # self.error_calculator = ErrorCalculator( |
| | | # token_list, sym_space, sym_blank, report_cer, report_wer |
| | | # ) |
| | | # |
| | | if ctc_weight == 0.0: |
| | | self.ctc = None |
| | | else: |
| | | self.ctc = ctc |
| | | # |
| | | # self.extract_feats_in_collect_stats = extract_feats_in_collect_stats |
| | | self.predictor = predictor |
| | | self.predictor_weight = predictor_weight |
| | | self.predictor_bias = predictor_bias |
| | | self.sampling_ratio = sampling_ratio |
| | | self.criterion_pre = mae_loss(normalize_length=length_normalized_loss) |
| | | |
| | | self.share_embedding = share_embedding |
| | | if self.share_embedding: |
| | | self.decoder.embed = None |
| | | |
| | | self.use_1st_decoder_loss = use_1st_decoder_loss |
| | | self.length_normalized_loss = length_normalized_loss |
| | | self.beam_search = None |
| | | self.error_calculator = None |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Encoder + Decoder + Calc loss |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size = speech.shape[0] |
| | | |
| | | # Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | loss_ctc, cer_ctc = None, None |
| | | loss_pre = None |
| | | stats = dict() |
| | | |
| | | # decoder: CTC branch |
| | | if self.ctc_weight != 0.0: |
| | | loss_ctc, cer_ctc = self._calc_ctc_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # Collect CTC branch stats |
| | | stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None |
| | | stats["cer_ctc"] = cer_ctc |
| | | |
| | | # decoder: Attention decoder branch |
| | | loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # 3. CTC-Att loss definition |
| | | if self.ctc_weight == 0.0: |
| | | loss = loss_att + loss_pre * self.predictor_weight |
| | | else: |
| | | loss = ( |
| | | self.ctc_weight * loss_ctc |
| | | + (1 - self.ctc_weight) * loss_att |
| | | + loss_pre * self.predictor_weight |
| | | ) |
| | | if pre_loss_att is not None: |
| | | loss += pre_loss_att |
| | | # Collect Attn branch stats |
| | | stats["loss_att"] = loss_att.detach() if loss_att is not None else None |
| | | stats["pre_loss_att"] = pre_loss_att.detach() if pre_loss_att is not None else None |
| | | stats["acc"] = acc_att |
| | | stats["cer"] = cer_att |
| | | stats["wer"] = wer_att |
| | | stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None |
| | | |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | stats["batch_size"] = batch_size |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = (text_lengths + self.predictor_bias).sum() |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def encode( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | ind: int |
| | | """ |
| | | with autocast(False): |
| | | |
| | | # Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| | | |
| | | # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | if self.normalize is not None: |
| | | speech, speech_lengths = self.normalize(speech, speech_lengths) |
| | | |
| | | # Forward encoder |
| | | encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def calc_predictor(self, encoder_out, encoder_out_lens): |
| | | |
| | | encoder_out_mask = ( |
| | | ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :] |
| | | ).to(encoder_out.device) |
| | | pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor( |
| | | encoder_out, None, encoder_out_mask, ignore_id=self.ignore_id |
| | | ) |
| | | return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index |
| | | |
| | | def cal_decoder_with_predictor( |
| | | self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens |
| | | ): |
| | | |
| | | decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens) |
| | | decoder_out = decoder_outs[0] |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | return decoder_out, ys_pad_lens |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | encoder_out_mask = ( |
| | | ~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :] |
| | | ).to(encoder_out.device) |
| | | if self.predictor_bias == 1: |
| | | _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_pad_lens = ys_pad_lens + self.predictor_bias |
| | | if self.predictor_bias == 2: |
| | | _, ys_pad = add_sos_and_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_pad_lens = ys_pad_lens + self.predictor_bias |
| | | |
| | | pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor( |
| | | encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id |
| | | ) |
| | | |
| | | # 0. sampler |
| | | decoder_out_1st = None |
| | | pre_loss_att = None |
| | | if self.sampling_ratio > 0.0: |
| | | if self.use_1st_decoder_loss: |
| | | sematic_embeds, decoder_out_1st = self.sampler_with_grad( |
| | | encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds |
| | | ) |
| | | else: |
| | | |
| | | sematic_embeds, decoder_out_1st = self.sampler( |
| | | encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds |
| | | ) |
| | | else: |
| | | sematic_embeds = pre_acoustic_embeds |
| | | |
| | | # 1. Forward decoder |
| | | decoder_outs = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens) |
| | | decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | |
| | | if decoder_out_1st is None: |
| | | decoder_out_1st = decoder_out |
| | | # 2. Compute attention loss |
| | | if self.use_1st_decoder_loss: |
| | | pre_loss_att = self.criterion_att_1st(decoder_out_1st, ys_pad) |
| | | loss_att = self.criterion_att(decoder_out, ys_pad) |
| | | acc_att = th_accuracy( |
| | | decoder_out_1st.view(-1, self.vocab_size), |
| | | ys_pad, |
| | | ignore_label=self.ignore_id, |
| | | ) |
| | | loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) |
| | | |
| | | # Compute cer/wer using attention-decoder |
| | | if self.training or self.error_calculator is None: |
| | | cer_att, wer_att = None, None |
| | | else: |
| | | ys_hat = decoder_out_1st.argmax(dim=-1) |
| | | cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) |
| | | |
| | | return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att |
| | | |
| | | def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds): |
| | | |
| | | tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to( |
| | | ys_pad.device |
| | | ) |
| | | ys_pad_masked = ys_pad * tgt_mask[:, :, 0] |
| | | if self.share_embedding: |
| | | ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked] |
| | | else: |
| | | ys_pad_embed = self.decoder.embed(ys_pad_masked) |
| | | with torch.no_grad(): |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens |
| | | ) |
| | | decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | pred_tokens = decoder_out.argmax(-1) |
| | | nonpad_positions = ys_pad.ne(self.ignore_id) |
| | | seq_lens = (nonpad_positions).sum(1) |
| | | same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1) |
| | | input_mask = torch.ones_like(nonpad_positions) |
| | | bsz, seq_len = ys_pad.size() |
| | | for li in range(bsz): |
| | | target_num = ( |
| | | ((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio |
| | | ).long() |
| | | if target_num > 0: |
| | | input_mask[li].scatter_( |
| | | dim=0, |
| | | index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), |
| | | value=0, |
| | | ) |
| | | input_mask = input_mask.eq(1) |
| | | input_mask = input_mask.masked_fill(~nonpad_positions, False) |
| | | input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device) |
| | | |
| | | sematic_embeds = pre_acoustic_embeds.masked_fill( |
| | | ~input_mask_expand_dim, 0 |
| | | ) + ys_pad_embed.masked_fill(input_mask_expand_dim, 0) |
| | | return sematic_embeds * tgt_mask, decoder_out * tgt_mask |
| | | |
| | | def sampler_with_grad(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds): |
| | | |
| | | tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to( |
| | | ys_pad.device |
| | | ) |
| | | ys_pad_masked = ys_pad * tgt_mask[:, :, 0] |
| | | if self.share_embedding: |
| | | ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked] |
| | | else: |
| | | ys_pad_embed = self.decoder.embed(ys_pad_masked) |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens |
| | | ) |
| | | decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | pred_tokens = decoder_out.argmax(-1) |
| | | nonpad_positions = ys_pad.ne(self.ignore_id) |
| | | seq_lens = (nonpad_positions).sum(1) |
| | | same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1) |
| | | input_mask = torch.ones_like(nonpad_positions) |
| | | bsz, seq_len = ys_pad.size() |
| | | for li in range(bsz): |
| | | target_num = ( |
| | | ((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio |
| | | ).long() |
| | | if target_num > 0: |
| | | input_mask[li].scatter_( |
| | | dim=0, |
| | | index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device), |
| | | value=0, |
| | | ) |
| | | input_mask = input_mask.eq(1) |
| | | input_mask = input_mask.masked_fill(~nonpad_positions, False) |
| | | input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device) |
| | | |
| | | sematic_embeds = pre_acoustic_embeds.masked_fill( |
| | | ~input_mask_expand_dim, 0 |
| | | ) + ys_pad_embed.masked_fill(input_mask_expand_dim, 0) |
| | | return sematic_embeds * tgt_mask, decoder_out * tgt_mask |
| | | |
| | | |
| | | def _calc_ctc_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | # Calc CTC loss |
| | | loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |
| | | |
| | | # Calc CER using CTC |
| | | cer_ctc = None |
| | | if not self.training and self.error_calculator is not None: |
| | | ys_hat = self.ctc.argmax(encoder_out).data |
| | | cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) |
| | | return loss_ctc, cer_ctc |
| | | |
| | | def init_beam_search( |
| | | self, |
| | | **kwargs, |
| | | ): |
| | | from funasr.models.paraformer.search import BeamSearchPara |
| | | from funasr.models.transformer.scorers.ctc import CTCPrefixScorer |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | |
| | | if self.ctc != None: |
| | | ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos) |
| | | scorers.update(ctc=ctc) |
| | | token_list = kwargs.get("token_list") |
| | | scorers.update( |
| | | length_bonus=LengthBonus(len(token_list)), |
| | | ) |
| | | |
| | | # 3. Build ngram model |
| | | # ngram is not supported now |
| | | ngram = None |
| | | scorers["ngram"] = ngram |
| | | |
| | | weights = dict( |
| | | decoder=1.0 - kwargs.get("decoding_ctc_weight"), |
| | | ctc=kwargs.get("decoding_ctc_weight", 0.0), |
| | | lm=kwargs.get("lm_weight", 0.0), |
| | | ngram=kwargs.get("ngram_weight", 0.0), |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearchPara( |
| | | beam_size=kwargs.get("beam_size", 2), |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=self.sos, |
| | | eos=self.eos, |
| | | vocab_size=len(token_list), |
| | | token_list=token_list, |
| | | pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", |
| | | ) |
| | | # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() |
| | | # for scorer in scorers.values(): |
| | | # if isinstance(scorer, torch.nn.Module): |
| | | # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() |
| | | self.beam_search = beam_search |
| | | |
| | | def inference( |
| | | self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | # init beamsearch |
| | | is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None |
| | | is_use_lm = ( |
| | | kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None |
| | | ) |
| | | pred_timestamp = kwargs.get("pred_timestamp", False) |
| | | if self.beam_search is None and (is_use_lm or is_use_ctc): |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | meta_data = {} |
| | | if ( |
| | | isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank" |
| | | ): # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is not None: |
| | | speech_lengths = speech_lengths.squeeze(-1) |
| | | else: |
| | | speech_lengths = speech.shape[1] |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video( |
| | | data_in, |
| | | fs=frontend.fs, |
| | | audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer, |
| | | ) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank( |
| | | audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend |
| | | ) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | meta_data["batch_data_time"] = ( |
| | | speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 |
| | | ) |
| | | |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | # Encoder |
| | | if kwargs.get("fp16", False): |
| | | speech = speech.half() |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | | |
| | | # predictor |
| | | predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) |
| | | pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = ( |
| | | predictor_outs[0], |
| | | predictor_outs[1], |
| | | predictor_outs[2], |
| | | predictor_outs[3], |
| | | ) |
| | | |
| | | pre_token_length = pre_token_length.round().long() |
| | | if torch.max(pre_token_length) < 1: |
| | | return [] |
| | | decoder_outs = self.cal_decoder_with_predictor( |
| | | encoder_out, encoder_out_lens, pre_acoustic_embeds, pre_token_length |
| | | ) |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | | results = [] |
| | | b, n, d = decoder_out.size() |
| | | if isinstance(key[0], (list, tuple)): |
| | | key = key[0] |
| | | if len(key) < b: |
| | | key = key * b |
| | | for i in range(b): |
| | | x = encoder_out[i, : encoder_out_lens[i], :] |
| | | am_scores = decoder_out[i, : pre_token_length[i], :] |
| | | if self.beam_search is not None: |
| | | nbest_hyps = self.beam_search( |
| | | x=x, |
| | | am_scores=am_scores, |
| | | maxlenratio=kwargs.get("maxlenratio", 0.0), |
| | | minlenratio=kwargs.get("minlenratio", 0.0), |
| | | ) |
| | | |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | else: |
| | | |
| | | yseq = am_scores.argmax(dim=-1) |
| | | score = am_scores.max(dim=-1)[0] |
| | | score = torch.sum(score, dim=-1) |
| | | # pad with mask tokens to ensure compatibility with sos/eos tokens |
| | | yseq = torch.tensor([self.sos] + yseq.tolist() + [self.eos], device=yseq.device) |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx+1}best_recog"] |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list( |
| | | filter( |
| | | lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int |
| | | ) |
| | | ) |
| | | |
| | | if tokenizer is not None: |
| | | # Change integer-ids to tokens |
| | | token = tokenizer.ids2tokens(token_int) |
| | | text_postprocessed = tokenizer.tokens2text(token) |
| | | |
| | | if pred_timestamp: |
| | | timestamp_str, timestamp = ts_prediction_lfr6_standard( |
| | | pre_peak_index[i], |
| | | alphas[i], |
| | | copy.copy(token), |
| | | vad_offset=kwargs.get("begin_time", 0), |
| | | upsample_rate=1, |
| | | ) |
| | | if not hasattr(tokenizer, "bpemodel"): |
| | | text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp) |
| | | result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,} |
| | | else: |
| | | if not hasattr(tokenizer, "bpemodel"): |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | result_i = {"key": key[i], "text": text_postprocessed} |
| | | |
| | | if ibest_writer is not None: |
| | | ibest_writer["token"][key[i]] = " ".join(token) |
| | | # ibest_writer["text"][key[i]] = text |
| | | ibest_writer["text"][key[i]] = text_postprocessed |
| | | else: |
| | | result_i = {"key": key[i], "token_int": token_int} |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | | |
| | | def export(self, **kwargs): |
| | | from .export_meta import export_rebuild_model |
| | | |
| | | if "max_seq_len" not in kwargs: |
| | | kwargs["max_seq_len"] = 512 |
| | | models = export_rebuild_model(model=self, **kwargs) |
| | | return models |
| New file |
| | |
| | | #!/usr/bin/env python3
|
| | | # -*- encoding: utf-8 -*-
|
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
| | | # Copyright 2024 Kun Zou (chinazoukun@gmail.com). All Rights Reserved.
|
| | | # MIT License (https://opensource.org/licenses/MIT)
|
| | |
|
| | | import torch
|
| | | import logging
|
| | | import numpy as np
|
| | |
|
| | | from funasr.register import tables
|
| | | from funasr.train_utils.device_funcs import to_device
|
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
| | | from torch.cuda.amp import autocast
|
| | |
|
| | |
|
| | | @tables.register("predictor_classes", "PifPredictor")
|
| | | class PifPredictor(torch.nn.Module):
|
| | | """
|
| | | Author: Kun Zou, chinazoukun@gmail.com
|
| | | E-Paraformer: A Faster and Better Parallel Transformer for Non-autoregressive End-to-End Mandarin Speech Recognition
|
| | | https://www.isca-archive.org/interspeech_2024/zou24_interspeech.pdf
|
| | | """
|
| | | def __init__(
|
| | | self,
|
| | | idim,
|
| | | l_order,
|
| | | r_order,
|
| | | threshold=1.0,
|
| | | dropout=0.1,
|
| | | smooth_factor=1.0,
|
| | | noise_threshold=0,
|
| | | sigma=0.5,
|
| | | bias=0.0,
|
| | | sigma_heads=4,
|
| | | ):
|
| | | super().__init__()
|
| | |
|
| | | self.pad = torch.nn.ConstantPad1d((l_order, r_order), 0)
|
| | | self.cif_conv1d = torch.nn.Conv1d(idim, idim, l_order + r_order + 1, groups=idim)
|
| | | self.cif_output = torch.nn.Linear(idim, 1)
|
| | | self.dropout = torch.nn.Dropout(p=dropout)
|
| | | self.threshold = threshold
|
| | | self.smooth_factor = smooth_factor
|
| | | self.noise_threshold = noise_threshold
|
| | | self.sigma = torch.nn.Parameter(torch.tensor([sigma]*sigma_heads))
|
| | | self.bias = torch.nn.Parameter(torch.tensor([bias]*sigma_heads))
|
| | | self.sigma_heads = sigma_heads
|
| | |
|
| | | def forward(
|
| | | self,
|
| | | hidden,
|
| | | target_label=None,
|
| | | mask=None,
|
| | | ignore_id=-1,
|
| | | mask_chunk_predictor=None,
|
| | | target_label_length=None,
|
| | | ):
|
| | |
|
| | | with autocast(False):
|
| | | h = hidden
|
| | | context = h.transpose(1, 2)
|
| | | queries = self.pad(context)
|
| | | memory = self.cif_conv1d(queries)
|
| | | output = memory + context
|
| | | output = self.dropout(output)
|
| | | output = output.transpose(1, 2)
|
| | | output = torch.relu(output)
|
| | | output = self.cif_output(output)
|
| | | alphas = torch.sigmoid(output)
|
| | | alphas = torch.nn.functional.relu(alphas * self.smooth_factor - self.noise_threshold)
|
| | | if mask is not None:
|
| | | mask = mask.transpose(-1, -2).float()
|
| | | alphas = alphas * mask
|
| | | if mask_chunk_predictor is not None:
|
| | | alphas = alphas * mask_chunk_predictor
|
| | | alphas = alphas.squeeze(-1)
|
| | | mask = mask.squeeze(-1)
|
| | | if target_label_length is not None:
|
| | | target_length = target_label_length
|
| | | elif target_label is not None:
|
| | | target_mask = (target_label != ignore_id).float()
|
| | | target_length = target_mask.sum(-1)
|
| | | else:
|
| | | target_mask = None
|
| | | target_length = None
|
| | | token_num = alphas.sum(-1)
|
| | | if target_length is not None:
|
| | | alphas *= (target_length / token_num)[:, None].repeat(1, alphas.size(1))
|
| | | max_token_num = torch.max(target_length)
|
| | | else:
|
| | | token_num_int = token_num.round()
|
| | | alphas *=(token_num_int / token_num)[:, None]
|
| | | max_token_num = torch.max(token_num_int)
|
| | | alignment = torch.cumsum(alphas, dim=-1)
|
| | | fire_positions = (torch.arange(max_token_num) + 0.5).type_as(alphas).unsqueeze(0)
|
| | | scores = - ((fire_positions[:, None, :, None] - alignment[:, None, None, :]) * self.sigma[None, :, None, None]) **2 + self.bias[None, :, None, None]
|
| | | scores = scores.masked_fill(~(mask[:, None, None, :].to(torch.bool)), float("-inf"))
|
| | | weights = torch.softmax(scores, dim=-1)
|
| | | n_hidden = hidden.view(hidden.size(0), -1, self.sigma_heads, hidden.size(-1) // self.sigma_heads).transpose(1, 2)
|
| | | acoustic_embeds = torch.matmul(weights, n_hidden).transpose(1,2).contiguous().view(hidden.size(0), -1, hidden.size(-1))
|
| | |
|
| | | if target_mask is not None:
|
| | | acoustic_embeds *= target_mask[:, :, None]
|
| | | cif_peak = None
|
| | | return acoustic_embeds, token_num, alphas, cif_peak
|
| | |
|
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import torch |
| | | import logging |
| | | from itertools import chain |
| | | from typing import Any, Dict, List, NamedTuple, Tuple, Union |
| | | |
| | | from funasr.metrics.common import end_detect |
| | | from funasr.models.transformer.scorers.scorer_interface import ( |
| | | PartialScorerInterface, |
| | | ScorerInterface, |
| | | ) |
| | | |
| | | |
| | | class Hypothesis(NamedTuple): |
| | | """Hypothesis data type.""" |
| | | |
| | | yseq: torch.Tensor |
| | | score: Union[float, torch.Tensor] = 0 |
| | | scores: Dict[str, Union[float, torch.Tensor]] = dict() |
| | | states: Dict[str, Any] = dict() |
| | | |
| | | def asdict(self) -> dict: |
| | | """Convert data to JSON-friendly dict.""" |
| | | return self._replace( |
| | | yseq=self.yseq.tolist(), |
| | | score=float(self.score), |
| | | scores={k: float(v) for k, v in self.scores.items()}, |
| | | )._asdict() |
| | | |
| | | |
| | | class BeamSearchPara(torch.nn.Module): |
| | | """Beam search implementation.""" |
| | | |
| | | def __init__( |
| | | self, |
| | | scorers: Dict[str, ScorerInterface], |
| | | weights: Dict[str, float], |
| | | beam_size: int, |
| | | vocab_size: int, |
| | | sos: int, |
| | | eos: int, |
| | | token_list: List[str] = None, |
| | | pre_beam_ratio: float = 1.5, |
| | | pre_beam_score_key: str = None, |
| | | ): |
| | | """Initialize beam search. |
| | | |
| | | Args: |
| | | scorers (dict[str, ScorerInterface]): Dict of decoder modules |
| | | e.g., Decoder, CTCPrefixScorer, LM |
| | | The scorer will be ignored if it is `None` |
| | | weights (dict[str, float]): Dict of weights for each scorers |
| | | The scorer will be ignored if its weight is 0 |
| | | beam_size (int): The number of hypotheses kept during search |
| | | vocab_size (int): The number of vocabulary |
| | | sos (int): Start of sequence id |
| | | eos (int): End of sequence id |
| | | token_list (list[str]): List of tokens for debug log |
| | | pre_beam_score_key (str): key of scores to perform pre-beam search |
| | | pre_beam_ratio (float): beam size in the pre-beam search |
| | | will be `int(pre_beam_ratio * beam_size)` |
| | | |
| | | """ |
| | | super().__init__() |
| | | # set scorers |
| | | self.weights = weights |
| | | self.scorers = dict() |
| | | self.full_scorers = dict() |
| | | self.part_scorers = dict() |
| | | # this module dict is required for recursive cast |
| | | # `self.to(device, dtype)` in `recog.py` |
| | | self.nn_dict = torch.nn.ModuleDict() |
| | | for k, v in scorers.items(): |
| | | w = weights.get(k, 0) |
| | | if w == 0 or v is None: |
| | | continue |
| | | assert isinstance( |
| | | v, ScorerInterface |
| | | ), f"{k} ({type(v)}) does not implement ScorerInterface" |
| | | self.scorers[k] = v |
| | | if isinstance(v, PartialScorerInterface): |
| | | self.part_scorers[k] = v |
| | | else: |
| | | self.full_scorers[k] = v |
| | | if isinstance(v, torch.nn.Module): |
| | | self.nn_dict[k] = v |
| | | |
| | | # set configurations |
| | | self.sos = sos |
| | | self.eos = eos |
| | | self.token_list = token_list |
| | | self.pre_beam_size = int(pre_beam_ratio * beam_size) |
| | | self.beam_size = beam_size |
| | | self.n_vocab = vocab_size |
| | | if ( |
| | | pre_beam_score_key is not None |
| | | and pre_beam_score_key != "full" |
| | | and pre_beam_score_key not in self.full_scorers |
| | | ): |
| | | raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}") |
| | | self.pre_beam_score_key = pre_beam_score_key |
| | | self.do_pre_beam = ( |
| | | self.pre_beam_score_key is not None |
| | | and self.pre_beam_size < self.n_vocab |
| | | and len(self.part_scorers) > 0 |
| | | ) |
| | | |
| | | def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]: |
| | | """Get an initial hypothesis data. |
| | | |
| | | Args: |
| | | x (torch.Tensor): The encoder output feature |
| | | |
| | | Returns: |
| | | Hypothesis: The initial hypothesis. |
| | | |
| | | """ |
| | | init_states = dict() |
| | | init_scores = dict() |
| | | for k, d in self.scorers.items(): |
| | | init_states[k] = d.init_state(x) |
| | | init_scores[k] = 0.0 |
| | | return [ |
| | | Hypothesis( |
| | | score=0.0, |
| | | scores=init_scores, |
| | | states=init_states, |
| | | yseq=torch.tensor([self.sos], device=x.device), |
| | | ) |
| | | ] |
| | | |
| | | @staticmethod |
| | | def append_token(xs: torch.Tensor, x: int) -> torch.Tensor: |
| | | """Append new token to prefix tokens. |
| | | |
| | | Args: |
| | | xs (torch.Tensor): The prefix token |
| | | x (int): The new token to append |
| | | |
| | | Returns: |
| | | torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device |
| | | |
| | | """ |
| | | x = torch.tensor([x], dtype=xs.dtype, device=xs.device) |
| | | return torch.cat((xs, x)) |
| | | |
| | | def score_full( |
| | | self, hyp: Hypothesis, x: torch.Tensor |
| | | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: |
| | | """Score new hypothesis by `self.full_scorers`. |
| | | |
| | | Args: |
| | | hyp (Hypothesis): Hypothesis with prefix tokens to score |
| | | x (torch.Tensor): Corresponding input feature |
| | | |
| | | Returns: |
| | | Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of |
| | | score dict of `hyp` that has string keys of `self.full_scorers` |
| | | and tensor score values of shape: `(self.n_vocab,)`, |
| | | and state dict that has string keys |
| | | and state values of `self.full_scorers` |
| | | |
| | | """ |
| | | scores = dict() |
| | | states = dict() |
| | | for k, d in self.full_scorers.items(): |
| | | scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x) |
| | | return scores, states |
| | | |
| | | def score_partial( |
| | | self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor |
| | | ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: |
| | | """Score new hypothesis by `self.part_scorers`. |
| | | |
| | | Args: |
| | | hyp (Hypothesis): Hypothesis with prefix tokens to score |
| | | ids (torch.Tensor): 1D tensor of new partial tokens to score |
| | | x (torch.Tensor): Corresponding input feature |
| | | |
| | | Returns: |
| | | Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of |
| | | score dict of `hyp` that has string keys of `self.part_scorers` |
| | | and tensor score values of shape: `(len(ids),)`, |
| | | and state dict that has string keys |
| | | and state values of `self.part_scorers` |
| | | |
| | | """ |
| | | scores = dict() |
| | | states = dict() |
| | | for k, d in self.part_scorers.items(): |
| | | scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x) |
| | | return scores, states |
| | | |
| | | def beam( |
| | | self, weighted_scores: torch.Tensor, ids: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Compute topk full token ids and partial token ids. |
| | | |
| | | Args: |
| | | weighted_scores (torch.Tensor): The weighted sum scores for each tokens. |
| | | Its shape is `(self.n_vocab,)`. |
| | | ids (torch.Tensor): The partial token ids to compute topk |
| | | |
| | | Returns: |
| | | Tuple[torch.Tensor, torch.Tensor]: |
| | | The topk full token ids and partial token ids. |
| | | Their shapes are `(self.beam_size,)` |
| | | |
| | | """ |
| | | # no pre beam performed |
| | | if weighted_scores.size(0) == ids.size(0): |
| | | top_ids = weighted_scores.topk(self.beam_size)[1] |
| | | return top_ids, top_ids |
| | | |
| | | # mask pruned in pre-beam not to select in topk |
| | | tmp = weighted_scores[ids] |
| | | weighted_scores[:] = -float("inf") |
| | | weighted_scores[ids] = tmp |
| | | top_ids = weighted_scores.topk(self.beam_size)[1] |
| | | local_ids = weighted_scores[ids].topk(self.beam_size)[1] |
| | | return top_ids, local_ids |
| | | |
| | | @staticmethod |
| | | def merge_scores( |
| | | prev_scores: Dict[str, float], |
| | | next_full_scores: Dict[str, torch.Tensor], |
| | | full_idx: int, |
| | | next_part_scores: Dict[str, torch.Tensor], |
| | | part_idx: int, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | """Merge scores for new hypothesis. |
| | | |
| | | Args: |
| | | prev_scores (Dict[str, float]): |
| | | The previous hypothesis scores by `self.scorers` |
| | | next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers` |
| | | full_idx (int): The next token id for `next_full_scores` |
| | | next_part_scores (Dict[str, torch.Tensor]): |
| | | scores of partial tokens by `self.part_scorers` |
| | | part_idx (int): The new token id for `next_part_scores` |
| | | |
| | | Returns: |
| | | Dict[str, torch.Tensor]: The new score dict. |
| | | Its keys are names of `self.full_scorers` and `self.part_scorers`. |
| | | Its values are scalar tensors by the scorers. |
| | | |
| | | """ |
| | | new_scores = dict() |
| | | for k, v in next_full_scores.items(): |
| | | new_scores[k] = prev_scores[k] + v[full_idx] |
| | | for k, v in next_part_scores.items(): |
| | | new_scores[k] = prev_scores[k] + v[part_idx] |
| | | return new_scores |
| | | |
| | | def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any: |
| | | """Merge states for new hypothesis. |
| | | |
| | | Args: |
| | | states: states of `self.full_scorers` |
| | | part_states: states of `self.part_scorers` |
| | | part_idx (int): The new token id for `part_scores` |
| | | |
| | | Returns: |
| | | Dict[str, torch.Tensor]: The new score dict. |
| | | Its keys are names of `self.full_scorers` and `self.part_scorers`. |
| | | Its values are states of the scorers. |
| | | |
| | | """ |
| | | new_states = dict() |
| | | for k, v in states.items(): |
| | | new_states[k] = v |
| | | for k, d in self.part_scorers.items(): |
| | | new_states[k] = d.select_state(part_states[k], part_idx) |
| | | return new_states |
| | | |
| | | def search( |
| | | self, running_hyps: List[Hypothesis], x: torch.Tensor, am_score: torch.Tensor |
| | | ) -> List[Hypothesis]: |
| | | """Search new tokens for running hypotheses and encoded speech x. |
| | | |
| | | Args: |
| | | running_hyps (List[Hypothesis]): Running hypotheses on beam |
| | | x (torch.Tensor): Encoded speech feature (T, D) |
| | | |
| | | Returns: |
| | | List[Hypotheses]: Best sorted hypotheses |
| | | |
| | | """ |
| | | best_hyps = [] |
| | | part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam |
| | | for hyp in running_hyps: |
| | | # scoring |
| | | weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device) |
| | | weighted_scores += am_score |
| | | scores, states = self.score_full(hyp, x) |
| | | for k in self.full_scorers: |
| | | weighted_scores += self.weights[k] * scores[k] |
| | | # partial scoring |
| | | if self.do_pre_beam: |
| | | pre_beam_scores = ( |
| | | weighted_scores |
| | | if self.pre_beam_score_key == "full" |
| | | else scores[self.pre_beam_score_key] |
| | | ) |
| | | part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1] |
| | | part_scores, part_states = self.score_partial(hyp, part_ids, x) |
| | | for k in self.part_scorers: |
| | | weighted_scores[part_ids] += self.weights[k] * part_scores[k] |
| | | # add previous hyp score |
| | | weighted_scores += hyp.score |
| | | |
| | | # update hyps |
| | | for j, part_j in zip(*self.beam(weighted_scores, part_ids)): |
| | | # will be (2 x beam at most) |
| | | best_hyps.append( |
| | | Hypothesis( |
| | | score=weighted_scores[j], |
| | | yseq=self.append_token(hyp.yseq, j), |
| | | scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j), |
| | | states=self.merge_states(states, part_states, part_j), |
| | | ) |
| | | ) |
| | | |
| | | # sort and prune 2 x beam -> beam |
| | | best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[ |
| | | : min(len(best_hyps), self.beam_size) |
| | | ] |
| | | return best_hyps |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | am_scores: torch.Tensor, |
| | | maxlenratio: float = 0.0, |
| | | minlenratio: float = 0.0, |
| | | ) -> List[Hypothesis]: |
| | | """Perform beam search. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Encoded speech feature (T, D) |
| | | maxlenratio (float): Input length ratio to obtain max output length. |
| | | If maxlenratio=0.0 (default), it uses a end-detect function |
| | | to automatically find maximum hypothesis lengths |
| | | If maxlenratio<0.0, its absolute value is interpreted |
| | | as a constant max output length. |
| | | minlenratio (float): Input length ratio to obtain min output length. |
| | | |
| | | Returns: |
| | | list[Hypothesis]: N-best decoding results |
| | | |
| | | """ |
| | | # set length bounds |
| | | maxlen = am_scores.shape[0] |
| | | logging.info("decoder input length: " + str(x.shape[0])) |
| | | logging.info("max output length: " + str(maxlen)) |
| | | |
| | | # main loop of prefix search |
| | | running_hyps = self.init_hyp(x) |
| | | ended_hyps = [] |
| | | for i in range(maxlen): |
| | | logging.debug("position " + str(i)) |
| | | best = self.search(running_hyps, x, am_scores[i]) |
| | | # post process of one iteration |
| | | running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps) |
| | | # end detection |
| | | if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i): |
| | | logging.info(f"end detected at {i}") |
| | | break |
| | | if len(running_hyps) == 0: |
| | | logging.info("no hypothesis. Finish decoding.") |
| | | break |
| | | else: |
| | | logging.debug(f"remained hypotheses: {len(running_hyps)}") |
| | | |
| | | nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True) |
| | | # check the number of hypotheses reaching to eos |
| | | if len(nbest_hyps) == 0: |
| | | logging.warning( |
| | | "there is no N-best results, perform recognition " "again with smaller minlenratio." |
| | | ) |
| | | return ( |
| | | [] |
| | | if minlenratio < 0.1 |
| | | else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1)) |
| | | ) |
| | | |
| | | # report the best result |
| | | best = nbest_hyps[0] |
| | | for k, v in best.scores.items(): |
| | | logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}") |
| | | logging.info(f"total log probability: {best.score:.2f}") |
| | | logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") |
| | | logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") |
| | | if self.token_list is not None: |
| | | logging.info( |
| | | "best hypo: " + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]]) + "\n" |
| | | ) |
| | | return nbest_hyps |
| | | |
| | | def post_process( |
| | | self, |
| | | i: int, |
| | | maxlen: int, |
| | | maxlenratio: float, |
| | | running_hyps: List[Hypothesis], |
| | | ended_hyps: List[Hypothesis], |
| | | ) -> List[Hypothesis]: |
| | | """Perform post-processing of beam search iterations. |
| | | |
| | | Args: |
| | | i (int): The length of hypothesis tokens. |
| | | maxlen (int): The maximum length of tokens in beam search. |
| | | maxlenratio (int): The maximum length ratio in beam search. |
| | | running_hyps (List[Hypothesis]): The running hypotheses in beam search. |
| | | ended_hyps (List[Hypothesis]): The ended hypotheses in beam search. |
| | | |
| | | Returns: |
| | | List[Hypothesis]: The new running hypotheses. |
| | | |
| | | """ |
| | | logging.debug(f"the number of running hypotheses: {len(running_hyps)}") |
| | | if self.token_list is not None: |
| | | logging.debug( |
| | | "best hypo: " |
| | | + "".join([self.token_list[x.item()] for x in running_hyps[0].yseq[1:]]) |
| | | ) |
| | | # add eos in the final loop to avoid that there are no ended hyps |
| | | if i == maxlen - 1: |
| | | logging.info("adding <eos> in the last position in the loop") |
| | | running_hyps = [ |
| | | h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps |
| | | ] |
| | | |
| | | # add ended hypotheses to a final list, and removed them from current hypotheses |
| | | # (this will be a problem, number of hyps < beam) |
| | | remained_hyps = [] |
| | | for hyp in running_hyps: |
| | | if hyp.yseq[-1] == self.eos: |
| | | # e.g., Word LM needs to add final <eos> score |
| | | for k, d in chain(self.full_scorers.items(), self.part_scorers.items()): |
| | | s = d.final_score(hyp.states[k]) |
| | | hyp.scores[k] += s |
| | | hyp = hyp._replace(score=hyp.score + self.weights[k] * s) |
| | | ended_hyps.append(hyp) |
| | | else: |
| | | remained_hyps.append(hyp) |
| | | return remained_hyps |