Merge remote-tracking branch 'origin/main' into dev_yf
| New file |
| | |
| | | |
| | | |
| | | python -m funasr.bin.inference \ |
| | | --config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \ |
| | | --config-name="config.yaml" \ |
| | | ++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \ |
| | | ++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \ |
| | | ++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \ |
| | | ++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \ |
| | | ++output_dir="./outputs/debug" \ |
| | | ++device="cuda:0" \ |
| | | |
| | |
| | | |
| | | # general configuration |
| | | feats_dir="../DATA" #feature output dictionary |
| | | exp_dir="." |
| | | exp_dir=`pwd` |
| | | lang=zh |
| | | token_type=char |
| | | stage=0 |
| | |
| | | |
| | | config=branchformer_12e_6d_2048_256.yaml |
| | | model_dir="baseline_$(basename "${config}" .yaml)_${lang}_${token_type}_${tag}" |
| | | |
| | | |
| | | |
| | | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then |
| | |
| | | log_file="${exp_dir}/exp/${model_dir}/train.log.txt.${current_time}" |
| | | echo "log_file: ${log_file}" |
| | | |
| | | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | |
| | | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then |
| | | echo "stage 5: Inference" |
| | | |
| | | if ${inference_device} == "cuda"; then |
| | | if [ ${inference_device} == "cuda" ]; then |
| | | nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | else |
| | | inference_batch_size=1 |
| | |
| | | |
| | | for dset in ${test_sets}; do |
| | | |
| | | inference_dir="${exp_dir}/exp/${model_dir}/${inference_checkpoint}/${dset}" |
| | | inference_dir="${exp_dir}/exp/${model_dir}/inference-${inference_checkpoint}/${dset}" |
| | | _logdir="${inference_dir}/logdir" |
| | | echo "inference_dir: ${inference_dir}" |
| | | |
| | | mkdir -p "${_logdir}" |
| | | data_dir="${feats_dir}/data/${dset}" |
| | |
| | | done |
| | | utils/split_scp.pl "${key_file}" ${split_scps} |
| | | |
| | | gpuid_list_array=(${gpuid_list//,/ }) |
| | | gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ }) |
| | | for JOB in $(seq ${nj}); do |
| | | { |
| | | id=$((JOB-1)) |
| | |
| | | ++input="${_logdir}/keys.${JOB}.scp" \ |
| | | ++output_dir="${inference_dir}/${JOB}" \ |
| | | ++device="${inference_device}" \ |
| | | ++batch_size="${inference_batch_size}" |
| | | ++ncpu=1 \ |
| | | ++disable_log=true \ |
| | | ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt |
| | | }& |
| | | |
| | | done |
| | |
| | | done |
| | | |
| | | echo "Computing WER ..." |
| | | cp ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc |
| | | cp ${data_dir}/text ${inference_dir}/1best_recog/text.ref |
| | | python utils/postprocess_text_zh.py ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc |
| | | python utils/postprocess_text_zh.py ${data_dir}/text ${inference_dir}/1best_recog/text.ref |
| | | python utils/compute_wer.py ${inference_dir}/1best_recog/text.ref ${inference_dir}/1best_recog/text.proc ${inference_dir}/1best_recog/text.cer |
| | | tail -n 3 ${inference_dir}/1best_recog/text.cer |
| | | done |
| New file |
| | |
| | | |
| | | |
| | | python -m funasr.bin.inference \ |
| | | --config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \ |
| | | --config-name="config.yaml" \ |
| | | ++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \ |
| | | ++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \ |
| | | ++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \ |
| | | ++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \ |
| | | ++output_dir="./outputs/debug" \ |
| | | ++device="cuda:0" \ |
| | | |
| | |
| | | |
| | | # general configuration |
| | | feats_dir="../DATA" #feature output dictionary |
| | | exp_dir="." |
| | | exp_dir=`pwd` |
| | | lang=zh |
| | | token_type=char |
| | | stage=0 |
| | |
| | | # feature configuration |
| | | nj=32 |
| | | |
| | | inference_device="cuda" #"cpu" |
| | | inference_device="cuda" #"cpu", "cuda:0", "cuda:1" |
| | | inference_checkpoint="model.pt" |
| | | inference_scp="wav.scp" |
| | | inference_batch_size=32 |
| | | inference_batch_size=1 |
| | | |
| | | # data |
| | | raw_data=../raw_data |
| | |
| | | log_file="${exp_dir}/exp/${model_dir}/train.log.txt.${current_time}" |
| | | echo "log_file: ${log_file}" |
| | | |
| | | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | |
| | | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then |
| | | echo "stage 5: Inference" |
| | | |
| | | if ${inference_device} == "cuda"; then |
| | | if [ ${inference_device} == "cuda" ]; then |
| | | nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | else |
| | | inference_batch_size=1 |
| | |
| | | |
| | | for dset in ${test_sets}; do |
| | | |
| | | inference_dir="${exp_dir}/exp/${model_dir}/${inference_checkpoint}/${dset}" |
| | | inference_dir="${exp_dir}/exp/${model_dir}/inference-${inference_checkpoint}/${dset}" |
| | | _logdir="${inference_dir}/logdir" |
| | | echo "inference_dir: ${inference_dir}" |
| | | |
| | | mkdir -p "${_logdir}" |
| | | data_dir="${feats_dir}/data/${dset}" |
| | |
| | | done |
| | | utils/split_scp.pl "${key_file}" ${split_scps} |
| | | |
| | | gpuid_list_array=(${gpuid_list//,/ }) |
| | | gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ }) |
| | | for JOB in $(seq ${nj}); do |
| | | { |
| | | id=$((JOB-1)) |
| | |
| | | ++input="${_logdir}/keys.${JOB}.scp" \ |
| | | ++output_dir="${inference_dir}/${JOB}" \ |
| | | ++device="${inference_device}" \ |
| | | ++batch_size="${inference_batch_size}" |
| | | ++ncpu=1 \ |
| | | ++disable_log=true \ |
| | | ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt |
| | | }& |
| | | |
| | | done |
| | |
| | | done |
| | | |
| | | echo "Computing WER ..." |
| | | cp ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc |
| | | cp ${data_dir}/text ${inference_dir}/1best_recog/text.ref |
| | | python utils/postprocess_text_zh.py ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc |
| | | python utils/postprocess_text_zh.py ${data_dir}/text ${inference_dir}/1best_recog/text.ref |
| | | python utils/compute_wer.py ${inference_dir}/1best_recog/text.ref ${inference_dir}/1best_recog/text.proc ${inference_dir}/1best_recog/text.cer |
| | | tail -n 3 ${inference_dir}/1best_recog/text.cer |
| | | done |
| New file |
| | |
| | | |
| | | |
| | | python -m funasr.bin.inference \ |
| | | --config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \ |
| | | --config-name="config.yaml" \ |
| | | ++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \ |
| | | ++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \ |
| | | ++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \ |
| | | ++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \ |
| | | ++output_dir="./outputs/debug" \ |
| | | ++device="cuda:0" \ |
| | | |
| | |
| | | |
| | | # general configuration |
| | | feats_dir="../DATA" #feature output dictionary |
| | | exp_dir="." |
| | | exp_dir=`pwd` |
| | | lang=zh |
| | | token_type=char |
| | | stage=0 |
| | |
| | | |
| | | config=e_branchformer_12e_6d_2048_256.yaml |
| | | model_dir="baseline_$(basename "${config}" .yaml)_${lang}_${token_type}_${tag}" |
| | | |
| | | |
| | | |
| | | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then |
| | |
| | | log_file="${exp_dir}/exp/${model_dir}/train.log.txt.${current_time}" |
| | | echo "log_file: ${log_file}" |
| | | |
| | | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | |
| | | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then |
| | | echo "stage 5: Inference" |
| | | |
| | | if ${inference_device} == "cuda"; then |
| | | if [ ${inference_device} == "cuda" ]; then |
| | | nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | else |
| | | inference_batch_size=1 |
| | |
| | | |
| | | for dset in ${test_sets}; do |
| | | |
| | | inference_dir="${exp_dir}/exp/${model_dir}/${inference_checkpoint}/${dset}" |
| | | inference_dir="${exp_dir}/exp/${model_dir}/inference-${inference_checkpoint}/${dset}" |
| | | _logdir="${inference_dir}/logdir" |
| | | echo "inference_dir: ${inference_dir}" |
| | | |
| | | mkdir -p "${_logdir}" |
| | | data_dir="${feats_dir}/data/${dset}" |
| | |
| | | done |
| | | utils/split_scp.pl "${key_file}" ${split_scps} |
| | | |
| | | gpuid_list_array=(${gpuid_list//,/ }) |
| | | gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ }) |
| | | for JOB in $(seq ${nj}); do |
| | | { |
| | | id=$((JOB-1)) |
| | |
| | | ++input="${_logdir}/keys.${JOB}.scp" \ |
| | | ++output_dir="${inference_dir}/${JOB}" \ |
| | | ++device="${inference_device}" \ |
| | | ++batch_size="${inference_batch_size}" |
| | | ++ncpu=1 \ |
| | | ++disable_log=true \ |
| | | ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt |
| | | }& |
| | | |
| | | done |
| | |
| | | done |
| | | |
| | | echo "Computing WER ..." |
| | | cp ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc |
| | | cp ${data_dir}/text ${inference_dir}/1best_recog/text.ref |
| | | python utils/postprocess_text_zh.py ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc |
| | | python utils/postprocess_text_zh.py ${data_dir}/text ${inference_dir}/1best_recog/text.ref |
| | | python utils/compute_wer.py ${inference_dir}/1best_recog/text.ref ${inference_dir}/1best_recog/text.proc ${inference_dir}/1best_recog/text.cer |
| | | tail -n 3 ${inference_dir}/1best_recog/text.cer |
| | | done |
| New file |
| | |
| | | |
| | | |
| | | python -m funasr.bin.inference \ |
| | | --config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \ |
| | | --config-name="config.yaml" \ |
| | | ++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \ |
| | | ++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \ |
| | | ++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \ |
| | | ++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \ |
| | | ++output_dir="./outputs/debug" \ |
| | | ++device="cuda:0" \ |
| | | |
| | |
| | | |
| | | # general configuration |
| | | feats_dir="../DATA" #feature output dictionary |
| | | exp_dir="." |
| | | exp_dir=`pwd` |
| | | lang=zh |
| | | token_type=char |
| | | stage=0 |
| | |
| | | log_file="${exp_dir}/exp/${model_dir}/train.log.txt.${current_time}" |
| | | echo "log_file: ${log_file}" |
| | | |
| | | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | |
| | | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then |
| | | echo "stage 5: Inference" |
| | | |
| | | if ${inference_device} == "cuda"; then |
| | | if [ ${inference_device} == "cuda" ]; then |
| | | nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | else |
| | | inference_batch_size=1 |
| | |
| | | |
| | | for dset in ${test_sets}; do |
| | | |
| | | inference_dir="${exp_dir}/exp/${model_dir}/${inference_checkpoint}/${dset}" |
| | | inference_dir="${exp_dir}/exp/${model_dir}/inference-${inference_checkpoint}/${dset}" |
| | | _logdir="${inference_dir}/logdir" |
| | | echo "inference_dir: ${inference_dir}" |
| | | |
| | | mkdir -p "${_logdir}" |
| | | data_dir="${feats_dir}/data/${dset}" |
| | |
| | | done |
| | | utils/split_scp.pl "${key_file}" ${split_scps} |
| | | |
| | | gpuid_list_array=(${gpuid_list//,/ }) |
| | | gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ }) |
| | | for JOB in $(seq ${nj}); do |
| | | { |
| | | id=$((JOB-1)) |
| | |
| | | ++input="${_logdir}/keys.${JOB}.scp" \ |
| | | ++output_dir="${inference_dir}/${JOB}" \ |
| | | ++device="${inference_device}" \ |
| | | ++batch_size="${inference_batch_size}" |
| | | ++ncpu=1 \ |
| | | ++disable_log=true \ |
| | | ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt |
| | | }& |
| | | |
| | | done |
| | |
| | | done |
| | | |
| | | echo "Computing WER ..." |
| | | cp ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc |
| | | cp ${data_dir}/text ${inference_dir}/1best_recog/text.ref |
| | | python utils/postprocess_text_zh.py ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc |
| | | python utils/postprocess_text_zh.py ${data_dir}/text ${inference_dir}/1best_recog/text.ref |
| | | python utils/compute_wer.py ${inference_dir}/1best_recog/text.ref ${inference_dir}/1best_recog/text.proc ${inference_dir}/1best_recog/text.cer |
| | | tail -n 3 ${inference_dir}/1best_recog/text.cer |
| | | done |
| New file |
| | |
| | | |
| | | import sys |
| | | import re |
| | | |
| | | in_f = sys.argv[1] |
| | | out_f = sys.argv[2] |
| | | |
| | | |
| | | with open(in_f, "r", encoding="utf-8") as f: |
| | | lines = f.readlines() |
| | | |
| | | with open(out_f, "w", encoding="utf-8") as f: |
| | | for line in lines: |
| | | outs = line.strip().split(" ", 1) |
| | | if len(outs) == 2: |
| | | idx, text = outs |
| | | text = re.sub("</s>", "", text) |
| | | text = re.sub("<s>", "", text) |
| | | text = re.sub("@@", "", text) |
| | | text = re.sub("@", "", text) |
| | | text = re.sub("<unk>", "", text) |
| | | text = re.sub(" ", "", text) |
| | | text = text.lower() |
| | | else: |
| | | idx = outs[0] |
| | | text = " " |
| | | |
| | | text = [x for x in text] |
| | | text = " ".join(text) |
| | | out = "{} {}\n".format(idx, text) |
| | | f.write(out) |
| New file |
| | |
| | | |
| | | |
| | | python -m funasr.bin.inference \ |
| | | --config-path="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3" \ |
| | | --config-name="config.yaml" \ |
| | | ++init_param="/mnt/workspace/FunASR/examples/aishell/paraformer/exp/baseline_paraformer_conformer_12e_6d_2048_256_zh_char_exp3/model.pt.ep38" \ |
| | | ++tokenizer_conf.token_list="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/zh_token_list/char/tokens.txt" \ |
| | | ++frontend_conf.cmvn_file="/mnt/nfs/zhifu.gzf/data/AISHELL-1-feats/DATA/data/train/am.mvn" \ |
| | | ++input="/mnt/nfs/zhifu.gzf/data/AISHELL-1/data_aishell/wav/train/S0002/BAC009S0002W0122.wav" \ |
| | | ++output_dir="./outputs/debug" \ |
| | | ++device="cuda:0" \ |
| | | |
| | |
| | | |
| | | # general configuration |
| | | feats_dir="../DATA" #feature output dictionary |
| | | exp_dir="." |
| | | exp_dir=`pwd` |
| | | lang=zh |
| | | token_type=char |
| | | stage=0 |
| | |
| | | valid_set=dev |
| | | test_sets="dev test" |
| | | |
| | | config=paraformer_conformer_12e_6d_2048_256.yaml |
| | | config=transformer_12e_6d_2048_256.yaml |
| | | model_dir="baseline_$(basename "${config}" .yaml)_${lang}_${token_type}_${tag}" |
| | | |
| | | |
| | | |
| | | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then |
| | |
| | | log_file="${exp_dir}/exp/${model_dir}/train.log.txt.${current_time}" |
| | | echo "log_file: ${log_file}" |
| | | |
| | | export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES |
| | | gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | torchrun \ |
| | | --nnodes 1 \ |
| | |
| | | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then |
| | | echo "stage 5: Inference" |
| | | |
| | | if ${inference_device} == "cuda"; then |
| | | if [ ${inference_device} == "cuda" ]; then |
| | | nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | else |
| | | inference_batch_size=1 |
| | |
| | | |
| | | for dset in ${test_sets}; do |
| | | |
| | | inference_dir="${exp_dir}/exp/${model_dir}/${inference_checkpoint}/${dset}" |
| | | inference_dir="${exp_dir}/exp/${model_dir}/inference-${inference_checkpoint}/${dset}" |
| | | _logdir="${inference_dir}/logdir" |
| | | echo "inference_dir: ${inference_dir}" |
| | | |
| | | mkdir -p "${_logdir}" |
| | | data_dir="${feats_dir}/data/${dset}" |
| | |
| | | done |
| | | utils/split_scp.pl "${key_file}" ${split_scps} |
| | | |
| | | gpuid_list_array=(${gpuid_list//,/ }) |
| | | gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ }) |
| | | for JOB in $(seq ${nj}); do |
| | | { |
| | | id=$((JOB-1)) |
| | |
| | | ++input="${_logdir}/keys.${JOB}.scp" \ |
| | | ++output_dir="${inference_dir}/${JOB}" \ |
| | | ++device="${inference_device}" \ |
| | | ++batch_size="${inference_batch_size}" |
| | | ++ncpu=1 \ |
| | | ++disable_log=true \ |
| | | ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt |
| | | }& |
| | | |
| | | done |
| | |
| | | done |
| | | |
| | | echo "Computing WER ..." |
| | | cp ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc |
| | | cp ${data_dir}/text ${inference_dir}/1best_recog/text.ref |
| | | python utils/postprocess_text_zh.py ${inference_dir}/1best_recog/text ${inference_dir}/1best_recog/text.proc |
| | | python utils/postprocess_text_zh.py ${data_dir}/text ${inference_dir}/1best_recog/text.ref |
| | | python utils/compute_wer.py ${inference_dir}/1best_recog/text.ref ${inference_dir}/1best_recog/text.proc ${inference_dir}/1best_recog/text.cer |
| | | tail -n 3 ${inference_dir}/1best_recog/text.cer |
| | | done |
| | |
| | | vad_model_revision="v2.0.4", |
| | | punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", |
| | | punc_model_revision="v2.0.4", |
| | | spk_model="damo/speech_campplus_sv_zh-cn_16k-common", |
| | | spk_model_revision="v2.0.4", |
| | | # spk_model="damo/speech_campplus_sv_zh-cn_16k-common", |
| | | # spk_model_revision="v2.0.2", |
| | | ) |
| | | |
| | | res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60) |
| | |
| | | |
| | | from funasr import AutoModel |
| | | |
| | | model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4", |
| | | # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", |
| | | # vad_model_revision="v2.0.4", |
| | | # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", |
| | | # punc_model_revision="v2.0.4", |
| | | model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", |
| | | model_revision="v2.0.4", |
| | | vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", |
| | | vad_model_revision="v2.0.4", |
| | | punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", |
| | | punc_model_revision="v2.0.4", |
| | | # spk_model="iic/speech_campplus_sv_zh-cn_16k-common", |
| | | # spk_model_revision="v2.0.2", |
| | | ) |
| | | |
| | | res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav") |
| | |
| | | vad_model_revision="v2.0.4", |
| | | punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", |
| | | punc_model_revision="v2.0.4", |
| | | spk_model="damo/speech_campplus_sv_zh-cn_16k-common", |
| | | spk_model_revision="v2.0.2", |
| | | # spk_model="damo/speech_campplus_sv_zh-cn_16k-common", |
| | | # spk_model_revision="v2.0.2", |
| | | ) |
| | | |
| | | |
| | | # example1 |
| | | res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", |
| | | hotword='达摩院 魔搭', |
| | | # preset_spk_num=2, |
| | | # return_raw_text=True, # return raw text recognition results splited by space of equal length with timestamp |
| | | # preset_spk_num=2, # preset speaker num for speaker cluster model |
| | | # sentence_timestamp=True, # return sentence level information when spk_model is not given |
| | | ) |
| | | print(res) |
| | |
| | | import json |
| | | import time |
| | | import copy |
| | | import torch |
| | | import hydra |
| | | import random |
| | | import string |
| | | import logging |
| | | import os.path |
| | | import numpy as np |
| | | from tqdm import tqdm |
| | | from omegaconf import DictConfig, OmegaConf, ListConfig |
| | | |
| | | from funasr.register import tables |
| | | from funasr.utils.load_utils import load_bytes |
| | |
| | | from funasr.utils.vad_utils import slice_padding_audio_samples |
| | | from funasr.train_utils.set_all_random_seed import set_all_random_seed |
| | | from funasr.train_utils.load_pretrained_model import load_pretrained_model |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils.load_utils import load_audio_text_image_video |
| | | from funasr.utils.timestamp_tools import timestamp_sentence |
| | | from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk |
| | | try: |
| | |
| | | else: |
| | | result[k] += restored_data[j][k] |
| | | |
| | | return_raw_text = kwargs.get('return_raw_text', False) |
| | | # step.3 compute punc model |
| | | if self.punc_model is not None: |
| | | self.punc_kwargs.update(cfg) |
| | | punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg) |
| | | import copy; raw_text = copy.copy(result["text"]) |
| | | raw_text = copy.copy(result["text"]) |
| | | if return_raw_text: result['raw_text'] = raw_text |
| | | result["text"] = punc_res[0]["text"] |
| | | else: |
| | | raw_text = None |
| | | |
| | | # speaker embedding cluster after resorted |
| | | if self.spk_model is not None and kwargs.get('return_spk_res', True): |
| | | if raw_text is None: |
| | | logging.error("Missing punc_model, which is required by spk_model.") |
| | | all_segments = sorted(all_segments, key=lambda x: x[0]) |
| | | spk_embedding = result['spk_embedding'] |
| | | labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs.get('preset_spk_num', None)) |
| | |
| | | if self.spk_mode == 'vad_segment': # recover sentence_list |
| | | sentence_list = [] |
| | | for res, vadsegment in zip(restored_data, vadsegments): |
| | | sentence_list.append({"start": vadsegment[0],\ |
| | | if 'timestamp' not in res: |
| | | logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \ |
| | | and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\ |
| | | can predict timestamp, and speaker diarization relies on timestamps.") |
| | | sentence_list.append({"start": vadsegment[0], |
| | | "end": vadsegment[1], |
| | | "sentence": res['raw_text'], |
| | | "sentence": res['text'], |
| | | "timestamp": res['timestamp']}) |
| | | elif self.spk_mode == 'punc_segment': |
| | | sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ |
| | | result['timestamp'], \ |
| | | result['raw_text']) |
| | | if 'timestamp' not in result: |
| | | logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \ |
| | | and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\ |
| | | can predict timestamp, and speaker diarization relies on timestamps.") |
| | | sentence_list = timestamp_sentence(punc_res[0]['punc_array'], |
| | | result['timestamp'], |
| | | raw_text, |
| | | return_raw_text=return_raw_text) |
| | | distribute_spk(sentence_list, sv_output) |
| | | result['sentence_info'] = sentence_list |
| | | elif kwargs.get("sentence_timestamp", False): |
| | | sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \ |
| | | result['timestamp'], \ |
| | | result['raw_text']) |
| | | sentence_list = timestamp_sentence(punc_res[0]['punc_array'], |
| | | result['timestamp'], |
| | | raw_text, |
| | | return_raw_text=return_raw_text) |
| | | result['sentence_info'] = sentence_list |
| | | del result['spk_embedding'] |
| | | if "spk_embedding" in result: del result['spk_embedding'] |
| | | |
| | | result["key"] = key |
| | | results_ret_list.append(result) |
| | |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | |
| | | import time |
| | | import torch |
| | | import logging |
| | | import torch.nn as nn |
| | | from contextlib import contextmanager |
| | | from typing import Dict, Optional, Tuple |
| | | from distutils.version import LooseVersion |
| | | |
| | | from typing import Dict, List, Optional, Tuple, Union |
| | | |
| | | |
| | | from torch.cuda.amp import autocast |
| | | from funasr.losses.label_smoothing_loss import ( |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | ) |
| | | |
| | | from funasr.models.transformer.utils.nets_utils import get_transducer_task_io |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.register import tables |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.models.transformer.scorers.ctc import CTCPrefixScorer |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | from funasr.models.transformer.utils.nets_utils import get_transducer_task_io |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer |
| | | |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | | else: |
| | | # Nothing to do if torch<1.6.0 |
| | | @contextmanager |
| | | def autocast(enabled=True): |
| | | yield |
| | | |
| | | |
| | | |
| | | class BATModel(nn.Module): |
| | | """BATModel module definition. |
| | | |
| | | Args: |
| | | vocab_size: Size of complete vocabulary (w/ EOS and blank included). |
| | | token_list: List of token |
| | | frontend: Frontend module. |
| | | specaug: SpecAugment module. |
| | | normalize: Normalization module. |
| | | encoder: Encoder module. |
| | | decoder: Decoder module. |
| | | joint_network: Joint Network module. |
| | | transducer_weight: Weight of the Transducer loss. |
| | | fastemit_lambda: FastEmit lambda value. |
| | | auxiliary_ctc_weight: Weight of auxiliary CTC loss. |
| | | auxiliary_ctc_dropout_rate: Dropout rate for auxiliary CTC loss inputs. |
| | | auxiliary_lm_loss_weight: Weight of auxiliary LM loss. |
| | | auxiliary_lm_loss_smoothing: Smoothing rate for LM loss' label smoothing. |
| | | ignore_id: Initial padding ID. |
| | | sym_space: Space symbol. |
| | | sym_blank: Blank Symbol |
| | | report_cer: Whether to report Character Error Rate during validation. |
| | | report_wer: Whether to report Word Error Rate during validation. |
| | | extract_feats_in_collect_stats: Whether to use extract_feats stats collection. |
| | | |
| | | """ |
| | | |
| | | @tables.register("model_classes", "BAT") # TODO: BAT training |
| | | class BAT(torch.nn.Module): |
| | | def __init__( |
| | | self, |
| | | |
| | | cif_weight: float = 1.0, |
| | | frontend: Optional[str] = None, |
| | | frontend_conf: Optional[Dict] = None, |
| | | specaug: Optional[str] = None, |
| | | specaug_conf: Optional[Dict] = None, |
| | | normalize: str = None, |
| | | normalize_conf: Optional[Dict] = None, |
| | | encoder: str = None, |
| | | encoder_conf: Optional[Dict] = None, |
| | | decoder: str = None, |
| | | decoder_conf: Optional[Dict] = None, |
| | | joint_network: str = None, |
| | | joint_network_conf: Optional[Dict] = None, |
| | | transducer_weight: float = 1.0, |
| | | fastemit_lambda: float = 0.0, |
| | | auxiliary_ctc_weight: float = 0.0, |
| | | auxiliary_ctc_dropout_rate: float = 0.0, |
| | | auxiliary_lm_loss_weight: float = 0.0, |
| | | auxiliary_lm_loss_smoothing: float = 0.0, |
| | | input_size: int = 80, |
| | | vocab_size: int = -1, |
| | | ignore_id: int = -1, |
| | | sym_space: str = "<space>", |
| | | sym_blank: str = "<blank>", |
| | | report_cer: bool = True, |
| | | report_wer: bool = True, |
| | | extract_feats_in_collect_stats: bool = True, |
| | | blank_id: int = 0, |
| | | sos: int = 1, |
| | | eos: int = 2, |
| | | lsm_weight: float = 0.0, |
| | | length_normalized_loss: bool = False, |
| | | r_d: int = 5, |
| | | r_u: int = 5, |
| | | # report_cer: bool = True, |
| | | # report_wer: bool = True, |
| | | # sym_space: str = "<space>", |
| | | # sym_blank: str = "<blank>", |
| | | # extract_feats_in_collect_stats: bool = True, |
| | | share_embedding: bool = False, |
| | | # preencoder: Optional[AbsPreEncoder] = None, |
| | | # postencoder: Optional[AbsPostEncoder] = None, |
| | | **kwargs, |
| | | ) -> None: |
| | | """Construct an BATModel object.""" |
| | | ): |
| | | |
| | | super().__init__() |
| | | |
| | | # The following labels ID are reserved: 0 (blank) and vocab_size - 1 (sos/eos) |
| | | self.blank_id = 0 |
| | | self.vocab_size = vocab_size |
| | | self.ignore_id = ignore_id |
| | | self.token_list = token_list.copy() |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = tables.normalize_classes.get(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(input_size=input_size, **encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | | |
| | | self.sym_space = sym_space |
| | | self.sym_blank = sym_blank |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=vocab_size, |
| | | **decoder_conf, |
| | | ) |
| | | decoder_output_size = decoder.output_size |
| | | |
| | | self.frontend = frontend |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | |
| | | self.encoder = encoder |
| | | self.decoder = decoder |
| | | self.joint_network = joint_network |
| | | joint_network_class = tables.joint_network_classes.get(joint_network) |
| | | joint_network = joint_network_class( |
| | | vocab_size, |
| | | encoder_output_size, |
| | | decoder_output_size, |
| | | **joint_network_conf, |
| | | ) |
| | | |
| | | self.criterion_transducer = None |
| | | self.error_calculator = None |
| | |
| | | |
| | | self.auxiliary_ctc_weight = auxiliary_ctc_weight |
| | | self.auxiliary_lm_loss_weight = auxiliary_lm_loss_weight |
| | | self.blank_id = blank_id |
| | | self.sos = sos if sos is not None else vocab_size - 1 |
| | | self.eos = eos if eos is not None else vocab_size - 1 |
| | | self.vocab_size = vocab_size |
| | | self.ignore_id = ignore_id |
| | | self.frontend = frontend |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | self.encoder = encoder |
| | | self.decoder = decoder |
| | | self.joint_network = joint_network |
| | | |
| | | self.report_cer = report_cer |
| | | self.report_wer = report_wer |
| | | |
| | | self.extract_feats_in_collect_stats = extract_feats_in_collect_stats |
| | | |
| | | self.criterion_pre = torch.nn.L1Loss() |
| | | self.predictor_weight = predictor_weight |
| | | self.predictor = predictor |
| | | |
| | | self.cif_weight = cif_weight |
| | | if self.cif_weight > 0: |
| | | self.cif_output_layer = torch.nn.Linear(encoder.output_size(), vocab_size) |
| | | self.criterion_cif = LabelSmoothingLoss( |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=vocab_size, |
| | | padding_idx=ignore_id, |
| | | smoothing=lsm_weight, |
| | | normalize_length=length_normalized_loss, |
| | | ) |
| | | self.r_d = r_d |
| | | self.r_u = r_u |
| | | |
| | | self.length_normalized_loss = length_normalized_loss |
| | | self.beam_search = None |
| | | self.ctc = None |
| | | self.ctc_weight = 0.0 |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Forward architecture and compute loss(es). |
| | | |
| | | """Encoder + Decoder + Calc loss |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | text: Label ID sequences. (B, L) |
| | | text_lengths: Label ID sequences lengths. (B,) |
| | | kwargs: Contains "utts_id". |
| | | |
| | | Return: |
| | | loss: Main loss value. |
| | | stats: Task statistics. |
| | | weight: Task weights. |
| | | |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | assert text_lengths.dim() == 1, text_lengths.shape |
| | | assert ( |
| | | speech.shape[0] |
| | | == speech_lengths.shape[0] |
| | | == text.shape[0] |
| | | == text_lengths.shape[0] |
| | | ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size = speech.shape[0] |
| | | text = text[:, : text_lengths.max()] |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | if hasattr(self.encoder, 'overlap_chunk_cls') and self.encoder.overlap_chunk_cls is not None: |
| | | encoder_out, encoder_out_lens = self.encoder.overlap_chunk_cls.remove_chunk(encoder_out, encoder_out_lens, |
| | | chunk_outs=None) |
| | | |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(encoder_out.device) |
| | | # 2. Transducer-related I/O preparation |
| | | decoder_in, target, t_len, u_len = get_transducer_task_io( |
| | | text, |
| | |
| | | self.decoder.set_device(encoder_out.device) |
| | | decoder_out = self.decoder(decoder_in, u_len) |
| | | |
| | | pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=self.ignore_id) |
| | | loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length), pre_token_length) |
| | | |
| | | if self.cif_weight > 0.0: |
| | | cif_predict = self.cif_output_layer(pre_acoustic_embeds) |
| | | loss_cif = self.criterion_cif(cif_predict, text) |
| | | else: |
| | | loss_cif = 0.0 |
| | | # 4. Joint Network |
| | | joint_out = self.joint_network( |
| | | encoder_out.unsqueeze(2), decoder_out.unsqueeze(1) |
| | | ) |
| | | |
| | | # 5. Losses |
| | | boundary = torch.zeros((encoder_out.size(0), 4), dtype=torch.int64, device=encoder_out.device) |
| | | boundary[:, 2] = u_len.long().detach() |
| | | boundary[:, 3] = t_len.long().detach() |
| | | |
| | | pre_peak_index = torch.floor(pre_peak_index).long() |
| | | s_begin = pre_peak_index - self.r_d |
| | | |
| | | T = encoder_out.size(1) |
| | | B = encoder_out.size(0) |
| | | U = decoder_out.size(1) |
| | | |
| | | mask = torch.arange(0, T, device=encoder_out.device).reshape(1, T).expand(B, T) |
| | | mask = mask <= boundary[:, 3].reshape(B, 1) - 1 |
| | | |
| | | s_begin_padding = boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1 |
| | | # handle the cases where `len(symbols) < s_range` |
| | | s_begin_padding = torch.clamp(s_begin_padding, min=0) |
| | | |
| | | s_begin = torch.where(mask, s_begin, s_begin_padding) |
| | | |
| | | mask2 = s_begin < boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1 |
| | | |
| | | s_begin = torch.where(mask2, s_begin, boundary[:, 2].reshape(B, 1) - (self.r_u+self.r_d) + 1) |
| | | |
| | | s_begin = torch.clamp(s_begin, min=0) |
| | | |
| | | ranges = s_begin.reshape((B, T, 1)).expand((B, T, min(self.r_u+self.r_d, min(u_len)))) + torch.arange(min(self.r_d+self.r_u, min(u_len)), device=encoder_out.device) |
| | | |
| | | import fast_rnnt |
| | | am_pruned, lm_pruned = fast_rnnt.do_rnnt_pruning( |
| | | am=self.joint_network.lin_enc(encoder_out), |
| | | lm=self.joint_network.lin_dec(decoder_out), |
| | | ranges=ranges, |
| | | loss_trans, cer_trans, wer_trans = self._calc_transducer_loss( |
| | | encoder_out, |
| | | joint_out, |
| | | target, |
| | | t_len, |
| | | u_len, |
| | | ) |
| | | |
| | | logits = self.joint_network(am_pruned, lm_pruned, project_input=False) |
| | | |
| | | with torch.cuda.amp.autocast(enabled=False): |
| | | loss_trans = fast_rnnt.rnnt_loss_pruned( |
| | | logits=logits.float(), |
| | | symbols=target.long(), |
| | | ranges=ranges, |
| | | termination_symbol=self.blank_id, |
| | | boundary=boundary, |
| | | reduction="sum", |
| | | ) |
| | | |
| | | cer_trans, wer_trans = None, None |
| | | if not self.training and (self.report_cer or self.report_wer): |
| | | if self.error_calculator is None: |
| | | from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator |
| | | self.error_calculator = ErrorCalculator( |
| | | self.decoder, |
| | | self.joint_network, |
| | | self.token_list, |
| | | self.sym_space, |
| | | self.sym_blank, |
| | | report_cer=self.report_cer, |
| | | report_wer=self.report_wer, |
| | | ) |
| | | cer_trans, wer_trans = self.error_calculator(encoder_out, target, t_len) |
| | | |
| | | loss_ctc, loss_lm = 0.0, 0.0 |
| | | |
| | |
| | | self.transducer_weight * loss_trans |
| | | + self.auxiliary_ctc_weight * loss_ctc |
| | | + self.auxiliary_lm_loss_weight * loss_lm |
| | | + self.predictor_weight * loss_pre |
| | | + self.cif_weight * loss_cif |
| | | ) |
| | | |
| | | stats = dict( |
| | | loss=loss.detach(), |
| | | loss_transducer=loss_trans.detach(), |
| | | loss_pre=loss_pre.detach(), |
| | | loss_cif=loss_cif.detach() if loss_cif > 0.0 else None, |
| | | aux_ctc_loss=loss_ctc.detach() if loss_ctc > 0.0 else None, |
| | | aux_lm_loss=loss_lm.detach() if loss_lm > 0.0 else None, |
| | | cer_transducer=cer_trans, |
| | |
| | | |
| | | return loss, stats, weight |
| | | |
| | | def collect_feats( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | """Collect features sequences and features lengths sequences. |
| | | |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | text: Label ID sequences. (B, L) |
| | | text_lengths: Label ID sequences lengths. (B,) |
| | | kwargs: Contains "utts_id". |
| | | |
| | | Return: |
| | | {}: "feats": Features sequences. (B, T, D_feats), |
| | | "feats_lengths": Features sequences lengths. (B,) |
| | | |
| | | """ |
| | | if self.extract_feats_in_collect_stats: |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | else: |
| | | # Generate dummy stats if extract_feats_in_collect_stats is False |
| | | logging.warning( |
| | | "Generating dummy stats for feats and feats_lengths, " |
| | | "because encoder_conf.extract_feats_in_collect_stats is " |
| | | f"{self.extract_feats_in_collect_stats}" |
| | | ) |
| | | |
| | | feats, feats_lengths = speech, speech_lengths |
| | | |
| | | return {"feats": feats, "feats_lengths": feats_lengths} |
| | | |
| | | def encode( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encoder speech sequences. |
| | | |
| | | """Frontend + Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | |
| | | Return: |
| | | encoder_out: Encoder outputs. (B, T, D_enc) |
| | | encoder_out_lens: Encoder outputs lengths. (B,) |
| | | |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | ind: int |
| | | """ |
| | | with autocast(False): |
| | | # 1. Extract feats |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | |
| | | # 2. Data augmentation |
| | | # Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | feats, feats_lengths = self.specaug(feats, feats_lengths) |
| | | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| | | |
| | | # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | if self.normalize is not None: |
| | | feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | speech, speech_lengths = self.normalize(speech, speech_lengths) |
| | | |
| | | # 4. Forward encoder |
| | | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) |
| | | # Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | assert encoder_out.size(0) == speech.size(0), ( |
| | | encoder_out.size(), |
| | | speech.size(0), |
| | | ) |
| | | assert encoder_out.size(1) <= encoder_out_lens.max(), ( |
| | | encoder_out.size(), |
| | | encoder_out_lens.max(), |
| | | ) |
| | | if intermediate_outs is not None: |
| | | return (encoder_out, intermediate_outs), encoder_out_lens |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def _extract_feats( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Extract features sequences and features sequences lengths. |
| | | def _calc_transducer_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | joint_out: torch.Tensor, |
| | | target: torch.Tensor, |
| | | t_len: torch.Tensor, |
| | | u_len: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, Optional[float], Optional[float]]: |
| | | """Compute Transducer loss. |
| | | |
| | | Args: |
| | | speech: Speech sequences. (B, S) |
| | | speech_lengths: Speech sequences lengths. (B,) |
| | | encoder_out: Encoder output sequences. (B, T, D_enc) |
| | | joint_out: Joint Network output sequences (B, T, U, D_joint) |
| | | target: Target label ID sequences. (B, L) |
| | | t_len: Encoder output sequences lengths. (B,) |
| | | u_len: Target label ID sequences lengths. (B,) |
| | | |
| | | Return: |
| | | feats: Features sequences. (B, T, D_feats) |
| | | feats_lengths: Features sequences lengths. (B,) |
| | | loss_transducer: Transducer loss value. |
| | | cer_transducer: Character error rate for Transducer. |
| | | wer_transducer: Word Error Rate for Transducer. |
| | | |
| | | """ |
| | | assert speech_lengths.dim() == 1, speech_lengths.shape |
| | | if self.criterion_transducer is None: |
| | | try: |
| | | from warp_rnnt import rnnt_loss as RNNTLoss |
| | | self.criterion_transducer = RNNTLoss |
| | | |
| | | # for data-parallel |
| | | speech = speech[:, : speech_lengths.max()] |
| | | except ImportError: |
| | | logging.error( |
| | | "warp-rnnt was not installed." |
| | | "Please consult the installation documentation." |
| | | ) |
| | | exit(1) |
| | | |
| | | if self.frontend is not None: |
| | | feats, feats_lengths = self.frontend(speech, speech_lengths) |
| | | else: |
| | | feats, feats_lengths = speech, speech_lengths |
| | | log_probs = torch.log_softmax(joint_out, dim=-1) |
| | | |
| | | return feats, feats_lengths |
| | | loss_transducer = self.criterion_transducer( |
| | | log_probs, |
| | | target, |
| | | t_len, |
| | | u_len, |
| | | reduction="mean", |
| | | blank=self.blank_id, |
| | | fastemit_lambda=self.fastemit_lambda, |
| | | gather=True, |
| | | ) |
| | | |
| | | if not self.training and (self.report_cer or self.report_wer): |
| | | if self.error_calculator is None: |
| | | from funasr.metrics import ErrorCalculatorTransducer as ErrorCalculator |
| | | |
| | | self.error_calculator = ErrorCalculator( |
| | | self.decoder, |
| | | self.joint_network, |
| | | self.token_list, |
| | | self.sym_space, |
| | | self.sym_blank, |
| | | report_cer=self.report_cer, |
| | | report_wer=self.report_wer, |
| | | ) |
| | | |
| | | cer_transducer, wer_transducer = self.error_calculator(encoder_out, target, t_len) |
| | | |
| | | return loss_transducer, cer_transducer, wer_transducer |
| | | |
| | | return loss_transducer, None, None |
| | | |
| | | def _calc_ctc_loss( |
| | | self, |
| | |
| | | ) |
| | | |
| | | return loss_lm |
| | | |
| | | def init_beam_search(self, |
| | | **kwargs, |
| | | ): |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | |
| | | if self.ctc != None: |
| | | ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos) |
| | | scorers.update( |
| | | ctc=ctc |
| | | ) |
| | | token_list = kwargs.get("token_list") |
| | | scorers.update( |
| | | length_bonus=LengthBonus(len(token_list)), |
| | | ) |
| | | |
| | | # 3. Build ngram model |
| | | # ngram is not supported now |
| | | ngram = None |
| | | scorers["ngram"] = ngram |
| | | |
| | | beam_search = BeamSearchTransducer( |
| | | self.decoder, |
| | | self.joint_network, |
| | | kwargs.get("beam_size", 2), |
| | | nbest=1, |
| | | ) |
| | | # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() |
| | | # for scorer in scorers.values(): |
| | | # if isinstance(scorer, torch.nn.Module): |
| | | # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() |
| | | self.beam_search = beam_search |
| | | |
| | | def inference(self, |
| | | data_in: list, |
| | | data_lengths: list=None, |
| | | key: list=None, |
| | | tokenizer=None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | # init beamsearch |
| | | is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None |
| | | is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None |
| | | # if self.beam_search is None and (is_use_lm or is_use_ctc): |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | meta_data = {} |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=self.frontend.fs, audio_fs=kwargs.get("fs", 16000)) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=self.frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * self.frontend.frame_shift * self.frontend.lfr_n / 1000 |
| | | |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | # Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | | |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search(encoder_out[0], is_final=True) |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | b, n, d = encoder_out.size() |
| | | for i in range(b): |
| | | |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq#[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq#[1:last_pos].tolist() |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) |
| | | |
| | | # Change integer-ids to tokens |
| | | token = tokenizer.ids2tokens(token_int) |
| | | text = tokenizer.tokens2text(token) |
| | | |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text |
| | | ibest_writer["text_postprocessed"][key[i]] = text_postprocessed |
| | | |
| | | return results, meta_data |
| | | |
| | |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if ibest_writer is None and kwargs.get("output_dir") is not None: |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{nbest_idx + 1}best_recog"] |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx+1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if ibest_writer is None and kwargs.get("output_dir") is not None: |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{nbest_idx + 1}best_recog"] |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | |
| | | self.init_cache(cache) |
| | | |
| | | ibest_writer = None |
| | | if ibest_writer is None and kwargs.get("output_dir") is not None: |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{1}best_recog"] |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{1}best_recog"] |
| | | |
| | | results = [] |
| | | result_i = {"key": key[0], "value": segments} |
| | |
| | | |
| | | results = [] |
| | | ibest_writer = None |
| | | if ibest_writer is None and kwargs.get("output_dir") is not None: |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer["tp_res"] |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer["tp_res"] |
| | | |
| | | for i, (us_alpha, us_peak, token_int) in enumerate(zip(us_alphas, us_peaks, text_token_int_list)): |
| | | token = tokenizer.ids2tokens(token_int) |
| | | timestamp_str, timestamp = ts_prediction_lfr6_standard(us_alpha[:encoder_out_lens[i] * 3], |
| | |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if ibest_writer is None and kwargs.get("output_dir") is not None: |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{nbest_idx+1}best_recog"] |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx+1}best_recog"] |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | |
| | | result_i = {"key": key[i], "text": text_postprocessed} |
| | | |
| | | |
| | | if ibest_writer is not None: |
| | | ibest_writer["token"][key[i]] = " ".join(token) |
| | |
| | | self.init_cache(cache, **kwargs) |
| | | |
| | | if kwargs.get("output_dir"): |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{1}best_recog"] |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{1}best_recog"] |
| | | ibest_writer["token"][key[0]] = " ".join(tokens) |
| | | ibest_writer["text"][key[0]] = text_postprocessed |
| | | |
| | |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if ibest_writer is None and kwargs.get("output_dir") is not None: |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{nbest_idx + 1}best_recog"] |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | |
| | | token, timestamp) |
| | | |
| | | result_i = {"key": key[i], "text": text_postprocessed, |
| | | "timestamp": time_stamp_postprocessed, "raw_text": copy.copy(text_postprocessed) |
| | | "timestamp": time_stamp_postprocessed |
| | | } |
| | | |
| | | if ibest_writer is not None: |
| | | ibest_writer["token"][key[i]] = " ".join(token) |
| | | # ibest_writer["raw_text"][key[i]] = text |
| | | ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed |
| | | ibest_writer["text"][key[i]] = text_postprocessed |
| | | else: |
| | |
| | | |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if ibest_writer is None and kwargs.get("output_dir") is not None: |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{nbest_idx+1}best_recog"] |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | |
| | | |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if ibest_writer is None and kwargs.get("output_dir") is not None: |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{nbest_idx+1}best_recog"] |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | |
| | | |
| | | return assignment_map |
| | | |
| | | |
| | | def load_pretrained_model( |
| | | path: str, |
| | | model: torch.nn.Module, |
| | |
| | | """ |
| | | |
| | | obj = model |
| | | |
| | | dst_state = obj.state_dict() |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | print(f"ckpt: {path}") |
| | | if oss_bucket is None: |
| | | src_state = torch.load(path, map_location=map_location) |
| | | else: |
| | | buffer = BytesIO(oss_bucket.get_object(path).read()) |
| | | src_state = torch.load(buffer, map_location=map_location) |
| | | src_state = src_state["model"] if "model" in src_state else src_state |
| | | if "state_dict" in src_state: |
| | | src_state = src_state["state_dict"] |
| | | |
| | | if excludes is not None: |
| | | for e in excludes.split(","): |
| | | src_state = {k: v for k, v in src_state.items() if not k.startswith(e)} |
| | | for k in dst_state.keys(): |
| | | if not k.startswith("module.") and "module." + k in src_state.keys(): |
| | | k_ddp = "module." + k |
| | | else: |
| | | k_ddp = k |
| | | if k_ddp in src_state: |
| | | dst_state[k] = src_state[k_ddp] |
| | | else: |
| | | print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}") |
| | | |
| | | dst_state = obj.state_dict() |
| | | src_state = assigment_scope_map(dst_state, src_state, scope_map) |
| | | flag = obj.load_state_dict(dst_state, strict=True) |
| | | # print(flag) |
| | | |
| | | if ignore_init_mismatch: |
| | | src_state = filter_state_dict(dst_state, src_state) |
| | | |
| | | logging.debug("Loaded src_state keys: {}".format(src_state.keys())) |
| | | logging.debug("Loaded dst_state keys: {}".format(dst_state.keys())) |
| | | dst_state.update(src_state) |
| | | obj.load_state_dict(dst_state, strict=True) |
| | | # def load_pretrained_model( |
| | | # path: str, |
| | | # model: torch.nn.Module, |
| | | # ignore_init_mismatch: bool, |
| | | # map_location: str = "cpu", |
| | | # oss_bucket=None, |
| | | # scope_map=None, |
| | | # excludes=None, |
| | | # ): |
| | | # """Load a model state and set it to the model. |
| | | # |
| | | # Args: |
| | | # init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys> |
| | | # |
| | | # Examples: |
| | | # |
| | | # """ |
| | | # |
| | | # obj = model |
| | | # |
| | | # if oss_bucket is None: |
| | | # src_state = torch.load(path, map_location=map_location) |
| | | # else: |
| | | # buffer = BytesIO(oss_bucket.get_object(path).read()) |
| | | # src_state = torch.load(buffer, map_location=map_location) |
| | | # src_state = src_state["model"] if "model" in src_state else src_state |
| | | # |
| | | # if excludes is not None: |
| | | # for e in excludes.split(","): |
| | | # src_state = {k: v for k, v in src_state.items() if not k.startswith(e)} |
| | | # |
| | | # dst_state = obj.state_dict() |
| | | # src_state = assigment_scope_map(dst_state, src_state, scope_map) |
| | | # |
| | | # if ignore_init_mismatch: |
| | | # src_state = filter_state_dict(dst_state, src_state) |
| | | # |
| | | # logging.debug("Loaded src_state keys: {}".format(src_state.keys())) |
| | | # logging.debug("Loaded dst_state keys: {}".format(dst_state.keys())) |
| | | # dst_state.update(src_state) |
| | | # obj.load_state_dict(dst_state, strict=True) |
| | |
| | | import torch |
| | | import logging |
| | | from tqdm import tqdm |
| | | from datetime import datetime |
| | | import torch.distributed as dist |
| | | from contextlib import nullcontext |
| | | # from torch.utils.tensorboard import SummaryWriter |
| | |
| | | filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}') |
| | | torch.save(state, filename) |
| | | |
| | | print(f'Checkpoint saved to {filename}') |
| | | print(f'\nCheckpoint saved to {filename}\n') |
| | | latest = Path(os.path.join(self.output_dir, f'model.pt')) |
| | | try: |
| | | latest.unlink() |
| | | except: |
| | | pass |
| | | torch.save(state, latest) |
| | | |
| | | latest.symlink_to(filename) |
| | | |
| | | def _resume_checkpoint(self, resume_path): |
| | | """ |
| | |
| | | if os.path.isfile(ckpt): |
| | | checkpoint = torch.load(ckpt) |
| | | self.start_epoch = checkpoint['epoch'] + 1 |
| | | self.model.load_state_dict(checkpoint['state_dict']) |
| | | # self.model.load_state_dict(checkpoint['state_dict']) |
| | | src_state = checkpoint['state_dict'] |
| | | dst_state = self.model.state_dict() |
| | | for k in dst_state.keys(): |
| | | if not k.startswith("module.") and "module."+k in src_state.keys(): |
| | | k_ddp = "module."+k |
| | | else: |
| | | k_ddp = k |
| | | if k_ddp in src_state.keys(): |
| | | dst_state[k] = src_state[k_ddp] |
| | | else: |
| | | print(f"Miss key in ckpt: model: {k}, ckpt: {k_ddp}") |
| | | |
| | | self.model.load_state_dict(dst_state) |
| | | self.optim.load_state_dict(checkpoint['optimizer']) |
| | | self.scheduler.load_state_dict(checkpoint['scheduler']) |
| | | print(f"Checkpoint loaded successfully from '{ckpt}'") |
| | |
| | | self._resume_checkpoint(self.output_dir) |
| | | |
| | | for epoch in range(self.start_epoch, self.max_epoch + 1): |
| | | |
| | | time1 = time.perf_counter() |
| | | self._train_epoch(epoch) |
| | | |
| | | |
| | |
| | | |
| | | self.scheduler.step() |
| | | |
| | | time2 = time.perf_counter() |
| | | time_escaped = (time2 - time1)/3600.0 |
| | | print(f"\ntime_escaped_epoch: {time_escaped:.3f} hours, estimated to finish {self.max_epoch} epoch: {(self.max_epoch-epoch)*time_escaped:.3f}\n") |
| | | |
| | | if self.rank == 0: |
| | | average_checkpoints(self.output_dir, self.avg_nbest_model) |
| | |
| | | torch.cuda.memory_reserved()/1024/1024/1024, |
| | | torch.cuda.max_memory_reserved()/1024/1024/1024, |
| | | ) |
| | | lr = self.scheduler.get_last_lr()[0] |
| | | time_now = datetime.now() |
| | | time_now = time_now.strftime("%Y-%m-%d %H:%M:%S") |
| | | description = ( |
| | | f"{time_now}, " |
| | | f"rank: {self.local_rank}, " |
| | | f"epoch: {epoch}/{self.max_epoch}, " |
| | | f"step: {batch_idx}/{len(self.dataloader_train)}, total: {self.batch_total}, " |
| | | f"step: {batch_idx+1}/{len(self.dataloader_train)}, total: {self.batch_total}, " |
| | | f"(loss: {loss.detach().cpu().item():.3f}), " |
| | | f"(lr: {lr:.3e}), " |
| | | f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, " |
| | | f"{speed_stats}, " |
| | | f"{gpu_info}" |
| | |
| | | description = ( |
| | | f"rank: {self.local_rank}, " |
| | | f"validation epoch: {epoch}/{self.max_epoch}, " |
| | | f"step: {batch_idx}/{len(self.dataloader_val)}, " |
| | | f"step: {batch_idx+1}/{len(self.dataloader_val)}, " |
| | | f"(loss: {loss.detach().cpu().item():.3f}), " |
| | | f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}, " |
| | | f"{speed_stats}, " |
| | |
| | | return res_txt, res |
| | | |
| | | |
| | | def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed): |
| | | def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False): |
| | | punc_list = [',', '。', '?', '、'] |
| | | res = [] |
| | | if text_postprocessed is None: |
| | |
| | | |
| | | punc_id = int(punc_id) if punc_id is not None else 1 |
| | | sentence_end = timestamp[1] if timestamp is not None else sentence_end |
| | | |
| | | sentence_text_seg = sentence_text_seg[:-1] if sentence_text_seg[-1] == ' ' else sentence_text_seg |
| | | if punc_id > 1: |
| | | sentence_text += punc_list[punc_id - 2] |
| | | if return_raw_text: |
| | | res.append({ |
| | | 'text': sentence_text, |
| | | "start": sentence_start, |
| | | "end": sentence_end, |
| | | "timestamp": ts_list |
| | | "timestamp": ts_list, |
| | | 'raw_text': sentence_text_seg, |
| | | }) |
| | | else: |
| | | res.append({ |
| | | 'text': sentence_text, |
| | | "start": sentence_start, |
| | | "end": sentence_end, |
| | | "timestamp": ts_list, |
| | | }) |
| | | sentence_text = '' |
| | | sentence_text_seg = '' |