From 2196844d1d6e5b8732c95896bb46f0eacdd9cf9d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 25 九月 2024 15:10:50 +0800
Subject: [PATCH] Dev kws (#2105)
---
examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_250_128_fdim80_t2599.yaml | 95 +
examples/industrial_data_pretraining/fsmn_kws/convert.sh | 26
funasr/models/fsmn_kws_mt/model.py | 353 +++
examples/industrial_data_pretraining/fsmn_kws_mt/path.sh | 5
funasr/models/sanm_kws/model.py | 266 ++
funasr/datasets/audio_datasets/scp2jsonl.py | 15
funasr/models/sanm_kws_streaming/model.py | 442 ++++
examples/industrial_data_pretraining/sanm_kws_streaming/funasr | 1
funasr/models/sanm_kws/export_meta.py | 98 +
examples/industrial_data_pretraining/sanm_kws_streaming/export.sh | 17
examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh | 44
examples/industrial_data_pretraining/fsmn_kws/finetune.sh | 173 +
funasr/models/fsmn_kws_mt/encoder.py | 213 ++
funasr/models/fsmn_kws/__init__.py | 0
funasr/tokenizer/char_tokenizer.py | 2
examples/industrial_data_pretraining/fsmn_kws/path.sh | 5
funasr/train_utils/average_nbest_models.py | 5
examples/industrial_data_pretraining/sanm_kws/export.sh | 17
examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh | 186 ++
funasr/models/fsmn_vad_streaming/encoder.py | 136 -
funasr/utils/type_utils.py | 0
examples/industrial_data_pretraining/fsmn_kws_mt/convert.py | 141 +
examples/industrial_data_pretraining/sanm_kws/funasr | 1
funasr/models/sanm_kws/__init__.py | 0
examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml | 103 +
funasr/utils/compute_det_ctc.py | 286 +++
funasr/models/fsmn_kws_mt/__init__.py | 0
examples/industrial_data_pretraining/fsmn_kws/convert.py | 134 +
examples/industrial_data_pretraining/sanm_kws_streaming/finetune.sh | 258 ++
examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh | 20
funasr/train_utils/trainer.py | 93
examples/industrial_data_pretraining/sanm_kws/infer.sh | 20
funasr/models/transformer/scorers/ctc.py | 1
examples/industrial_data_pretraining/fsmn_kws_mt/funasr | 1
funasr/download/download_model_from_hub.py | 37
examples/industrial_data_pretraining/fsmn_kws/funasr | 1
funasr/datasets/kws_datasets/datasets.py | 132 +
funasr/models/fsmn_kws/model.py | 285 +++
examples/industrial_data_pretraining/sanm_kws/finetune.sh | 172 +
funasr/models/sanm_kws_streaming/export_meta.py | 98 +
funasr/bin/train.py | 3
examples/industrial_data_pretraining/sanm_kws/path.sh | 5
funasr/utils/kws_utils.py | 284 +++
examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_280_200_fdim40_t2602.yaml | 95 +
examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml | 103 +
funasr/datasets/audio_datasets/datasets.py | 33
examples/industrial_data_pretraining/sanm_kws_streaming/path.sh | 5
funasr/datasets/kws_datasets/__init__.py | 0
examples/industrial_data_pretraining/fsmn_kws/infer.sh | 20
examples/industrial_data_pretraining/sanm_kws_streaming/conf/sanm_6e_320_256_fdim40_t2602.yaml | 109 +
funasr/models/fsmn_kws/encoder.py | 534 +++++
examples/industrial_data_pretraining/sanm_kws/infer_from_local.sh | 41
funasr/auto/auto_model.py | 69
funasr/models/sanm_kws_streaming/__init__.py | 0
examples/industrial_data_pretraining/sanm_kws/conf/sanm_6e_320_256_fdim40_t2602.yaml | 94 +
examples/industrial_data_pretraining/sanm_kws_streaming/infer.sh | 34
funasr/train_utils/load_pretrained_model.py | 7
examples/industrial_data_pretraining/fsmn_kws/infer_from_local.sh | 41
examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh | 36
funasr/models/ctc/ctc.py | 29
examples/industrial_data_pretraining/sanm_kws_streaming/infer_from_local.sh | 62
funasr/models/sanm/encoder.py | 8
62 files changed, 5,302 insertions(+), 192 deletions(-)
diff --git a/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_250_128_fdim80_t2599.yaml b/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_250_128_fdim80_t2599.yaml
new file mode 100644
index 0000000..666da0c
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_250_128_fdim80_t2599.yaml
@@ -0,0 +1,95 @@
+
+# network architecture
+model: FsmnKWS
+model_conf:
+ ctc_weight: 1.0
+
+# encoder related
+encoder: FSMN
+encoder_conf:
+ input_dim: 400
+ input_affine_dim: 140
+ fsmn_layers: 4
+ linear_dim: 250
+ proj_dim: 128
+ lorder: 10
+ rorder: 2
+ lstride: 1
+ rstride: 1
+ output_affine_dim: 140
+ output_dim: 2599
+ use_softmax: false
+
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 5
+ lfr_n: 3
+
+specaug: SpecAugLFR
+specaug_conf:
+ apply_time_warp: false
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ lfr_rate: 3
+ num_freq_mask: 1
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 12
+ num_time_mask: 1
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 100
+ keep_nbest_models: 10
+ avg_nbest_model: 10
+ avg_keep_nbest_models_type: loss
+ validate_interval: 50000
+ save_checkpoint_interval: 50000
+ avg_checkpoint_interval: 1000
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 10000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: EspnetStyleBatchSampler
+ batch_type: length # example or length
+ batch_size: 32000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 2048
+ shuffle: true
+ num_workers: 8
+ preprocessor_speech: SpeechPreprocessSpeedPerturb
+ preprocessor_speech_conf:
+ speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+ split_with_space: true
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+ extra_linear: false
+
+normalize: null
diff --git a/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_280_200_fdim40_t2602.yaml b/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_280_200_fdim40_t2602.yaml
new file mode 100644
index 0000000..6ad59de
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_280_200_fdim40_t2602.yaml
@@ -0,0 +1,95 @@
+
+# network architecture
+model: FsmnKWS
+model_conf:
+ ctc_weight: 1.0
+
+# encoder related
+encoder: FSMN
+encoder_conf:
+ input_dim: 360
+ input_affine_dim: 280
+ fsmn_layers: 4
+ linear_dim: 280
+ proj_dim: 200
+ lorder: 10
+ rorder: 2
+ lstride: 1
+ rstride: 1
+ output_affine_dim: 400
+ output_dim: 2602
+ use_softmax: false
+
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 40
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 9
+ lfr_n: 3
+
+specaug: SpecAugLFR
+specaug_conf:
+ apply_time_warp: false
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ lfr_rate: 3
+ num_freq_mask: 1
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 12
+ num_time_mask: 1
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 100
+ keep_nbest_models: 10
+ avg_nbest_model: 10
+ avg_keep_nbest_models_type: loss
+ validate_interval: 50000
+ save_checkpoint_interval: 50000
+ avg_checkpoint_interval: 1000
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 10000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: EspnetStyleBatchSampler
+ batch_type: length # example or length
+ batch_size: 32000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 2048
+ shuffle: true
+ num_workers: 8
+ preprocessor_speech: SpeechPreprocessSpeedPerturb
+ preprocessor_speech_conf:
+ speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+ split_with_space: true
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+ extra_linear: false
+
+normalize: null
diff --git a/examples/industrial_data_pretraining/fsmn_kws/convert.py b/examples/industrial_data_pretraining/fsmn_kws/convert.py
new file mode 100644
index 0000000..1609ef4
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/convert.py
@@ -0,0 +1,134 @@
+from __future__ import print_function
+
+import argparse
+import copy
+import logging
+import os
+from shutil import copyfile
+
+import torch
+import yaml
+from typing import Union
+
+
+from funasr.models.fsmn_kws.model import FsmnKWSConvert
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ description=
+ 'load and convert network to each other between kaldi/pytorch format')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument(
+ '--network_file',
+ default='',
+ required=True,
+ help='input network, support kaldi.txt/pytorch.pt')
+ parser.add_argument('--model_dir', required=True, help='save model dir')
+ parser.add_argument('--model_name', required=True, help='save model name')
+ parser.add_argument('--convert_to',
+ default='kaldi',
+ required=True,
+ help='target network type, kaldi/pytorch')
+
+ args = parser.parse_args()
+ return args
+
+
+def convert_to_kaldi(
+ configs,
+ network_file,
+ model_dir,
+ model_name="convert.kaldi.txt"
+):
+ copyfile(network_file, os.path.join(model_dir, 'origin.torch.pt'))
+
+ model = FsmnKWSConvert(
+ vocab_size=configs['encoder_conf']['output_dim'],
+ encoder='FSMNConvert',
+ encoder_conf=configs['encoder_conf'],
+ ctc_conf=configs['ctc_conf'],
+ )
+ print(model)
+ num_params = count_parameters(model)
+ print('the number of model params: {}'.format(num_params))
+
+ states= torch.load(network_file, map_location='cpu')
+ model.load_state_dict(states["state_dict"])
+
+ kaldi_text = os.path.join(model_dir, model_name)
+ with open(kaldi_text, 'w', encoding='utf8') as fout:
+ nnet_desp = model.to_kaldi_net()
+ fout.write(nnet_desp)
+ fout.close()
+
+
+def convert_to_pytorch(
+ configs,
+ network_file,
+ model_dir,
+ model_name="convert.torch.pt"
+):
+ model = FsmnKWSConvert(
+ vocab_size=configs['encoder_conf']['output_dim'],
+ frontend=None,
+ specaug=None,
+ normalize=None,
+ encoder='FSMNConvert',
+ encoder_conf=configs['encoder_conf'],
+ ctc_conf=configs['ctc_conf'],
+ )
+
+ num_params = count_parameters(model)
+ print('the number of model params: {}'.format(num_params))
+
+ copyfile(network_file, os.path.join(model_dir, 'origin.kaldi.txt'))
+ model.to_pytorch_net(network_file)
+
+ save_model_path = os.path.join(model_dir, model_name)
+ torch.save({"model": model.state_dict()}, save_model_path)
+
+ print('convert torch format back to kaldi')
+ kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt')
+ with open(kaldi_text, 'w', encoding='utf8') as fout:
+ nnet_desp = model.to_kaldi_net()
+ fout.write(nnet_desp)
+ fout.close()
+
+ print('Done!')
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+ print(args)
+ with open(args.config, 'r') as fin:
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
+
+ if args.convert_to == 'pytorch':
+ print('convert kaldi net to pytorch...')
+ convert_to_pytorch(
+ configs,
+ args.network_file,
+ args.model_dir,
+ args.model_name
+ )
+ elif args.convert_to == 'kaldi':
+ print('convert pytorch net to kaldi...')
+ convert_to_kaldi(
+ configs,
+ args.network_file,
+ args.model_dir,
+ args.model_name
+ )
+ else:
+ print('unsupported target network type: {}'.format(args.convert_to))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/industrial_data_pretraining/fsmn_kws/convert.sh b/examples/industrial_data_pretraining/fsmn_kws/convert.sh
new file mode 100644
index 0000000..1ec12f0
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/convert.sh
@@ -0,0 +1,26 @@
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models_kws
+mkdir -p ${local_path_root}
+
+local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
+if [ ! -d "$local_path" ]; then
+ git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
+fi
+
+export PATH=${local_path}/runtime:$PATH
+export LD_LIBRARY_PATH=${local_path}/runtime:$LD_LIBRARY_PATH
+
+config=./conf/fsmn_4e_l10r2_250_128_fdim80_t2599.yaml
+torch_nnet=exp/finetune_outputs/model.pt.avg10
+out_dir=exp/finetune_outputs
+
+if [ ! -d "$out_dir" ]; then
+ mkdir -p $out_dir
+fi
+
+python convert.py --config $config --network_file $torch_nnet --model_dir $out_dir --model_name "convert.kaldi.txt" --convert_to kaldi
+
+nnet-copy --binary=true ${out_dir}/convert.kaldi.txt ${out_dir}/convert.kaldi.net
diff --git a/examples/industrial_data_pretraining/fsmn_kws/finetune.sh b/examples/industrial_data_pretraining/fsmn_kws/finetune.sh
new file mode 100755
index 0000000..59a2922
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/finetune.sh
@@ -0,0 +1,173 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+. ./path.sh
+workspace=`pwd`
+
+CUDA_VISIBLE_DEVICES="0,1"
+
+stage=2
+stop_stage=3
+
+inference_device="cuda" #"cpu"
+inference_checkpoint="model.pt.avg10"
+inference_scp="wav.scp"
+inference_batch_size=32
+nj=32
+test_sets="test"
+
+# model_name from model_hub, or model_dir in local path
+
+## option 1, download model automatically, unsupported currently
+model_name_or_model_dir="iic/speech_charctc_kws_phone-xiaoyun"
+
+## option 2, download model by git
+local_path_root=${workspace}/modelscope_models
+model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
+if [ ! -d $model_name_or_model_dir ]; then
+ mkdir -p ${model_name_or_model_dir}
+ git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${model_name_or_model_dir}
+fi
+
+config=fsmn_4e_l10r2_250_128_fdim80_t2599.yaml
+token_list=${model_name_or_model_dir}/funasr/tokens_2599.txt
+lexicon_list=${model_name_or_model_dir}/funasr/lexicon.txt
+cmvn_file=${model_name_or_model_dir}/funasr/am.mvn.dim80_l2r2
+init_param="${model_name_or_model_dir}/funasr/basetrain_fsmn_4e_l10r2_250_128_fdim80_t2599.pt"
+
+
+# data prepare
+# data dir, which contains: train.json, val.json
+data_dir=../../data
+
+train_data="${data_dir}/train.jsonl"
+val_data="${data_dir}/val.jsonl"
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Generate audio json list"
+ # generate train.jsonl and val.jsonl from wav.scp and text.txt
+ python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+ ++scp_file_list='['''${data_dir}/train_wav.scp''', '''${data_dir}/train_text.txt''']' \
+ ++data_type_list='["source", "target"]' \
+ ++jsonl_file_out="${train_data}"
+
+ python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+ ++scp_file_list='['''${data_dir}/val_wav.scp''', '''${data_dir}/val_text.txt''']' \
+ ++data_type_list='["source", "target"]' \
+ ++jsonl_file_out="${val_data}"
+fi
+
+# exp output dir
+output_dir="${workspace}/exp/finetune_outputs"
+
+# Training Stage
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: KWS Training"
+
+ mkdir -p ${output_dir}
+ current_time=$(date "+%Y-%m-%d_%H-%M")
+ log_file="${output_dir}/train.log.txt.${current_time}"
+ echo "log_file: ${log_file}"
+ echo "finetune use basetrain model: ${init_param}"
+
+ export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
+ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+ torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+ ../../../funasr/bin/train.py \
+ --config-path "${workspace}/conf" \
+ --config-name "${config}" \
+ ++init_param="${init_param}" \
+ ++disable_update=true \
+ ++train_data_set_list="${train_data}" \
+ ++valid_data_set_list="${val_data}" \
+ ++tokenizer_conf.token_list="${token_list}" \
+ ++tokenizer_conf.seg_dict="${lexicon_list}" \
+ ++frontend_conf.cmvn_file="${cmvn_file}" \
+ ++output_dir="${output_dir}" &> ${log_file}
+fi
+
+
+# Testing Stage
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: Inference"
+ keywords=(灏忎簯灏忎簯)
+ keywords_string=$(IFS=,; echo "${keywords[*]}")
+ echo "keywords: $keywords_string"
+
+ if [ ${inference_device} == "cuda" ]; then
+ nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+ else
+ inference_batch_size=1
+ CUDA_VISIBLE_DEVICES=""
+ for JOB in $(seq ${nj}); do
+ CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+ done
+ fi
+
+ for dset in ${test_sets}; do
+ inference_dir="${output_dir}/inference-${inference_checkpoint}/${dset}"
+ _logdir="${inference_dir}/logdir"
+ echo "inference_dir: ${inference_dir}"
+
+ mkdir -p "${_logdir}"
+ test_data_dir="${data_dir}/${dset}"
+ key_file=${test_data_dir}/${inference_scp}
+
+ split_scps=
+ for JOB in $(seq "${nj}"); do
+ split_scps+=" ${_logdir}/keys.${JOB}.scp"
+ done
+ $FUNASR_DIR/examples/aishell/paraformer/utils/split_scp.pl "${key_file}" ${split_scps}
+
+ gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+ for JOB in $(seq ${nj}); do
+ {
+ id=$((JOB-1))
+ gpuid=${gpuid_list_array[$id]}
+
+ echo "${output_dir}"
+
+ export CUDA_VISIBLE_DEVICES=${gpuid}
+ python ../../../funasr/bin/inference.py \
+ --config-path="${output_dir}" \
+ --config-name="config.yaml" \
+ ++init_param="${output_dir}/${inference_checkpoint}" \
+ ++tokenizer_conf.token_list="${token_list}" \
+ ++tokenizer_conf.seg_dict="${lexicon_list}" \
+ ++frontend_conf.cmvn_file="${cmvn_file}" \
+ ++keywords="\"$keywords_string"\" \
+ ++input="${_logdir}/keys.${JOB}.scp" \
+ ++output_dir="${inference_dir}/${JOB}" \
+ ++device="${inference_device}" \
+ ++ncpu=1 \
+ ++disable_log=true \
+ ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
+ }&
+
+ done
+ wait
+
+ for f in detect; do
+ if [ -f "${inference_dir}/${JOB}/${f}" ]; then
+ for JOB in $(seq "${nj}"); do
+ cat "${inference_dir}/${JOB}/${f}"
+ done | sort -k1 >"${inference_dir}/${f}"
+ fi
+ done
+
+ python funasr/utils/compute_det_ctc.py \
+ --keywords ${keywords_string} \
+ --test_data ${test_data_dir}/wav.scp \
+ --trans_data ${test_data_dir}/text \
+ --score_file ${inference_dir}/detect \
+ --stats_dir ${inference_dir}
+ done
+
+fi
diff --git a/examples/industrial_data_pretraining/fsmn_kws/funasr b/examples/industrial_data_pretraining/fsmn_kws/funasr
new file mode 120000
index 0000000..39a970f
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/funasr
@@ -0,0 +1 @@
+../../../funasr
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_kws/infer.sh b/examples/industrial_data_pretraining/fsmn_kws/infer.sh
new file mode 100644
index 0000000..6e03b89
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/infer.sh
@@ -0,0 +1,20 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+model="iic/speech_charctc_kws_phone-xiaoyun"
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python funasr/bin/inference.py \
++model=${model} \
++input=${input} \
++output_dir="./outputs/debug" \
++device="cpu" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/fsmn_kws/infer_from_local.sh b/examples/industrial_data_pretraining/fsmn_kws/infer_from_local.sh
new file mode 100644
index 0000000..ca1b7c8
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/infer_from_local.sh
@@ -0,0 +1,41 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method2, inference from local model
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+output_dir="./outputs/debug"
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models
+mkdir -p ${local_path_root}
+local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
+git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
+
+device="cuda:0" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
+
+config="inference_fsmn_4e_l10r2_250_128_fdim80_t2599.yaml"
+tokens="${local_path}/funasr/tokens_2599.txt"
+seg_dict="${local_path}/funasr/lexicon.txt"
+init_param="${local_path}/funasr/finetune_fsmn_4e_l10r2_250_128_fdim80_t2599_xiaoyun_xiaoyun.pt"
+cmvn_file="${local_path}/funasr/am.mvn.dim80_l2r2"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python -m funasr.bin.inference \
+--config-path "${local_path}/funasr" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++tokenizer_conf.token_list="${tokens}" \
+++tokenizer_conf.seg_dict="${seg_dict}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++device="${device}" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/fsmn_kws/path.sh b/examples/industrial_data_pretraining/fsmn_kws/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml b/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml
new file mode 100644
index 0000000..fdf52e4
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml
@@ -0,0 +1,103 @@
+
+# network architecture
+model: FsmnKWSMT
+model_conf:
+ ctc_weight: 1.0
+
+# encoder related
+encoder: FSMNMT
+encoder_conf:
+ input_dim: 400
+ input_affine_dim: 140
+ fsmn_layers: 4
+ linear_dim: 250
+ proj_dim: 128
+ lorder: 10
+ rorder: 2
+ lstride: 1
+ rstride: 1
+ output_affine_dim: 140
+ output_dim: 2599
+ output_dim2: 4
+ use_softmax: false
+
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 5
+ lfr_n: 3
+
+specaug: SpecAugLFR
+specaug_conf:
+ apply_time_warp: false
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ lfr_rate: 3
+ num_freq_mask: 1
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 12
+ num_time_mask: 1
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 100
+ keep_nbest_models: 100
+ avg_nbest_model: 10
+ avg_keep_nbest_models_type: loss
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 10000
+
+dataset: KwsMTDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: EspnetStyleBatchSampler
+ batch_type: length # example or length
+ batch_size: 64000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 2048
+ shuffle: true
+ num_workers: 8
+ preprocessor_speech: SpeechPreprocessSpeedPerturb
+ preprocessor_speech_conf:
+ speed_perturb: [0.9, 1.0, 1.1]
+ dataloader: DataloaderMapStyle
+
+tokenizer:
+ - CharTokenizer
+ - CharTokenizer
+
+tokenizer_conf:
+ - unk_symbol: <unk>
+ split_with_space: true
+ token_list: null
+ seg_dict: null
+ - unk_symbol: <unk>
+ split_with_space: true
+ token_list: null
+ seg_dict: null
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin # ctc_type: focalctc, builtin
+ reduce: true
+ ignore_nan_grad: true
+ extra_linear: false
+
+normalize: null
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml b/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml
new file mode 100644
index 0000000..ca9413e
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml
@@ -0,0 +1,103 @@
+
+# network architecture
+model: FsmnKWSMT
+model_conf:
+ ctc_weight: 1.0
+
+# encoder related
+encoder: FSMNMT
+encoder_conf:
+ input_dim: 360
+ input_affine_dim: 280
+ fsmn_layers: 4
+ linear_dim: 280
+ proj_dim: 200
+ lorder: 10
+ rorder: 2
+ lstride: 1
+ rstride: 1
+ output_affine_dim: 400
+ output_dim: 2602
+ output_dim2: 4
+ use_softmax: false
+
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 40
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 9
+ lfr_n: 3
+
+specaug: SpecAugLFR
+specaug_conf:
+ apply_time_warp: false
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ lfr_rate: 3
+ num_freq_mask: 1
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 12
+ num_time_mask: 1
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 100
+ keep_nbest_models: 100
+ avg_nbest_model: 10
+ avg_keep_nbest_models_type: loss
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 10000
+
+dataset: KwsMTDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: EspnetStyleBatchSampler
+ batch_type: length # example or length
+ batch_size: 64000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 2048
+ shuffle: true
+ num_workers: 8
+ preprocessor_speech: SpeechPreprocessSpeedPerturb
+ preprocessor_speech_conf:
+ speed_perturb: [0.9, 1.0, 1.1]
+ dataloader: DataloaderMapStyle
+
+tokenizer:
+ - CharTokenizer
+ - CharTokenizer
+
+tokenizer_conf:
+ - unk_symbol: <unk>
+ split_with_space: true
+ token_list: null
+ seg_dict: null
+ - unk_symbol: <unk>
+ split_with_space: true
+ token_list: null
+ seg_dict: null
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin # ctc_type: focalctc, builtin
+ reduce: true
+ ignore_nan_grad: true
+ extra_linear: false
+
+normalize: null
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/convert.py b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.py
new file mode 100644
index 0000000..e63e689
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.py
@@ -0,0 +1,141 @@
+from __future__ import print_function
+
+import argparse
+import copy
+import logging
+import os
+from shutil import copyfile
+
+import torch
+import yaml
+from typing import Union
+from funasr.models.fsmn_kws_mt.encoder import FSMNMTConvert
+from funasr.models.fsmn_kws_mt.model import FsmnKWSMTConvert
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+def get_args():
+ parser = argparse.ArgumentParser(
+ description=
+ 'load and convert network to each other between kaldi/pytorch format')
+ parser.add_argument('--config', required=True, help='config file')
+ parser.add_argument(
+ '--network_file',
+ default='',
+ required=True,
+ help='input network, support kaldi.txt/pytorch.pt')
+ parser.add_argument('--model_dir', required=True, help='save model dir')
+ parser.add_argument('--model_name', required=True, help='save model name')
+ parser.add_argument('--model_name2', required=True, help='save model name')
+ parser.add_argument('--convert_to',
+ default='kaldi',
+ required=True,
+ help='target network type, kaldi/pytorch')
+
+ args = parser.parse_args()
+ return args
+
+
+def convert_to_kaldi(
+ configs,
+ network_file,
+ model_dir,
+ model_name="convert.kaldi.txt",
+ model_name2="convert.kaldi2.txt"
+):
+ copyfile(network_file, os.path.join(model_dir, 'origin.torch.pt'))
+
+ model = FsmnKWSMTConvert(
+ vocab_size=configs['encoder_conf']['output_dim'],
+ vocab_size2=configs['encoder_conf']['output_dim2'],
+ encoder='FSMNMTConvert',
+ encoder_conf=configs['encoder_conf'],
+ ctc_conf=configs['ctc_conf'],
+ )
+ print(model)
+ num_params = count_parameters(model)
+ print('the number of model params: {}'.format(num_params))
+
+ states= torch.load(network_file, map_location='cpu')
+ model.load_state_dict(states["state_dict"])
+
+ kaldi_text = os.path.join(model_dir, model_name)
+ with open(kaldi_text, 'w', encoding='utf8') as fout:
+ nnet_desp = model.to_kaldi_net()
+ fout.write(nnet_desp)
+ fout.close()
+
+ kaldi_text2 = os.path.join(model_dir, model_name2)
+ with open(kaldi_text2, 'w', encoding='utf8') as fout:
+ nnet_desp2 = model.to_kaldi_net2()
+ fout.write(nnet_desp2)
+ fout.close()
+
+
+def convert_to_pytorch(
+ configs,
+ network_file,
+ model_dir,
+ model_name="convert.torch.pt"
+):
+ model = FsmnKWSMTConvert(
+ vocab_size=configs['encoder_conf']['output_dim'],
+ vocab_size2=configs['encoder_conf']['output_dim2'],
+ encoder='FSMNMTConvert',
+ encoder_conf=configs['encoder_conf'],
+ ctc_conf=configs['ctc_conf'],
+ )
+
+ num_params = count_parameters(model)
+ print('the number of model params: {}'.format(num_params))
+
+ copyfile(network_file, os.path.join(model_dir, 'origin.kaldi.txt'))
+ model.to_pytorch_net(network_file)
+
+ save_model_path = os.path.join(model_dir, model_name)
+ torch.save({"model": model.state_dict()}, save_model_path)
+
+ print('convert torch format back to kaldi')
+ kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt')
+ with open(kaldi_text, 'w', encoding='utf8') as fout:
+ nnet_desp = model.to_kaldi_net()
+ fout.write(nnet_desp)
+ fout.close()
+
+ print('Done!')
+
+
+def main():
+ args = get_args()
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+ print(args)
+ with open(args.config, 'r') as fin:
+ configs = yaml.load(fin, Loader=yaml.FullLoader)
+
+ if args.convert_to == 'pytorch':
+ print('convert kaldi net to pytorch...')
+ convert_to_pytorch(
+ configs,
+ args.network_file,
+ args.model_dir,
+ args.model_name,
+ args.model_name2,
+ )
+ elif args.convert_to == 'kaldi':
+ print('convert pytorch net to kaldi...')
+ convert_to_kaldi(
+ configs,
+ args.network_file,
+ args.model_dir,
+ args.model_name
+ )
+ else:
+ print('unsupported target network type: {}'.format(args.convert_to))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh
new file mode 100644
index 0000000..30e2eed
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh
@@ -0,0 +1,36 @@
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models
+mkdir -p ${local_path_root}
+
+local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
+if [ ! -d "$local_path" ]; then
+ git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
+fi
+
+export PATH=${local_path}/runtime:$PATH
+export LD_LIBRARY_PATH=${local_path}/runtime:$LD_LIBRARY_PATH
+
+# finetune config file
+config=./conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml
+
+# finetune output checkpoint
+torch_nnet=exp/finetune_outputs/model.pt.avg10
+
+out_dir=exp/finetune_outputs
+
+if [ ! -d "$out_dir" ]; then
+ mkdir -p $out_dir
+fi
+
+python convert.py --config $config \
+ --network_file $torch_nnet \
+ --model_dir $out_dir \
+ --model_name "convert.kaldi.txt" \
+ --model_name2 "convert.kaldi2.txt" \
+ --convert_to kaldi
+
+nnet-copy --binary=true ${out_dir}/convert.kaldi.txt ${out_dir}/convert.kaldi.net
+nnet-copy --binary=true ${out_dir}/convert.kaldi2.txt ${out_dir}/convert.kaldi2.net
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh
new file mode 100755
index 0000000..1e87021
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh
@@ -0,0 +1,186 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+. ./path.sh
+workspace=`pwd`
+
+CUDA_VISIBLE_DEVICES="0,1"
+
+stage=2
+stop_stage=3
+
+inference_device="cuda" #"cpu"
+inference_checkpoint="model.pt.avg10"
+inference_scp="wav.scp"
+inference_batch_size=32
+nj=32
+test_sets="test"
+
+# model_name from model_hub, or model_dir in local path
+
+## option 1, download model automatically, unsupported currently
+model_name_or_model_dir="iic/speech_charctc_kws_phone-xiaoyun"
+
+## option 2, download model by git
+local_path_root=${workspace}/modelscope_models
+model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
+if [ ! -d $model_name_or_model_dir ]; then
+ mkdir -p ${model_name_or_model_dir}
+ git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${model_name_or_model_dir}
+fi
+
+config=fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml
+token_list=${model_name_or_model_dir}/funasr/tokens_2599.txt
+token_list2=${model_name_or_model_dir}/funasr/tokens_xiaoyun_char.txt
+lexicon_list=${model_name_or_model_dir}/funasr/lexicon.txt
+cmvn_file=${model_name_or_model_dir}/funasr/am.mvn.dim80_l2r2
+init_param="${model_name_or_model_dir}/funasr/basetrain_fsmn_4e_l10r2_250_128_fdim80_t2599.pt"
+
+
+# data prepare
+# data dir, which contains: train.json, val.json
+data_dir=../../data
+
+train_data="${data_dir}/train.jsonl"
+val_data="${data_dir}/val.jsonl"
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Generate audio json list"
+ # generate train.jsonl and val.jsonl from wav.scp and text.txt
+ python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+ ++scp_file_list='['''${data_dir}/train_wav.scp''', '''${data_dir}/train_text.txt''']' \
+ ++data_type_list='["source", "target"]' \
+ ++jsonl_file_out="${train_data}"
+
+ python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+ ++scp_file_list='['''${data_dir}/val_wav.scp''', '''${data_dir}/val_text.txt''']' \
+ ++data_type_list='["source", "target"]' \
+ ++jsonl_file_out="${val_data}"
+fi
+
+# exp output dir
+output_dir="${workspace}/exp/finetune_outputs"
+
+
+# Training Stage
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: KWS Training"
+
+ mkdir -p ${output_dir}
+ current_time=$(date "+%Y-%m-%d_%H-%M")
+ log_file="${output_dir}/train.log.txt.${current_time}"
+ echo "log_file: ${log_file}"
+ echo "finetune use basetrain model: ${init_param}"
+
+ export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
+ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+ torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+ ../../../funasr/bin/train.py \
+ --config-path "${workspace}/conf" \
+ --config-name "${config}" \
+ ++init_param="${init_param}" \
+ ++token_lists='['''${token_list}''', '''${token_list2}''']' \
+ ++seg_dicts='['''${lexicon_list}''', '''${lexicon_list}''']' \
+ ++disable_update=true \
+ ++train_data_set_list="${train_data}" \
+ ++valid_data_set_list="${val_data}" \
+ ++frontend_conf.cmvn_file="${cmvn_file}" \
+ ++output_dir="${output_dir}" &> ${log_file}
+fi
+
+
+# Testing Stage
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: Inference"
+ keywords=(灏忎簯灏忎簯)
+ keywords_string=$(IFS=,; echo "${keywords[*]}")
+ echo "keywords: $keywords_string"
+
+ if [ ${inference_device} == "cuda" ]; then
+ nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+ else
+ inference_batch_size=1
+ CUDA_VISIBLE_DEVICES=""
+ for JOB in $(seq ${nj}); do
+ CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+ done
+ fi
+
+ for dset in ${test_sets}; do
+ inference_dir="${output_dir}/inference-${inference_checkpoint}/${dset}"
+ _logdir="${inference_dir}/logdir"
+ echo "inference_dir: ${inference_dir}"
+
+ mkdir -p "${_logdir}"
+ test_data_dir="${data_dir}/${dset}"
+ key_file=${test_data_dir}/${inference_scp}
+
+ split_scps=
+ for JOB in $(seq "${nj}"); do
+ split_scps+=" ${_logdir}/keys.${JOB}.scp"
+ done
+ $FUNASR_DIR/examples/aishell/paraformer/utils/split_scp.pl "${key_file}" ${split_scps}
+
+ gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+ for JOB in $(seq ${nj}); do
+ {
+ id=$((JOB-1))
+ gpuid=${gpuid_list_array[$id]}
+
+ echo "${output_dir}"
+
+ export CUDA_VISIBLE_DEVICES=${gpuid}
+ python ../../../funasr/bin/inference.py \
+ --config-path="${output_dir}" \
+ --config-name="config.yaml" \
+ ++init_param="${output_dir}/${inference_checkpoint}" \
+ ++tokenizer_conf.token_list="${token_list}" \
+ ++tokenizer_conf.seg_dict="${lexicon_list}" \
+ ++tokenizer2_conf.token_list="${token_list2}" \
+ ++tokenizer2_conf.seg_dict="${lexicon_list}" \
+ ++frontend_conf.cmvn_file="${cmvn_file}" \
+ ++keywords="\"$keywords_string"\" \
+ ++input="${_logdir}/keys.${JOB}.scp" \
+ ++output_dir="${inference_dir}/${JOB}" \
+ ++device="${inference_device}" \
+ ++ncpu=1 \
+ ++disable_log=true \
+ ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
+ }&
+
+ done
+ wait
+
+ for f in detect detect2; do
+ if [ -f "${inference_dir}/${JOB}/${f}" ]; then
+ for JOB in $(seq "${nj}"); do
+ cat "${inference_dir}/${JOB}/${f}"
+ done | sort -k1 >"${inference_dir}/${f}"
+ fi
+ done
+
+ mkdir -p ${inference_dir}/task1
+ python funasr/utils/compute_det_ctc.py \
+ --keywords ${keywords_string} \
+ --test_data ${test_data_dir}/wav.scp \
+ --trans_data ${test_data_dir}/text \
+ --score_file ${inference_dir}/detect \
+ --stats_dir ${inference_dir}/task1
+
+ mkdir -p ${inference_dir}/task2
+ python funasr/utils/compute_det_ctc.py \
+ --keywords ${keywords_string} \
+ --test_data ${test_data_dir}/wav.scp \
+ --trans_data ${test_data_dir}/text \
+ --score_file ${inference_dir}/detect2 \
+ --stats_dir ${inference_dir}/task2
+ done
+
+fi
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/funasr b/examples/industrial_data_pretraining/fsmn_kws_mt/funasr
new file mode 120000
index 0000000..39a970f
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/funasr
@@ -0,0 +1 @@
+../../../funasr
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh
new file mode 100644
index 0000000..6e03b89
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh
@@ -0,0 +1,20 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+model="iic/speech_charctc_kws_phone-xiaoyun"
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python funasr/bin/inference.py \
++model=${model} \
++input=${input} \
++output_dir="./outputs/debug" \
++device="cpu" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh
new file mode 100644
index 0000000..51d2312
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh
@@ -0,0 +1,44 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method2, inference from local model
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+output_dir="./outputs/debug"
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models
+mkdir -p ${local_path_root}
+local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
+git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
+
+device="cuda:0" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
+
+config="inference_fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml"
+tokens="${local_path}/funasr/tokens_2602.txt"
+tokens2="${local_path}/funasr/tokens_xiaoyun_char.txt"
+seg_dict="${local_path}/funasr/lexicon.txt"
+init_param="${local_path}/funasr/finetune_fsmn_4e_l10r2_280_200_fdim40_t2602_t4_xiaoyun_xiaoyun.pt"
+cmvn_file="${local_path}/funasr/am.mvn.dim40_l4r4"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python -m funasr.bin.inference \
+--config-path "${local_path}/funasr" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++tokenizer_conf.token_list="${tokens}" \
+++tokenizer_conf.seg_dict="${seg_dict}" \
+++tokenizer2_conf.token_list="${tokens2}" \
+++tokenizer2_conf.seg_dict="${seg_dict}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++device="${device}" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/path.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/examples/industrial_data_pretraining/sanm_kws/conf/sanm_6e_320_256_fdim40_t2602.yaml b/examples/industrial_data_pretraining/sanm_kws/conf/sanm_6e_320_256_fdim40_t2602.yaml
new file mode 100644
index 0000000..c4d8c18
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/conf/sanm_6e_320_256_fdim40_t2602.yaml
@@ -0,0 +1,94 @@
+
+# network architecture
+model: SanmKWS
+model_conf:
+ ctc_weight: 1.0
+
+# encoder
+encoder: SANMEncoder
+encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 320 # the number of units of position-wise feed forward
+ num_blocks: 6 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.1
+ input_layer: pe
+ pos_enc_class: SinusoidalPositionEncoder
+ normalize_before: true
+ kernel_size: 11
+ sanm_shfit: 0
+ selfattention_layer_type: sanm
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 40
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 7
+ lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+ apply_time_warp: false
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ lfr_rate: 6
+ num_freq_mask: 1
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 12
+ num_time_mask: 1
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 100
+ keep_nbest_models: 20
+ avg_nbest_model: 10
+ avg_keep_nbest_models_type: loss
+ validate_interval: 50000
+ save_checkpoint_interval: 50000
+ avg_checkpoint_interval: 1000
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 10000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: EspnetStyleBatchSampler
+ batch_type: length # example or length
+ batch_size: 96000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 2048
+ shuffle: true
+ num_workers: 8
+ preprocessor_speech: SpeechPreprocessSpeedPerturb
+ preprocessor_speech_conf:
+ speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin # ctc_type: focalctc, builtin
+ reduce: true
+ ignore_nan_grad: true
+normalize: null
diff --git a/examples/industrial_data_pretraining/sanm_kws/export.sh b/examples/industrial_data_pretraining/sanm_kws/export.sh
new file mode 100644
index 0000000..1106292
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/export.sh
@@ -0,0 +1,17 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+config_path="/home/pengteng.spt/source/FunASR_KWS/examples/industrial_data_pretraining/sanm_kws/conf"
+config_path="/home/pengteng.spt/source/FunASR_KWS/examples/industrial_data_pretraining/sanm_kws/exp/20240914_xiaoyun_finetune_sanm_6e_320_256_feats_dim40_char_t2602_offline"
+
+config_file="sanm_6e_320_256_fdim40_t2602.yaml"
+config_file="config.yaml"
+
+model_path="./modelscope_models_kws/speech_charctc_kws_phone-xiaoyun/funasr/finetune_sanm_6e_320_256_fdim40_t2602_online_xiaoyun_commands.pt"
+
+python -m funasr.bin.export \
+ --config-path="${config_path}" \
+ --config-name="${config_file}" \
+ ++init_param=${model_path} \
+ ++type="onnx" \
+ ++quantize=true
diff --git a/examples/industrial_data_pretraining/sanm_kws/finetune.sh b/examples/industrial_data_pretraining/sanm_kws/finetune.sh
new file mode 100755
index 0000000..cf6d488
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/finetune.sh
@@ -0,0 +1,172 @@
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+. ./path.sh
+workspace=`pwd`
+
+CUDA_VISIBLE_DEVICES="0,1"
+
+stage=2
+stop_stage=3
+
+inference_device="cpu" #"cpu"
+inference_device="cuda" #"cpu"
+inference_checkpoint="model.pt.avg10"
+inference_scp="wav.scp"
+inference_batch_size=32
+nj=32
+test_sets="test"
+
+# model_name from model_hub, or model_dir in local path
+
+## option 1, download model automatically, unsupported currently
+model_name_or_model_dir="iic/speech_sanm_kws_phone-xiaoyun-commands-offline"
+
+## option 2, download model by git
+local_path_root=${workspace}/modelscope_models
+model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
+if [ ! -d $model_name_or_model_dir ]; then
+ mkdir -p ${model_name_or_model_dir}
+ git clone https://www.modelscope.cn/iic/speech_sanm_kws_phone-xiaoyun-commands-offline.git ${model_name_or_model_dir}
+fi
+
+config=sanm_6e_320_256_fdim40_t2602.yaml
+token_list=${model_name_or_model_dir}/tokens_2602.txt
+lexicon_list=${model_name_or_model_dir}/lexicon.txt
+cmvn_file=${model_name_or_model_dir}/am.mvn.dim40_l3r3
+init_param="${model_name_or_model_dir}/basetrain_sanm_6e_320_256_fdim40_t2602_offline.pt"
+
+
+# data prepare
+# data dir, which contains: train.json, val.json
+data_dir=../../data
+
+train_data="${data_dir}/train.jsonl"
+val_data="${data_dir}/val.jsonl"
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Generate audio json list"
+ # generate train.jsonl and val.jsonl from wav.scp and text.txt
+ python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+ ++scp_file_list='['''${data_dir}/train_wav.scp''', '''${data_dir}/train_text.txt''']' \
+ ++data_type_list='["source", "target"]' \
+ ++jsonl_file_out="${train_data}"
+
+ python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+ ++scp_file_list='['''${data_dir}/val_wav.scp''', '''${data_dir}/val_text.txt''']' \
+ ++data_type_list='["source", "target"]' \
+ ++jsonl_file_out="${val_data}"
+fi
+
+# exp output dir
+output_dir="${workspace}/exp/finetune_outputs"
+
+# Training Stage
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: KWS Training"
+
+ mkdir -p ${output_dir}
+ current_time=$(date "+%Y-%m-%d_%H-%M")
+ log_file="${output_dir}/train.log.txt.${current_time}"
+ echo "log_file: ${log_file}"
+ echo "finetune use basetrain model: ${init_param}"
+
+ export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
+ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+ torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+ ../../../funasr/bin/train.py \
+ --config-path "${workspace}/conf" \
+ --config-name "${config}" \
+ ++init_param="${init_param}" \
+ ++disable_update=true \
+ ++train_data_set_list="${train_data}" \
+ ++valid_data_set_list="${val_data}" \
+ ++tokenizer_conf.token_list="${token_list}" \
+ ++tokenizer_conf.seg_dict="${lexicon_list}" \
+ ++frontend_conf.cmvn_file="${cmvn_file}" \
+ ++output_dir="${output_dir}" &> ${log_file}
+fi
+
+
+# Testing Stage
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: Inference"
+ keywords=(灏忎簯灏忎簯)
+ keywords_string=$(IFS=,; echo "${keywords[*]}")
+ echo "keywords: $keywords_string"
+
+ if [ ${inference_device} == "cuda" ]; then
+ nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+ else
+ inference_batch_size=1
+ CUDA_VISIBLE_DEVICES=""
+ for JOB in $(seq ${nj}); do
+ CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+ done
+ fi
+
+ for dset in ${test_sets}; do
+ inference_dir="${output_dir}/inference-${inference_checkpoint}/${dset}"
+ _logdir="${inference_dir}/logdir"
+ echo "inference_dir: ${inference_dir}"
+
+ mkdir -p "${_logdir}"
+ test_data_dir="${data_dir}/${dset}"
+ key_file=${test_data_dir}/${inference_scp}
+
+ split_scps=
+ for JOB in $(seq "${nj}"); do
+ split_scps+=" ${_logdir}/keys.${JOB}.scp"
+ done
+ $FUNASR_DIR/examples/aishell/paraformer/utils/split_scp.pl "${key_file}" ${split_scps}
+
+ gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+ for JOB in $(seq ${nj}); do
+ {
+ id=$((JOB-1))
+ gpuid=${gpuid_list_array[$id]}
+
+ echo "${output_dir}"
+
+ export CUDA_VISIBLE_DEVICES=${gpuid}
+ python ../../../funasr/bin/inference.py \
+ --config-path="${output_dir}" \
+ --config-name="config.yaml" \
+ ++init_param="${output_dir}/${inference_checkpoint}" \
+ ++tokenizer_conf.token_list="${token_list}" \
+ ++tokenizer_conf.seg_dict="${lexicon_list}" \
+ ++frontend_conf.cmvn_file="${cmvn_file}" \
+ ++keywords="\"$keywords_string"\" \
+ ++input="${_logdir}/keys.${JOB}.scp" \
+ ++output_dir="${inference_dir}/${JOB}" \
+ ++device="${inference_device}" \
+ ++ncpu=1 \
+ ++disable_log=true \
+ ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
+ # ++batch_size="${inference_batch_size}"
+ }&
+
+ done
+ wait
+
+ for f in detect score; do
+ if [ -f "${inference_dir}/${JOB}/${f}" ]; then
+ for JOB in $(seq "${nj}"); do
+ cat "${inference_dir}/${JOB}/${f}"
+ done | sort -k1 >"${inference_dir}/${f}"
+ fi
+ done
+
+ python funasr/utils/compute_det_ctc.py \
+ --keywords ${keywords_string} \
+ --test_data ${test_data_dir}/wav.scp \
+ --trans_data ${test_data_dir}/text \
+ --score_file ${inference_dir}/detect \
+ --stats_dir ${inference_dir}
+ done
+
+fi
diff --git a/examples/industrial_data_pretraining/sanm_kws/funasr b/examples/industrial_data_pretraining/sanm_kws/funasr
new file mode 120000
index 0000000..39a970f
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/funasr
@@ -0,0 +1 @@
+../../../funasr
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/sanm_kws/infer.sh b/examples/industrial_data_pretraining/sanm_kws/infer.sh
new file mode 100644
index 0000000..74455bb
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/infer.sh
@@ -0,0 +1,20 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+model="iic/speech_sanm_kws_phone-xiaoyun-commands-offline"
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python funasr/bin/inference.py \
++model=${model} \
++input=${input} \
++output_dir="./outputs/debug" \
++device="cpu" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/sanm_kws/infer_from_local.sh b/examples/industrial_data_pretraining/sanm_kws/infer_from_local.sh
new file mode 100644
index 0000000..6be8bab
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/infer_from_local.sh
@@ -0,0 +1,41 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method2, inference from local model
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+output_dir="./outputs/debug"
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models
+mkdir -p ${local_path_root}
+local_path=${local_path_root}/speech_sanm_kws_phone-xiaoyun-commands-offline
+git clone https://www.modelscope.cn/iic/speech_sanm_kws_phone-xiaoyun-commands-offline.git ${local_path}
+
+device="cpu" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
+
+config="inference_sanm_6e_320_256_fdim40_t2602_offline.yaml"
+tokens="${local_path}/tokens_2602.txt"
+seg_dict="${local_path}/lexicon.txt"
+init_param="${local_path}/finetune_sanm_6e_320_256_fdim40_t2602_offline_xiaoyun_commands.pt"
+cmvn_file="${local_path}/am.mvn.dim40_l3r3"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python -m funasr.bin.inference \
+--config-path "${local_path}/" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++tokenizer_conf.token_list="${tokens}" \
+++tokenizer_conf.seg_dict="${seg_dict}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++device="${device}" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/sanm_kws/path.sh b/examples/industrial_data_pretraining/sanm_kws/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/conf/sanm_6e_320_256_fdim40_t2602.yaml b/examples/industrial_data_pretraining/sanm_kws_streaming/conf/sanm_6e_320_256_fdim40_t2602.yaml
new file mode 100644
index 0000000..664997c
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/conf/sanm_6e_320_256_fdim40_t2602.yaml
@@ -0,0 +1,109 @@
+
+# network architecture
+model: SanmKWSStreaming
+model_conf:
+ ctc_weight: 1.0
+
+# encoder
+encoder: SANMEncoderChunkOpt
+encoder_conf:
+ output_size: 256 # dimension of attention
+ attention_heads: 4
+ linear_units: 320 # the number of units of position-wise feed forward
+ num_blocks: 6 # the number of encoder blocks
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.1
+ input_layer: pe_online
+ pos_enc_class: SinusoidalPositionEncoder
+ normalize_before: true
+ kernel_size: 11
+ sanm_shfit: 0
+ selfattention_layer_type: sanm
+ chunk_size:
+ - 16
+ - 20
+ stride:
+ - 8
+ - 10
+ pad_left:
+ - 4
+ - 5
+ encoder_att_look_back_factor:
+ - 0
+ - 0
+ decoder_att_look_back_factor:
+ - 0
+ - 0
+
+# frontend related
+frontend: WavFrontendOnline
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 40
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 7
+ lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+ apply_time_warp: false
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ lfr_rate: 6
+ num_freq_mask: 1
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 12
+ num_time_mask: 1
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 100
+ keep_nbest_models: 20
+ avg_nbest_model: 10
+ avg_keep_nbest_models_type: loss
+ validate_interval: 50000
+ save_checkpoint_interval: 50000
+ avg_checkpoint_interval: 1000
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: EspnetStyleBatchSampler
+ batch_type: length # example or length
+ batch_size: 64000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 2048
+ shuffle: true
+ num_workers: 8
+ preprocessor_speech: SpeechPreprocessSpeedPerturb
+ preprocessor_speech_conf:
+ speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin # ctc_type: focalctc, builtin
+ reduce: true
+ ignore_nan_grad: true
+normalize: null
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/export.sh b/examples/industrial_data_pretraining/sanm_kws_streaming/export.sh
new file mode 100644
index 0000000..eef6875
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/export.sh
@@ -0,0 +1,17 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+config_path="/home/pengteng.spt/source/FunASR_KWS/examples/industrial_data_pretraining/sanm_kws_streaming/conf"
+config_path="/home/pengteng.spt/source/FunASR_KWS/examples/industrial_data_pretraining/sanm_kws_streaming/exp/20240618_xiaoyun_finetune_sanm_6e_320_256_feats_dim40_char_t2602_online_6"
+
+config_file="sanm_6e_320_256_fdim40_t2602.yaml"
+config_file="config.yaml"
+
+model_path="./modelscope_models_kws/speech_charctc_kws_phone-xiaoyun/funasr/finetune_sanm_6e_320_256_fdim40_t2602_online_xiaoyun_commands.pt"
+
+python -m funasr.bin.export \
+ --config-path="${config_path}" \
+ --config-name="${config_file}" \
+ ++init_param=${model_path} \
+ ++type="onnx" \
+ ++quantize=true
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/finetune.sh b/examples/industrial_data_pretraining/sanm_kws_streaming/finetune.sh
new file mode 100755
index 0000000..8914786
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/finetune.sh
@@ -0,0 +1,258 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+. ./path.sh
+workspace=`pwd`
+
+CUDA_VISIBLE_DEVICES="0,1"
+
+stage=2
+stop_stage=4
+
+inference_device="cpu" #"cpu"
+inference_checkpoint="model.pt.avg10"
+inference_scp="wav.scp"
+inference_batch_size=32
+nj=32
+test_sets="test"
+
+# model_name from model_hub, or model_dir in local path
+
+## option 1, download model automatically, unsupported currently
+model_name_or_model_dir="iic/speech_sanm_kws_phone-xiaoyun-commands-online"
+
+## option 2, download model by git
+local_path_root=${workspace}/modelscope_models
+model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
+if [ ! -d $model_name_or_model_dir ]; then
+ mkdir -p ${model_name_or_model_dir}
+ git clone https://www.modelscope.cn/iic/speech_sanm_kws_phone-xiaoyun-commands-online.git ${model_name_or_model_dir}
+fi
+
+config=sanm_6e_320_256_fdim40_t2602.yaml
+token_list=${model_name_or_model_dir}/tokens_2602.txt
+lexicon_list=${model_name_or_model_dir}/lexicon.txt
+cmvn_file=${model_name_or_model_dir}/am.mvn.dim40_l3r3
+init_param="${model_name_or_model_dir}/basetrain_sanm_6e_320_256_fdim40_t2602_online.pt"
+
+
+# data prepare
+# data dir, which contains: train.json, val.json
+data_dir=../../data
+
+train_data="${data_dir}/train.jsonl"
+val_data="${data_dir}/val.jsonl"
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+ echo "stage 1: Generate audio json list"
+ # generate train.jsonl and val.jsonl from wav.scp and text.txt
+ python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+ ++scp_file_list='['''${data_dir}/train_wav.scp''', '''${data_dir}/train_text.txt''']' \
+ ++data_type_list='["source", "target"]' \
+ ++jsonl_file_out="${train_data}"
+
+ python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+ ++scp_file_list='['''${data_dir}/val_wav.scp''', '''${data_dir}/val_text.txt''']' \
+ ++data_type_list='["source", "target"]' \
+ ++jsonl_file_out="${val_data}"
+fi
+
+# exp output dir
+output_dir="${workspace}/exp/finetune_outputs"
+
+# Training Stage
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+ echo "stage 2: KWS Training"
+
+ mkdir -p ${output_dir}
+ current_time=$(date "+%Y-%m-%d_%H-%M")
+ log_file="${output_dir}/train.log.txt.${current_time}"
+ echo "log_file: ${log_file}"
+ echo "finetune use basetrain model: ${init_param}"
+
+ export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
+ gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+ torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+ ../../../funasr/bin/train.py \
+ --config-path "${workspace}/conf" \
+ --config-name "${config}" \
+ ++init_param="${init_param}" \
+ ++disable_update=true \
+ ++train_data_set_list="${train_data}" \
+ ++valid_data_set_list="${val_data}" \
+ ++tokenizer_conf.token_list="${token_list}" \
+ ++tokenizer_conf.seg_dict="${lexicon_list}" \
+ ++frontend_conf.cmvn_file="${cmvn_file}" \
+ ++output_dir="${output_dir}" &> ${log_file}
+fi
+
+
+# Testing Stage
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+ echo "stage 3: Inference chunk_size: [4, 8, 4]"
+ keywords=(灏忎簯灏忎簯)
+ keywords_string=$(IFS=,; echo "${keywords[*]}")
+ echo "keywords: $keywords_string"
+
+ if [ ${inference_device} == "cuda" ]; then
+ nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+ else
+ inference_batch_size=1
+ CUDA_VISIBLE_DEVICES=""
+ for JOB in $(seq ${nj}); do
+ CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+ done
+ fi
+
+ for dset in ${test_sets}; do
+ inference_dir="${output_dir}/inference-${inference_checkpoint}/${dset}/chunk-4-8-4_elb-0_dlb_0"
+ _logdir="${inference_dir}/logdir"
+ echo "inference_dir: ${inference_dir}"
+
+ mkdir -p "${_logdir}"
+ test_data_dir="${data_dir}/${dset}"
+ key_file=${test_data_dir}/${inference_scp}
+
+ split_scps=
+ for JOB in $(seq "${nj}"); do
+ split_scps+=" ${_logdir}/keys.${JOB}.scp"
+ done
+ $FUNASR_DIR/examples/aishell/paraformer/utils/split_scp.pl "${key_file}" ${split_scps}
+
+ gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+ for JOB in $(seq ${nj}); do
+ {
+ id=$((JOB-1))
+ gpuid=${gpuid_list_array[$id]}
+
+ echo "${output_dir}"
+
+ export CUDA_VISIBLE_DEVICES=${gpuid}
+ python ../../../funasr/bin/inference.py \
+ --config-path="${output_dir}" \
+ --config-name="config.yaml" \
+ ++init_param="${output_dir}/${inference_checkpoint}" \
+ ++tokenizer_conf.token_list="${token_list}" \
+ ++tokenizer_conf.seg_dict="${lexicon_list}" \
+ ++frontend_conf.cmvn_file="${cmvn_file}" \
+ ++keywords="\"$keywords_string"\" \
+ ++input="${_logdir}/keys.${JOB}.scp" \
+ ++output_dir="${inference_dir}/${JOB}" \
+ ++chunk_size='[4, 8, 4]' \
+ ++encoder_chunk_look_back=0 \
+ ++decoder_chunk_look_back=0 \
+ ++device="${inference_device}" \
+ ++ncpu=1 \
+ ++disable_log=true \
+ ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
+ }&
+
+ done
+ wait
+
+ for f in detect score; do
+ if [ -f "${inference_dir}/${JOB}/${f}" ]; then
+ for JOB in $(seq "${nj}"); do
+ cat "${inference_dir}/${JOB}/${f}"
+ done | sort -k1 >"${inference_dir}/${f}"
+ fi
+ done
+
+ python funasr/utils/compute_det_ctc.py \
+ --keywords ${keywords_string} \
+ --test_data ${test_data_dir}/wav.scp \
+ --trans_data ${test_data_dir}/text \
+ --score_file ${inference_dir}/detect \
+ --stats_dir ${inference_dir}
+ done
+
+fi
+
+
+# Testing Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+ echo "stage 4: Inference chunk_size: [5, 10, 5]"
+ keywords=(灏忎簯灏忎簯)
+ keywords_string=$(IFS=,; echo "${keywords[*]}")
+ echo "keywords: $keywords_string"
+
+ if [ ${inference_device} == "cuda" ]; then
+ nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+ else
+ inference_batch_size=1
+ CUDA_VISIBLE_DEVICES=""
+ for JOB in $(seq ${nj}); do
+ CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+ done
+ fi
+
+ for dset in ${test_sets}; do
+ inference_dir="${output_dir}/inference-${inference_checkpoint}/${dset}/chunk-5-10-5_elb-0_dlb_0"
+ _logdir="${inference_dir}/logdir"
+ echo "inference_dir: ${inference_dir}"
+
+ mkdir -p "${_logdir}"
+ test_data_dir="${data_dir}/${dset}"
+ key_file=${test_data_dir}/${inference_scp}
+
+ split_scps=
+ for JOB in $(seq "${nj}"); do
+ split_scps+=" ${_logdir}/keys.${JOB}.scp"
+ done
+ $FUNASR_DIR/examples/aishell/paraformer/utils/split_scp.pl "${key_file}" ${split_scps}
+
+ gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+ for JOB in $(seq ${nj}); do
+ {
+ id=$((JOB-1))
+ gpuid=${gpuid_list_array[$id]}
+
+ echo "${output_dir}"
+
+ export CUDA_VISIBLE_DEVICES=${gpuid}
+ python ../../../funasr/bin/inference.py \
+ --config-path="${output_dir}" \
+ --config-name="config.yaml" \
+ ++init_param="${output_dir}/${inference_checkpoint}" \
+ ++tokenizer_conf.token_list="${token_list}" \
+ ++tokenizer_conf.seg_dict="${lexicon_list}" \
+ ++frontend_conf.cmvn_file="${cmvn_file}" \
+ ++keywords="\"$keywords_string"\" \
+ ++input="${_logdir}/keys.${JOB}.scp" \
+ ++output_dir="${inference_dir}/${JOB}" \
+ ++chunk_size='[5, 10, 5]' \
+ ++encoder_chunk_look_back=0 \
+ ++decoder_chunk_look_back=0 \
+ ++device="${inference_device}" \
+ ++ncpu=1 \
+ ++disable_log=true \
+ ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
+ }&
+
+ done
+ wait
+
+ for f in detect; do
+ if [ -f "${inference_dir}/${JOB}/${f}" ]; then
+ for JOB in $(seq "${nj}"); do
+ cat "${inference_dir}/${JOB}/${f}"
+ done | sort -k1 >"${inference_dir}/${f}"
+ fi
+ done
+
+ python funasr/utils/compute_det_ctc.py \
+ --keywords ${keywords_string} \
+ --test_data ${test_data_dir}/wav.scp \
+ --trans_data ${test_data_dir}/text \
+ --score_file ${inference_dir}/detect \
+ --stats_dir ${inference_dir}
+ done
+
+fi
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/funasr b/examples/industrial_data_pretraining/sanm_kws_streaming/funasr
new file mode 120000
index 0000000..39a970f
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/funasr
@@ -0,0 +1 @@
+../../../funasr
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/infer.sh b/examples/industrial_data_pretraining/sanm_kws_streaming/infer.sh
new file mode 100644
index 0000000..ce34a9b
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/infer.sh
@@ -0,0 +1,34 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+model="iic/speech_sanm_kws_phone-xiaoyun-commands-online"
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python funasr/bin/inference.py \
++model=${model} \
++input=${input} \
++output_dir="./outputs/debug" \
+++chunk_size='[4, 8, 4]' \
+++encoder_chunk_look_back=0 \
+++decoder_chunk_look_back=0 \
++device="cpu" \
+++keywords="\"$keywords_string"\"
+
+
+python funasr/bin/inference.py \
++model=${model} \
++input=${input} \
++output_dir="./outputs/debug" \
+++chunk_size='[5, 10, 5]' \
+++encoder_chunk_look_back=0 \
+++decoder_chunk_look_back=0 \
++device="cpu" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/infer_from_local.sh b/examples/industrial_data_pretraining/sanm_kws_streaming/infer_from_local.sh
new file mode 100644
index 0000000..a1f0639
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/infer_from_local.sh
@@ -0,0 +1,62 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+# method2, inference from local model
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+output_dir="./outputs/debug"
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models
+mkdir -p ${local_path_root}
+local_path=${local_path_root}/speech_sanm_kws_phone-xiaoyun-commands-online
+git clone https://www.modelscope.cn/iic/speech_sanm_kws_phone-xiaoyun-commands-online.git ${local_path}
+
+device="cpu" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
+
+config="inference_sanm_6e_320_256_fdim40_t2602_online.yaml"
+tokens="${local_path}/tokens_2602.txt"
+seg_dict="${local_path}/lexicon.txt"
+init_param="${local_path}/finetune_sanm_6e_320_256_fdim40_t2602_online_xiaoyun_commands.pt"
+cmvn_file="${local_path}/am.mvn.dim40_l3r3"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+echo "inference sanm streaming with chunk_size=[4, 8, 4]"
+python -m funasr.bin.inference \
+--config-path "${local_path}/" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++tokenizer_conf.token_list="${tokens}" \
+++tokenizer_conf.seg_dict="${seg_dict}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++chunk_size='[4, 8, 4]' \
+++encoder_chunk_look_back=0 \
+++decoder_chunk_look_back=0 \
+++device="${device}" \
+++keywords="\"$keywords_string"\"
+
+
+echo "inference sanm streaming with chunk_size=[5, 10, 5]"
+python -m funasr.bin.inference \
+--config-path "${local_path}/" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++tokenizer_conf.token_list="${tokens}" \
+++tokenizer_conf.seg_dict="${seg_dict}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++chunk_size='[5, 10, 5]' \
+++encoder_chunk_look_back=0 \
+++decoder_chunk_look_back=0 \
+++device="${device}" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/path.sh b/examples/industrial_data_pretraining/sanm_kws_streaming/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index ca1f202..9f5f4fb 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -14,6 +14,7 @@
import numpy as np
from tqdm import tqdm
+from omegaconf import DictConfig, ListConfig
from funasr.utils.misc import deep_update
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
@@ -187,21 +188,59 @@
# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
- if tokenizer is not None:
- tokenizer_class = tables.tokenizer_classes.get(tokenizer)
- tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
- kwargs["token_list"] = (
- tokenizer.token_list if hasattr(tokenizer, "token_list") else None
- )
- kwargs["token_list"] = (
- tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
- )
- vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
- if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
- vocab_size = tokenizer.get_vocab_size()
- else:
- vocab_size = -1
kwargs["tokenizer"] = tokenizer
+ kwargs["vocab_size"] = -1
+
+ if tokenizer is not None:
+ tokenizers = (
+ tokenizer.split(",") if isinstance(tokenizer, str) else tokenizer
+ ) # type of tokenizers is list!!!
+ tokenizers_conf = kwargs.get("tokenizer_conf", {})
+ tokenizers_build = []
+ vocab_sizes = []
+ token_lists = []
+ ### === only for kws ===
+ token_list_files = kwargs.get("token_lists", [])
+ seg_dicts = kwargs.get("seg_dicts", [])
+ ### === only for kws ===
+
+ if not isinstance(tokenizers_conf, (list, tuple, ListConfig)):
+ tokenizers_conf = [tokenizers_conf] * len(tokenizers)
+
+ for i, tokenizer in enumerate(tokenizers):
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
+ tokenizer_conf = tokenizers_conf[i]
+
+ ### === only for kws ===
+ if len(token_list_files) > 1:
+ tokenizer_conf.token_list = token_list_files[i]
+ if len(seg_dicts) > 1:
+ tokenizer_conf.seg_dict = seg_dicts[i]
+ ### === only for kws ===
+
+ tokenizer = tokenizer_class(**tokenizer_conf)
+ tokenizers_build.append(tokenizer)
+ token_list = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+ token_list = (
+ tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else token_list
+ )
+ vocab_size = -1
+ if token_list is not None:
+ vocab_size = len(token_list)
+
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
+ vocab_size = tokenizer.get_vocab_size()
+ token_lists.append(token_list)
+ vocab_sizes.append(vocab_size)
+
+ if len(tokenizers_build) <= 1:
+ tokenizers_build = tokenizers_build[0]
+ token_lists = token_lists[0]
+ vocab_sizes = vocab_sizes[0]
+
+ kwargs["tokenizer"] = tokenizers_build
+ kwargs["vocab_size"] = vocab_sizes
+ kwargs["token_list"] = token_lists
# build frontend
frontend = kwargs.get("frontend", None)
@@ -219,7 +258,7 @@
model_conf = {}
deep_update(model_conf, kwargs.get("model_conf", {}))
deep_update(model_conf, kwargs)
- model = model_class(**model_conf, vocab_size=vocab_size)
+ model = model_class(**model_conf)
# init_param
init_param = kwargs.get("init_param", None)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 2729b80..fcd763f 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -20,6 +20,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.algorithms.join import Join
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from tensorboardX import SummaryWriter
from funasr.train_utils.average_nbest_models import average_checkpoints
from funasr.register import tables
@@ -191,8 +192,6 @@
tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
os.makedirs(tensorboard_dir, exist_ok=True)
try:
- from tensorboardX import SummaryWriter
-
writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None
except:
writer = None
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
index 2aafde3..68b2d3c 100644
--- a/funasr/datasets/audio_datasets/datasets.py
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -1,6 +1,7 @@
import torch
import random
+
from funasr.register import tables
from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
@@ -17,6 +18,7 @@
index_ds: str = None,
frontend=None,
tokenizer=None,
+ is_training: bool = True,
int_pad_value: int = -1,
float_pad_value: float = 0.0,
**kwargs,
@@ -24,18 +26,23 @@
super().__init__()
index_ds_class = tables.index_ds_classes.get(index_ds)
self.index_ds = index_ds_class(path, **kwargs)
- preprocessor_speech = kwargs.get("preprocessor_speech", None)
- if preprocessor_speech:
- preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
- preprocessor_speech = preprocessor_speech_class(
- **kwargs.get("preprocessor_speech_conf")
- )
- self.preprocessor_speech = preprocessor_speech
- preprocessor_text = kwargs.get("preprocessor_text", None)
- if preprocessor_text:
- preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
- preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
- self.preprocessor_text = preprocessor_text
+
+ self.preprocessor_speech = None
+ self.preprocessor_text = None
+
+ if is_training:
+ preprocessor_speech = kwargs.get("preprocessor_speech", None)
+ if preprocessor_speech:
+ preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+ preprocessor_speech = preprocessor_speech_class(
+ **kwargs.get("preprocessor_speech_conf")
+ )
+ self.preprocessor_speech = preprocessor_speech
+ preprocessor_text = kwargs.get("preprocessor_text", None)
+ if preprocessor_text:
+ preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+ preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+ self.preprocessor_text = preprocessor_text
self.frontend = frontend
self.fs = 16000 if frontend is None else frontend.fs
@@ -64,6 +71,7 @@
data_src = load_audio_text_image_video(source, fs=self.fs)
if self.preprocessor_speech:
data_src = self.preprocessor_speech(data_src, fs=self.fs)
+
speech, speech_lengths = extract_fbank(
data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
) # speech: [b, T, d]
@@ -71,6 +79,7 @@
target = item["target"]
if self.preprocessor_text:
target = self.preprocessor_text(target)
+
if self.tokenizer:
ids = self.tokenizer.encode(target)
text = torch.tensor(ids, dtype=torch.int64)
diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py
index f4c9d74..48c64d2 100644
--- a/funasr/datasets/audio_datasets/scp2jsonl.py
+++ b/funasr/datasets/audio_datasets/scp2jsonl.py
@@ -58,7 +58,8 @@
for key in json_dict[data_type_list[0]].keys():
jsonl_line = {"key": key}
for data_file in data_type_list:
- jsonl_line.update(json_dict[data_file][key])
+ if key in json_dict[data_file]:
+ jsonl_line.update(json_dict[data_file][key])
jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
f.write(jsonl_line + "\n")
f.flush()
@@ -81,10 +82,14 @@
key = lines[0]
line = lines[1] if len(lines) > 1 else ""
line = line.strip()
- if os.path.exists(line):
- waveform, _ = librosa.load(line, sr=16000)
- sample_num = len(waveform)
- context_len = int(sample_num / 16000 * 1000 / 10)
+ if data_type == "source":
+ if os.path.exists(line):
+ waveform, _ = librosa.load(line, sr=16000)
+ sample_num = len(waveform)
+ context_len = int(sample_num * 1000 / 16000 / 10)
+ else:
+ print("source file not found: {}".format(line))
+ continue
else:
context_len = len(line.split()) if " " in line else len(line)
res[key] = {data_type: line, f"{data_type}_len": context_len}
diff --git a/funasr/datasets/kws_datasets/__init__.py b/funasr/datasets/kws_datasets/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/kws_datasets/__init__.py
diff --git a/funasr/datasets/kws_datasets/datasets.py b/funasr/datasets/kws_datasets/datasets.py
new file mode 100644
index 0000000..4679295
--- /dev/null
+++ b/funasr/datasets/kws_datasets/datasets.py
@@ -0,0 +1,132 @@
+import torch
+import random
+
+
+from funasr.register import tables
+from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
+
+
+@tables.register("dataset_classes", "KwsMTDataset")
+class KwsMTDataset(torch.utils.data.Dataset):
+ """
+ KwsMTDataset, support multi tokenizers
+ """
+ def __init__(self,
+ path,
+ index_ds: str = None,
+ frontend=None,
+ tokenizer=None,
+ is_training: bool = True,
+ int_pad_value: int = -1,
+ float_pad_value: float = 0.0,
+ **kwargs,
+ ):
+ super().__init__()
+ index_ds_class = tables.index_ds_classes.get(index_ds)
+ self.index_ds = index_ds_class(path, **kwargs)
+
+ self.preprocessor_speech = None
+ self.preprocessor_text = None
+
+ if is_training:
+ preprocessor_speech = kwargs.get("preprocessor_speech", None)
+ if preprocessor_speech:
+ preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+ preprocessor_speech = preprocessor_speech_class(
+ **kwargs.get("preprocessor_speech_conf")
+ )
+ self.preprocessor_speech = preprocessor_speech
+
+ preprocessor_text = kwargs.get("preprocessor_text", None)
+ if preprocessor_text:
+ preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+ preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+ self.preprocessor_text = preprocessor_text
+
+ self.frontend = frontend
+ self.fs = 16000 if frontend is None else frontend.fs
+ self.data_type = "sound"
+ print(tokenizer)
+ self.tokenizer = tokenizer
+
+ self.int_pad_value = int_pad_value
+ self.float_pad_value = float_pad_value
+
+ def get_source_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_source_len(item)
+
+ def get_target_len(self, index):
+ item = self.index_ds[index]
+ return self.index_ds.get_target_len(item)
+
+ def __len__(self):
+ return len(self.index_ds)
+
+ def __getitem__(self, index):
+ item = self.index_ds[index]
+ # import pdb;
+ # pdb.set_trace()
+ source = item["source"]
+ data_src = load_audio_text_image_video(source, fs=self.fs)
+ if self.preprocessor_speech:
+ data_src = self.preprocessor_speech(data_src, fs=self.fs)
+ speech, speech_lengths = extract_fbank(
+ data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
+ ) # speech: [b, T, d]
+
+ target = item["target"]
+ if self.preprocessor_text:
+ target = self.preprocessor_text(target)
+
+ if self.tokenizer[0]:
+ ids = self.tokenizer[0].encode(target)
+ text = torch.tensor(ids, dtype=torch.int64)
+ # print("target: ", target, ", ids: ", str(ids))
+ else:
+ ids = target
+ text = ids
+
+ if self.tokenizer[1]:
+ ids2 = self.tokenizer[1].encode(target)
+ text2 = torch.tensor(ids2, dtype=torch.int64)
+ # print("target: ", target, ", ids2: ", str(ids2))
+ else:
+ ids2 = target
+ text2 = ids2
+
+ ids_lengths = len(ids)
+ text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
+
+ ids2_lengths = len(ids2)
+ text2_lengths = torch.tensor([ids2_lengths], dtype=torch.int32)
+
+ return {"speech": speech[0, :, :],
+ "speech_lengths": speech_lengths,
+ "text": text,
+ "text_lengths": text_lengths,
+ "text2": text2,
+ "text2_lengths": text2_lengths,
+ }
+
+
+ def collator(self, samples: list=None):
+ outputs = {}
+ for sample in samples:
+ for key in sample.keys():
+ if key not in outputs:
+ outputs[key] = []
+ outputs[key].append(sample[key])
+
+ for key, data_list in outputs.items():
+ if isinstance(data_list[0], torch.Tensor):
+ if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
+
+ pad_value = self.int_pad_value
+ else:
+ pad_value = self.float_pad_value
+
+ outputs[key] = torch.nn.utils.rnn.pad_sequence(
+ data_list, batch_first=True, padding_value=pad_value
+ )
+ return outputs
diff --git a/funasr/download/download_model_from_hub.py b/funasr/download/download_model_from_hub.py
index df4f33d..8e51144 100644
--- a/funasr/download/download_model_from_hub.py
+++ b/funasr/download/download_model_from_hub.py
@@ -51,7 +51,8 @@
cfg = {}
if "file_path_metas" in conf_json:
add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
- cfg.update(kwargs)
+ # cfg.update(kwargs)
+ cfg = OmegaConf.merge(cfg, kwargs)
if "config" in cfg:
config = OmegaConf.load(cfg["config"])
kwargs = OmegaConf.merge(config, cfg)
@@ -159,15 +160,41 @@
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
if isinstance(file_path_metas, dict):
+ if isinstance(cfg, list):
+ cfg.append({})
for k, v in file_path_metas.items():
if isinstance(v, str):
p = os.path.join(model_or_path, v)
if os.path.exists(p):
- cfg[k] = p
+ if isinstance(cfg, dict):
+ cfg[k] = p
+ elif isinstance(cfg, list):
+ # if len(cfg) == 0:
+ # cfg.append({})
+ cfg[-1][k] = p
+
elif isinstance(v, dict):
- if k not in cfg:
- cfg[k] = {}
- add_file_root_path(model_or_path, v, cfg[k])
+ if isinstance(cfg, dict):
+ if k not in cfg:
+ cfg[k] = {}
+ add_file_root_path(model_or_path, v, cfg[k])
+ # elif isinstance(cfg, list):
+ # cfg.append({})
+ # add_file_root_path(model_or_path, v, cfg)
+ elif isinstance(v, (list, tuple)):
+ for i, vv in enumerate(v):
+ if k not in cfg:
+ cfg[k] = []
+ if isinstance(vv, str):
+ p = os.path.join(model_or_path, v)
+ file_path_metas[i] = p
+ if os.path.exists(p):
+ if isinstance(cfg[k], dict):
+ cfg[k] = p
+ elif isinstance(cfg[k], list):
+ cfg[k].append(p)
+ elif isinstance(vv, dict):
+ add_file_root_path(model_or_path, vv, cfg[k])
return cfg
diff --git a/funasr/models/ctc/ctc.py b/funasr/models/ctc/ctc.py
index bdfb3a6..8eb64d1 100644
--- a/funasr/models/ctc/ctc.py
+++ b/funasr/models/ctc/ctc.py
@@ -23,11 +23,17 @@
ctc_type: str = "builtin",
reduce: bool = True,
ignore_nan_grad: bool = True,
+ extra_linear: bool = True,
):
super().__init__()
eprojs = encoder_output_size
self.dropout_rate = dropout_rate
- self.ctc_lo = torch.nn.Linear(eprojs, odim)
+
+ if extra_linear:
+ self.ctc_lo = torch.nn.Linear(eprojs, odim)
+ else:
+ self.ctc_lo = None
+
self.ctc_type = ctc_type
self.ignore_nan_grad = ignore_nan_grad
@@ -130,7 +136,10 @@
ys_lens: batch of lengths of character sequence (B)
"""
# hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
- ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
+ if self.ctc_lo is not None:
+ ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
+ else:
+ ys_hat = hs_pad
if self.ctc_type == "gtnctc":
# gtn expects list form for ys
@@ -141,6 +150,7 @@
# (B, L) -> (BxL,)
ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
+ hlens = hlens.to(hs_pad.device)
loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to(
device=hs_pad.device, dtype=hs_pad.dtype
)
@@ -155,7 +165,10 @@
Returns:
torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
"""
- return F.softmax(self.ctc_lo(hs_pad), dim=2)
+ if self.ctc_lo is not None:
+ return F.softmax(self.ctc_lo(hs_pad), dim=2)
+ else:
+ return F.softmax(hs_pad, dim=2)
def log_softmax(self, hs_pad):
"""log_softmax of frame activations
@@ -165,7 +178,10 @@
Returns:
torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
"""
- return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
+ if self.ctc_lo is not None:
+ return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
+ else:
+ return F.log_softmax(hs_pad, dim=2)
def argmax(self, hs_pad):
"""argmax of frame activations
@@ -175,4 +191,7 @@
Returns:
torch.Tensor: argmax applied 2d tensor (B, Tmax)
"""
- return torch.argmax(self.ctc_lo(hs_pad), dim=2)
+ if self.ctc_lo is not None:
+ return torch.argmax(self.ctc_lo(hs_pad), dim=2)
+ else:
+ return torch.argmax(hs_pad, dim=2)
diff --git a/funasr/models/fsmn_kws/__init__.py b/funasr/models/fsmn_kws/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/fsmn_kws/__init__.py
diff --git a/funasr/models/fsmn_kws/encoder.py b/funasr/models/fsmn_kws/encoder.py
new file mode 100755
index 0000000..2c31687
--- /dev/null
+++ b/funasr/models/fsmn_kws/encoder.py
@@ -0,0 +1,534 @@
+from typing import Tuple, Dict
+import copy
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from funasr.register import tables
+
+
+def toKaldiMatrix(np_mat):
+ np.set_printoptions(threshold=np.inf, linewidth=np.nan)
+ out_str = str(np_mat)
+ out_str = out_str.replace('[', '')
+ out_str = out_str.replace(']', '')
+ return '[ %s ]\n' % out_str
+
+
+class LinearTransform(nn.Module):
+
+ def __init__(self, input_dim, output_dim):
+ super(LinearTransform, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.linear = nn.Linear(input_dim, output_dim, bias=False)
+
+ def forward(self, input):
+ output = self.linear(input)
+
+ return output
+
+ def to_kaldi_net(self):
+ re_str = ''
+ re_str += '<LinearTransform> %d %d\n' % (self.output_dim,
+ self.input_dim)
+ re_str += '<LearnRateCoef> 1\n'
+
+ linear_weights = self.state_dict()['linear.weight']
+ x = linear_weights.squeeze().numpy()
+ re_str += toKaldiMatrix(x)
+
+ return re_str
+
+ def to_pytorch_net(self, fread):
+ linear_line = fread.readline()
+ linear_split = linear_line.strip().split()
+ assert len(linear_split) == 3
+ assert linear_split[0] == '<LinearTransform>'
+ self.output_dim = int(linear_split[1])
+ self.input_dim = int(linear_split[2])
+
+ learn_rate_line = fread.readline()
+ assert learn_rate_line.find('LearnRateCoef') != -1
+
+ self.linear.reset_parameters()
+
+ linear_weights = self.state_dict()['linear.weight']
+ #print(linear_weights.shape)
+ new_weights = torch.zeros((self.output_dim, self.input_dim),
+ dtype=torch.float32)
+ for i in range(self.output_dim):
+ line = fread.readline()
+ splits = line.strip().strip('\[\]').strip().split()
+ assert len(splits) == self.input_dim
+ cols = torch.tensor([float(item) for item in splits],
+ dtype=torch.float32)
+ new_weights[i, :] = cols
+
+ self.linear.weight.data = new_weights
+
+
+class AffineTransform(nn.Module):
+
+ def __init__(self, input_dim, output_dim):
+ super(AffineTransform, self).__init__()
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.linear = nn.Linear(input_dim, output_dim)
+
+ def forward(self, input):
+ output = self.linear(input)
+
+ return output
+
+ def to_kaldi_net(self):
+ re_str = ''
+ re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
+ self.input_dim)
+ re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
+
+ linear_weights = self.state_dict()['linear.weight']
+ x = linear_weights.squeeze().numpy()
+ re_str += toKaldiMatrix(x)
+
+ linear_bias = self.state_dict()['linear.bias']
+ x = linear_bias.squeeze().numpy()
+ re_str += toKaldiMatrix(x)
+
+ return re_str
+
+ def to_pytorch_net(self, fread):
+ affine_line = fread.readline()
+ affine_split = affine_line.strip().split()
+ assert len(affine_split) == 3
+ assert affine_split[0] == '<AffineTransform>'
+ self.output_dim = int(affine_split[1])
+ self.input_dim = int(affine_split[2])
+ print('AffineTransform output/input dim: %d %d' %
+ (self.output_dim, self.input_dim))
+
+ learn_rate_line = fread.readline()
+ assert learn_rate_line.find('LearnRateCoef') != -1
+
+ #linear_weights = self.state_dict()['linear.weight']
+ #print(linear_weights.shape)
+ self.linear.reset_parameters()
+
+ new_weights = torch.zeros((self.output_dim, self.input_dim),
+ dtype=torch.float32)
+ for i in range(self.output_dim):
+ line = fread.readline()
+ splits = line.strip().strip('\[\]').strip().split()
+ assert len(splits) == self.input_dim
+ cols = torch.tensor([float(item) for item in splits],
+ dtype=torch.float32)
+ new_weights[i, :] = cols
+
+ self.linear.weight.data = new_weights
+
+ linear_bias = self.state_dict()['linear.bias']
+ #print(linear_bias.shape)
+ bias_line = fread.readline()
+ splits = bias_line.strip().strip('\[\]').strip().split()
+ assert len(splits) == self.output_dim
+ new_bias = torch.tensor([float(item) for item in splits],
+ dtype=torch.float32)
+
+ self.linear.bias.data = new_bias
+
+
+class RectifiedLinear(nn.Module):
+
+ def __init__(self, input_dim, output_dim):
+ super(RectifiedLinear, self).__init__()
+ self.dim = input_dim
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(0.1)
+
+ def forward(self, input):
+ out = self.relu(input)
+ return out
+
+ def to_kaldi_net(self):
+ re_str = ''
+ re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
+ return re_str
+
+ def to_pytorch_net(self, fread):
+ line = fread.readline()
+ splits = line.strip().split()
+ assert len(splits) == 3
+ assert splits[0] == '<RectifiedLinear>'
+ assert int(splits[1]) == int(splits[2])
+ assert int(splits[1]) == self.dim
+ self.dim = int(splits[1])
+
+
+class FSMNBlock(nn.Module):
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ lorder=None,
+ rorder=None,
+ lstride=1,
+ rstride=1,
+ ):
+ super(FSMNBlock, self).__init__()
+
+ self.dim = input_dim
+
+ if lorder is None:
+ return
+
+ self.lorder = lorder
+ self.rorder = rorder
+ self.lstride = lstride
+ self.rstride = rstride
+
+ self.conv_left = nn.Conv2d(
+ self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False
+ )
+
+ if self.rorder > 0:
+ self.conv_right = nn.Conv2d(
+ self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False
+ )
+ else:
+ self.conv_right = None
+
+ def forward(self, input: torch.Tensor, cache: torch.Tensor = None):
+ x = torch.unsqueeze(input, 1)
+ x_per = x.permute(0, 3, 2, 1) # B D T C
+
+ if cache is not None:
+ cache = cache.to(x_per.device)
+ y_left = torch.cat((cache, x_per), dim=2)
+ cache = y_left[:, :, -(self.lorder - 1) * self.lstride :, :]
+ else:
+ y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
+
+ y_left = self.conv_left(y_left)
+ out = x_per + y_left
+
+ if self.conv_right is not None:
+ # maybe need to check
+ y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
+ y_right = y_right[:, :, self.rstride :, :]
+ y_right = self.conv_right(y_right)
+ out += y_right
+
+ out_per = out.permute(0, 3, 2, 1)
+ output = out_per.squeeze(1)
+
+ return output, cache
+
+ def to_kaldi_net(self):
+ re_str = ''
+ re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
+ re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % (
+ 1, self.lorder, self.rorder, self.lstride, self.rstride)
+
+ #print(self.conv_left.weight,self.conv_right.weight)
+ lfiters = self.state_dict()['conv_left.weight']
+ x = np.flipud(lfiters.squeeze().numpy().T)
+ re_str += toKaldiMatrix(x)
+
+ if self.conv_right is not None:
+ rfiters = self.state_dict()['conv_right.weight']
+ x = (rfiters.squeeze().numpy().T)
+ re_str += toKaldiMatrix(x)
+
+ return re_str
+
+ def to_pytorch_net(self, fread):
+ fsmn_line = fread.readline()
+ fsmn_split = fsmn_line.strip().split()
+ assert len(fsmn_split) == 3
+ assert fsmn_split[0] == '<Fsmn>'
+ self.dim = int(fsmn_split[1])
+
+ params_line = fread.readline()
+ params_split = params_line.strip().strip('\[\]').strip().split()
+ assert len(params_split) == 12
+ assert params_split[0] == '<LearnRateCoef>'
+ assert params_split[2] == '<LOrder>'
+ self.lorder = int(params_split[3])
+ assert params_split[4] == '<ROrder>'
+ self.rorder = int(params_split[5])
+ assert params_split[6] == '<LStride>'
+ self.lstride = int(params_split[7])
+ assert params_split[8] == '<RStride>'
+ self.rstride = int(params_split[9])
+ assert params_split[10] == '<MaxNorm>'
+
+ #lfilters = self.state_dict()['conv_left.weight']
+ #print(lfilters.shape)
+ print('read conv_left weight')
+ new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1),
+ dtype=torch.float32)
+ for i in range(self.lorder):
+ print('read conv_left weight -- %d' % i)
+ line = fread.readline()
+ splits = line.strip().strip('\[\]').strip().split()
+ assert len(splits) == self.dim
+ cols = torch.tensor([float(item) for item in splits],
+ dtype=torch.float32)
+ new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols
+
+ new_lfilters = torch.transpose(new_lfilters, 0, 2)
+ #print(new_lfilters.shape)
+
+ self.conv_left.reset_parameters()
+ self.conv_left.weight.data = new_lfilters
+ #print(self.conv_left.weight.shape)
+
+ if self.rorder > 0:
+ #rfilters = self.state_dict()['conv_right.weight']
+ #print(rfilters.shape)
+ print('read conv_right weight')
+ new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1),
+ dtype=torch.float32)
+ line = fread.readline()
+ for i in range(self.rorder):
+ print('read conv_right weight -- %d' % i)
+ line = fread.readline()
+ splits = line.strip().strip('\[\]').strip().split()
+ assert len(splits) == self.dim
+ cols = torch.tensor([float(item) for item in splits],
+ dtype=torch.float32)
+ new_rfilters[i, 0, :, 0] = cols
+
+ new_rfilters = torch.transpose(new_rfilters, 0, 2)
+ #print(new_rfilters.shape)
+ self.conv_right.reset_parameters()
+ self.conv_right.weight.data = new_rfilters
+ #print(self.conv_right.weight.shape)
+
+class BasicBlock(nn.Module):
+ def __init__(
+ self,
+ linear_dim: int,
+ proj_dim: int,
+ lorder: int,
+ rorder: int,
+ lstride: int,
+ rstride: int,
+ stack_layer: int,
+ ):
+ super(BasicBlock, self).__init__()
+ self.lorder = lorder
+ self.rorder = rorder
+ self.lstride = lstride
+ self.rstride = rstride
+ self.stack_layer = stack_layer
+ self.linear = LinearTransform(linear_dim, proj_dim)
+ self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
+ self.affine = AffineTransform(proj_dim, linear_dim)
+ self.relu = RectifiedLinear(linear_dim, linear_dim)
+
+ def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor] = None):
+ x1 = self.linear(input) # B T D
+
+ if cache is not None:
+ cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+ if cache_layer_name not in cache:
+ cache[cache_layer_name] = torch.zeros(
+ x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1
+ )
+ x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
+ else:
+ x2, _ = self.fsmn_block(x1, None)
+ x3 = self.affine(x2)
+ x4 = self.relu(x3)
+ return x4
+
+ def to_kaldi_net(self):
+ re_str = ''
+ re_str += self.linear.to_kaldi_net()
+ re_str += self.fsmn_block.to_kaldi_net()
+ re_str += self.affine.to_kaldi_net()
+ re_str += self.relu.to_kaldi_net()
+
+ return re_str
+
+ def to_pytorch_net(self, fread):
+ self.linear.to_pytorch_net(fread)
+ self.fsmn_block.to_pytorch_net(fread)
+ self.affine.to_pytorch_net(fread)
+ self.relu.to_pytorch_net(fread)
+
+
+class BasicBlock_export(nn.Module):
+ def __init__(
+ self,
+ model,
+ ):
+ super(BasicBlock_export, self).__init__()
+ self.linear = model.linear
+ self.fsmn_block = model.fsmn_block
+ self.affine = model.affine
+ self.relu = model.relu
+
+ def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
+ x = self.linear(input) # B T D
+ # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+ # if cache_layer_name not in in_cache:
+ # in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+ x, out_cache = self.fsmn_block(x, in_cache)
+ x = self.affine(x)
+ x = self.relu(x)
+ return x, out_cache
+
+
+class FsmnStack(nn.Sequential):
+ def __init__(self, *args):
+ super(FsmnStack, self).__init__(*args)
+
+ def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
+ x = input
+ for module in self._modules.values():
+ x = module(x, cache)
+ return x
+
+ def to_kaldi_net(self):
+ re_str = ''
+ for module in self._modules.values():
+ re_str += module.to_kaldi_net()
+
+ return re_str
+
+ def to_pytorch_net(self, fread):
+ for module in self._modules.values():
+ module.to_pytorch_net(fread)
+
+
+"""
+FSMN net for keyword spotting
+input_dim: input dimension
+linear_dim: fsmn input dimensionll
+proj_dim: fsmn projection dimension
+lorder: fsmn left order
+rorder: fsmn right order
+num_syn: output dimension
+fsmn_layers: no. of sequential fsmn layers
+"""
+
+
+@tables.register("encoder_classes", "FSMNConvert")
+class FSMNConvert(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ input_affine_dim: int,
+ fsmn_layers: int,
+ linear_dim: int,
+ proj_dim: int,
+ lorder: int,
+ rorder: int,
+ lstride: int,
+ rstride: int,
+ output_affine_dim: int,
+ output_dim: int,
+ use_softmax: bool = True,
+ ):
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.input_affine_dim = input_affine_dim
+ self.fsmn_layers = fsmn_layers
+ self.linear_dim = linear_dim
+ self.proj_dim = proj_dim
+ self.output_affine_dim = output_affine_dim
+ self.output_dim = output_dim
+
+ self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+ self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+ self.relu = RectifiedLinear(linear_dim, linear_dim)
+ self.fsmn = FsmnStack(
+ *[
+ BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i)
+ for i in range(fsmn_layers)
+ ]
+ )
+ self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+ self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+
+ self.use_softmax = use_softmax
+ if self.use_softmax:
+ self.softmax = nn.Softmax(dim=-1)
+
+ def output_size(self) -> int:
+ return self.output_dim
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ cache: Dict[str, torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Args:
+ input (torch.Tensor): Input tensor (B, T, D)
+ cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs,
+ {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
+ """
+
+ x1 = self.in_linear1(input)
+ x2 = self.in_linear2(x1)
+ x3 = self.relu(x2)
+ x4 = self.fsmn(x3, cache) # self.cache will update automatically in self.fsmn
+ x5 = self.out_linear1(x4)
+ x6 = self.out_linear2(x5)
+
+ if self.use_softmax:
+ x7 = self.softmax(x6)
+ return x7
+
+ return x6
+
+ def to_kaldi_net(self):
+ re_str = ''
+ re_str += '<Nnet>\n'
+ re_str += self.in_linear1.to_kaldi_net()
+ re_str += self.in_linear2.to_kaldi_net()
+ re_str += self.relu.to_kaldi_net()
+
+ for fsmn in self.fsmn:
+ re_str += fsmn.to_kaldi_net()
+
+ re_str += self.out_linear1.to_kaldi_net()
+ re_str += self.out_linear2.to_kaldi_net()
+ re_str += '<Softmax> %d %d\n' % (self.output_dim, self.output_dim)
+ re_str += '</Nnet>\n'
+
+ return re_str
+
+ def to_pytorch_net(self, kaldi_file):
+ with open(kaldi_file, 'r', encoding='utf8') as fread:
+ fread = open(kaldi_file, 'r')
+ nnet_start_line = fread.readline()
+ assert nnet_start_line.strip() == '<Nnet>'
+
+ self.in_linear1.to_pytorch_net(fread)
+ self.in_linear2.to_pytorch_net(fread)
+ self.relu.to_pytorch_net(fread)
+
+ for fsmn in self.fsmn:
+ fsmn.to_pytorch_net(fread)
+
+ self.out_linear1.to_pytorch_net(fread)
+ self.out_linear2.to_pytorch_net(fread)
+
+ softmax_line = fread.readline()
+ softmax_split = softmax_line.strip().split()
+ assert softmax_split[0].strip() == '<Softmax>'
+ assert int(softmax_split[1]) == self.output_dim
+ assert int(softmax_split[2]) == self.output_dim
+
+ nnet_end_line = fread.readline()
+ assert nnet_end_line.strip() == '</Nnet>'
+ fread.close()
diff --git a/funasr/models/fsmn_kws/model.py b/funasr/models/fsmn_kws/model.py
new file mode 100644
index 0000000..066730f
--- /dev/null
+++ b/funasr/models/fsmn_kws/model.py
@@ -0,0 +1,285 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import time
+import torch
+import logging
+from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
+
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+
+@tables.register("model_classes", "FsmnKWS")
+class FsmnKWS(torch.nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+
+ def __init__(
+ self,
+ specaug: Optional[str] = None,
+ specaug_conf: Optional[Dict] = None,
+ normalize: str = None,
+ normalize_conf: Optional[Dict] = None,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ ctc: str = None,
+ ctc_conf: Optional[Dict] = None,
+ ctc_weight: float = 1.0,
+ input_size: int = 360,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ **kwargs,
+ ):
+ super().__init__()
+
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug)
+ specaug = specaug_class(**specaug_conf)
+
+ if normalize is not None:
+ normalize_class = tables.normalize_classes.get(normalize)
+ normalize = normalize_class(**normalize_conf)
+
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(**encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ if ctc_conf is None:
+ ctc_conf = {}
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ )
+
+ self.blank_id = blank_id
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.ctc_weight = ctc_weight
+
+ # self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ self.encoder = encoder
+ self.ctc = ctc
+
+ self.error_calculator = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+ batch_size = speech.shape[0]
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats = dict()
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ loss = self.ctc_weight * loss_ctc
+
+ stats["cer"] = cer_ctc
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ encoder_out = self.encoder(speech)
+ encoder_out_lens = speech_lengths
+
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ return encoder_out, encoder_out_lens
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+
+ return loss_ctc, cer_ctc
+
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list=None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+ keywords = kwargs.get("keywords")
+ from funasr.utils.kws_utils import KwsCtcPrefixDecoder
+ self.kws_decoder = KwsCtcPrefixDecoder(
+ ctc=self.ctc,
+ keywords=keywords,
+ token_list=tokenizer.token_list,
+ seg_dict=tokenizer.seg_dict,
+ )
+
+ meta_data = {}
+ if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is not None:
+ speech_lengths = speech_lengths.squeeze(-1)
+ else:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ results = []
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+
+ for i in range(encoder_out.size(0)):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ detect_result = self.kws_decoder.decode(x)
+ is_deted, det_keyword, det_score = detect_result[0], detect_result[1], detect_result[2]
+
+ if is_deted:
+ self.writer["detect"][key[i]] = "detected " + det_keyword + " " + str(det_score)
+ det_info = "detected " + det_keyword + " " + str(det_score)
+ else:
+ self.writer["detect"][key[i]] = "rejected"
+ det_info = "rejected"
+
+ result_i = {"key": key[i], "text": det_info}
+ results.append(result_i)
+
+ return results, meta_data
+
+
+@tables.register("model_classes", "FsmnKWSConvert")
+class FsmnKWSConvert(torch.nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+
+ def __init__(
+ self,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ ctc: str = None,
+ ctc_conf: Optional[Dict] = None,
+ ctc_weight: float = 1.0,
+ input_size: int = 360,
+ vocab_size: int = -1,
+ blank_id: int = 0,
+ **kwargs,
+ ):
+ super().__init__()
+
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(**encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ if ctc_conf is None:
+ ctc_conf = {}
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ )
+
+ self.blank_id = blank_id
+ self.vocab_size = vocab_size
+ self.ctc_weight = ctc_weight
+ self.encoder = encoder
+ self.ctc = ctc
+
+ self.error_calculator = None
+
+ def to_kaldi_net(self):
+ return self.encoder.to_kaldi_net()
+
+
+ def to_pytorch_net(self, kaldi_file):
+ return self.encoder.to_pytorch_net(kaldi_file)
diff --git a/funasr/models/fsmn_kws_mt/__init__.py b/funasr/models/fsmn_kws_mt/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/fsmn_kws_mt/__init__.py
diff --git a/funasr/models/fsmn_kws_mt/encoder.py b/funasr/models/fsmn_kws_mt/encoder.py
new file mode 100755
index 0000000..4b36d83
--- /dev/null
+++ b/funasr/models/fsmn_kws_mt/encoder.py
@@ -0,0 +1,213 @@
+from typing import Tuple, Dict
+import copy
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from funasr.models.fsmn_kws.encoder import (toKaldiMatrix, LinearTransform, AffineTransform, RectifiedLinear, FSMNBlock, FsmnStack, BasicBlock)
+
+
+from funasr.register import tables
+
+
+'''
+FSMN net for keyword spotting
+input_dim: input dimension
+linear_dim: fsmn input dimensionll
+proj_dim: fsmn projection dimension
+lorder: fsmn left order
+rorder: fsmn right order
+num_syn: output dimension
+fsmn_layers: no. of sequential fsmn layers
+'''
+
+@tables.register("encoder_classes", "FSMNMT")
+class FSMNMT(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ input_affine_dim: int,
+ fsmn_layers: int,
+ linear_dim: int,
+ proj_dim: int,
+ lorder: int,
+ rorder: int,
+ lstride: int,
+ rstride: int,
+ output_affine_dim: int,
+ output_dim: int,
+ output_dim2: int,
+ use_softmax: bool = True,
+ ):
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.input_affine_dim = input_affine_dim
+ self.fsmn_layers = fsmn_layers
+ self.linear_dim = linear_dim
+ self.proj_dim = proj_dim
+ self.output_affine_dim = output_affine_dim
+ self.output_dim = output_dim
+ self.output_dim2 = output_dim2
+
+ self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+ self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+ self.relu = RectifiedLinear(linear_dim, linear_dim)
+ self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+ range(fsmn_layers)])
+ self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+ self.out_linear1_2 = AffineTransform(linear_dim, output_affine_dim)
+ self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+ self.out_linear2_2 = AffineTransform(output_affine_dim, output_dim2)
+
+ self.use_softmax = use_softmax
+ if self.use_softmax:
+ self.softmax = nn.Softmax(dim=-1)
+
+ def output_size(self) -> int:
+ return self.output_dim
+
+ def output_size2(self) -> int:
+ return self.output_dim2
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ cache: Dict[str, torch.Tensor] = None
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+ """
+ Args:
+ input (torch.Tensor): Input tensor (B, T, D)
+ cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs,
+ {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
+ """
+
+ x1 = self.in_linear1(input)
+ x2 = self.in_linear2(x1)
+ x3 = self.relu(x2)
+ x4 = self.fsmn(x3, cache) # self.cache will update automatically in self.fsmn
+ x5 = self.out_linear1(x4)
+ x6 = self.out_linear2(x5)
+
+ x5_2 = self.out_linear1_2(x4)
+ x6_2 = self.out_linear2_2(x5_2)
+
+ if self.use_softmax:
+ x7 = self.softmax(x6)
+ x7_2 = self.softmax(x6_2)
+ return x7, x7_2
+
+ return x6, x6_2
+
+
+@tables.register("encoder_classes", "FSMNMTConvert")
+class FSMNMTConvert(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ input_affine_dim: int,
+ fsmn_layers: int,
+ linear_dim: int,
+ proj_dim: int,
+ lorder: int,
+ rorder: int,
+ lstride: int,
+ rstride: int,
+ output_affine_dim: int,
+ output_dim: int,
+ output_dim2: int,
+ use_softmax: bool = True,
+ ):
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.input_affine_dim = input_affine_dim
+ self.fsmn_layers = fsmn_layers
+ self.linear_dim = linear_dim
+ self.proj_dim = proj_dim
+ self.output_affine_dim = output_affine_dim
+ self.output_dim = output_dim
+ self.output_dim2 = output_dim2
+
+ self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+ self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+ self.relu = RectifiedLinear(linear_dim, linear_dim)
+ self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+ range(fsmn_layers)])
+ self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+ self.out_linear1_2 = AffineTransform(linear_dim, output_affine_dim)
+ self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+ self.out_linear2_2 = AffineTransform(output_affine_dim, output_dim2)
+
+ self.use_softmax = use_softmax
+ if self.use_softmax:
+ self.softmax = nn.Softmax(dim=-1)
+
+ def output_size(self) -> int:
+ return self.output_dim
+
+ def output_size2(self) -> int:
+ return self.output_dim2
+
+ def to_kaldi_net(self):
+ re_str = ''
+ re_str += '<Nnet>\n'
+ re_str += self.in_linear1.to_kaldi_net()
+ re_str += self.in_linear2.to_kaldi_net()
+ re_str += self.relu.to_kaldi_net()
+
+ for fsmn in self.fsmn:
+ re_str += fsmn.to_kaldi_net()
+
+ re_str += self.out_linear1.to_kaldi_net()
+ re_str += self.out_linear2.to_kaldi_net()
+ re_str += '<Softmax> %d %d\n' % (self.output_dim, self.output_dim)
+ re_str += '</Nnet>\n'
+
+ return re_str
+
+ def to_kaldi_net2(self):
+ re_str = ''
+ re_str += '<Nnet>\n'
+ re_str += self.in_linear1.to_kaldi_net()
+ re_str += self.in_linear2.to_kaldi_net()
+ re_str += self.relu.to_kaldi_net()
+
+ for fsmn in self.fsmn:
+ re_str += fsmn.to_kaldi_net()
+
+ re_str += self.out_linear1_2.to_kaldi_net()
+ re_str += self.out_linear2_2.to_kaldi_net()
+ re_str += '<Softmax> %d %d\n' % (self.output_dim2, self.output_dim2)
+ re_str += '</Nnet>\n'
+
+ return re_str
+
+ def to_pytorch_net(self, kaldi_file):
+ with open(kaldi_file, 'r', encoding='utf8') as fread:
+ fread = open(kaldi_file, 'r')
+ nnet_start_line = fread.readline()
+ assert nnet_start_line.strip() == '<Nnet>'
+
+ self.in_linear1.to_pytorch_net(fread)
+ self.in_linear2.to_pytorch_net(fread)
+ self.relu.to_pytorch_net(fread)
+
+ for fsmn in self.fsmn:
+ fsmn.to_pytorch_net(fread)
+
+ self.out_linear1.to_pytorch_net(fread)
+ self.out_linear2.to_pytorch_net(fread)
+
+ softmax_line = fread.readline()
+ softmax_split = softmax_line.strip().split()
+ assert softmax_split[0].strip() == '<Softmax>'
+ assert int(softmax_split[1]) == self.output_dim
+ assert int(softmax_split[2]) == self.output_dim
+
+ nnet_end_line = fread.readline()
+ assert nnet_end_line.strip() == '</Nnet>'
+ fread.close()
diff --git a/funasr/models/fsmn_kws_mt/model.py b/funasr/models/fsmn_kws_mt/model.py
new file mode 100644
index 0000000..c4645bb
--- /dev/null
+++ b/funasr/models/fsmn_kws_mt/model.py
@@ -0,0 +1,353 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import time
+import torch
+import logging
+from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
+
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+
+@tables.register("model_classes", "FsmnKWSMT")
+class FsmnKWSMT(torch.nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+
+ def __init__(
+ self,
+ specaug: Optional[str] = None,
+ specaug_conf: Optional[Dict] = None,
+ normalize: str = None,
+ normalize_conf: Optional[Dict] = None,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ ctc_conf: Optional[Dict] = None,
+ input_size: int = 360,
+ vocab_size: int = -1,
+ vocab_size2: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ **kwargs,
+ ):
+ super().__init__()
+
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug)
+ specaug = specaug_class(**specaug_conf)
+
+ if normalize is not None:
+ normalize_class = tables.normalize_classes.get(normalize)
+ normalize = normalize_class(**normalize_conf)
+
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(**encoder_conf)
+ encoder_output_size = encoder.output_size()
+ encoder_output_size2 = encoder.output_size2()
+
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ )
+ ctc2 = CTC(
+ odim=vocab_size2, encoder_output_size=encoder_output_size2, **ctc_conf
+ )
+
+ self.blank_id = blank_id
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+
+ # self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ self.encoder = encoder
+ self.ctc = ctc
+ self.ctc2 = ctc2
+
+ self.error_calculator = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ text2: torch.Tensor,
+ text2_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ text2: (Batch, Length)
+ text2_lengths: (Batch,)
+ """
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+ batch_size = speech.shape[0]
+
+ # Encoder
+ encoder_out, encoder_out2, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+ loss_ctc2, cer_ctc2 = self._calc_ctc_loss(
+ encoder_out2, encoder_out_lens, text2, text2_lengths
+ )
+
+ # Collect CTC branch stats
+ stats = dict()
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+ stats["loss_ctc2"] = loss_ctc2.detach() if loss_ctc2 is not None else None
+ stats["cer_ctc2"] = cer_ctc2
+
+ loss = 0.5 * loss_ctc + 0.5 * loss_ctc2
+
+ stats["cer"] = cer_ctc
+ stats["cer2"] = cer_ctc2
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ encoder_out, encoder_out2 = self.encoder(speech)
+ encoder_out_lens = speech_lengths
+
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ if isinstance(encoder_out2, tuple):
+ encoder_out2 = encoder_out2[0]
+
+ return encoder_out, encoder_out2, encoder_out_lens
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
+
+ def _calc_ctc2_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc2.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
+
+
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list=None,
+ tokenizer=None,
+ tokenizer2=None,
+ frontend=None,
+ **kwargs,
+ ):
+ keywords = kwargs.get("keywords")
+ from funasr.utils.kws_utils import KwsCtcPrefixDecoder
+ self.kws_decoder = KwsCtcPrefixDecoder(
+ ctc=self.ctc,
+ keywords=keywords,
+ token_list=tokenizer.token_list,
+ seg_dict=tokenizer.seg_dict,
+ )
+ self.kws_decoder2 = KwsCtcPrefixDecoder(
+ ctc=self.ctc2,
+ keywords=keywords,
+ token_list=tokenizer2.token_list,
+ seg_dict=tokenizer2.seg_dict,
+ )
+
+ meta_data = {}
+ if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is not None:
+ speech_lengths = speech_lengths.squeeze(-1)
+ else:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(
+ data_in,
+ fs=frontend.fs,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer
+ )
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(
+ audio_sample_list,
+ data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend
+ )
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ encoder_out, encoder_out2, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ if isinstance(encoder_out2, tuple):
+ encoder_out2 = encoder_out2[0]
+
+ results = []
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+
+ for i in range(encoder_out.size(0)):
+ x = encoder_out[i, :encoder_out_lens[i], :]
+ detect_result = self.kws_decoder.decode(x)
+ is_deted, det_keyword, det_score = detect_result[0], detect_result[1], detect_result[2]
+
+ if is_deted:
+ self.writer["detect"][key[i]] = "detected " + det_keyword + " " + str(det_score)
+ det_info = "detected " + det_keyword + " " + str(det_score)
+ else:
+ self.writer["detect"][key[i]] = "rejected"
+ det_info = "rejected"
+
+ x2 = encoder_out2[i, :encoder_out_lens[i], :]
+ detect_result2 = self.kws_decoder2.decode(x2)
+ is_deted2, det_keyword2, det_score2 = detect_result2[0], detect_result2[1], detect_result2[2]
+
+ if is_deted2:
+ self.writer["detect2"][key[i]] = "detected " + det_keyword2 + " " + str(det_score2)
+ det_info2 = "detected " + det_keyword2 + " " + str(det_score2)
+ else:
+ self.writer["detect2"][key[i]] = "rejected"
+ det_info2 = "rejected"
+
+ result_i = {"key": key[i], "text": det_info, "text2": det_info2}
+ results.append(result_i)
+
+ return results, meta_data
+
+
+@tables.register("model_classes", "FsmnKWSMTConvert")
+class FsmnKWSMTConvert(torch.nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+ https://arxiv.org/abs/1803.05030
+ """
+
+ def __init__(
+ self,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ ctc: str = None,
+ ctc_conf: Optional[Dict] = None,
+ ctc_weight: float = 1.0,
+ input_size: int = 360,
+ vocab_size: int = -1,
+ vocab_size2: int = -1,
+ blank_id: int = 0,
+ **kwargs,
+ ):
+ super().__init__()
+
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(**encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ if ctc_conf is None:
+ ctc_conf = {}
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ )
+
+ self.blank_id = blank_id
+ self.vocab_size = vocab_size
+ self.ctc_weight = ctc_weight
+ self.encoder = encoder
+ self.ctc = ctc
+
+ self.error_calculator = None
+
+ def to_kaldi_net(self):
+ return self.encoder.to_kaldi_net()
+
+ def to_kaldi_net2(self):
+ return self.encoder.to_kaldi_net2()
+
+ def to_pytorch_net(self, kaldi_file):
+ return self.encoder.to_pytorch_net(kaldi_file)
diff --git a/funasr/models/fsmn_vad_streaming/encoder.py b/funasr/models/fsmn_vad_streaming/encoder.py
index 6668c5d..14c2f5f 100755
--- a/funasr/models/fsmn_vad_streaming/encoder.py
+++ b/funasr/models/fsmn_vad_streaming/encoder.py
@@ -85,13 +85,17 @@
else:
self.conv_right = None
- def forward(self, input: torch.Tensor, cache: torch.Tensor):
+ def forward(self, input: torch.Tensor, cache: torch.Tensor = None):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) # B D T C
- cache = cache.to(x_per.device)
- y_left = torch.cat((cache, x_per), dim=2)
- cache = y_left[:, :, -(self.lorder - 1) * self.lstride :, :]
+ if cache is not None:
+ cache = cache.to(x_per.device)
+ y_left = torch.cat((cache, x_per), dim=2)
+ cache = y_left[:, :, -(self.lorder - 1) * self.lstride :, :]
+ else:
+ y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
+
y_left = self.conv_left(y_left)
out = x_per + y_left
@@ -130,14 +134,18 @@
self.affine = AffineTransform(proj_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
- def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
+ def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor] = None):
x1 = self.linear(input) # B T D
- cache_layer_name = "cache_layer_{}".format(self.stack_layer)
- if cache_layer_name not in cache:
- cache[cache_layer_name] = torch.zeros(
- x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1
- )
- x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
+
+ if cache is not None:
+ cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+ if cache_layer_name not in cache:
+ cache[cache_layer_name] = torch.zeros(
+ x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1
+ )
+ x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
+ else:
+ x2, _ = self.fsmn_block(x1, None)
x3 = self.affine(x2)
x4 = self.relu(x3)
return x4
@@ -203,6 +211,7 @@
rstride: int,
output_affine_dim: int,
output_dim: int,
+ use_softmax: bool = True,
):
super().__init__()
@@ -225,13 +234,21 @@
)
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
- self.softmax = nn.Softmax(dim=-1)
+
+ self.use_softmax = use_softmax
+ if self.use_softmax:
+ self.softmax = nn.Softmax(dim=-1)
def fuse_modules(self):
pass
+ def output_size(self) -> int:
+ return self.output_dim
+
def forward(
- self, input: torch.Tensor, cache: Dict[str, torch.Tensor]
+ self,
+ input: torch.Tensor,
+ cache: Dict[str, torch.Tensor] = None
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
@@ -246,9 +263,12 @@
x4 = self.fsmn(x3, cache) # self.cache will update automatically in self.fsmn
x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5)
- x7 = self.softmax(x6)
- return x7
+ if self.use_softmax:
+ x7 = self.softmax(x6)
+ return x7
+
+ return x6
@tables.register("encoder_classes", "FSMNExport")
@@ -276,6 +296,7 @@
# self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
# self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
# self.softmax = nn.Softmax(dim=-1)
+
self.in_linear1 = model.in_linear1
self.in_linear2 = model.in_linear2
self.relu = model.relu
@@ -317,88 +338,3 @@
x = self.softmax(x)
return x, out_caches
-
-
-"""
-one deep fsmn layer
-dimproj: projection dimension, input and output dimension of memory blocks
-dimlinear: dimension of mapping layer
-lorder: left order
-rorder: right order
-lstride: left stride
-rstride: right stride
-"""
-
-
-@tables.register("encoder_classes", "DFSMN")
-class DFSMN(nn.Module):
-
- def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
- super(DFSMN, self).__init__()
-
- self.lorder = lorder
- self.rorder = rorder
- self.lstride = lstride
- self.rstride = rstride
-
- self.expand = AffineTransform(dimproj, dimlinear)
- self.shrink = LinearTransform(dimlinear, dimproj)
-
- self.conv_left = nn.Conv2d(
- dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False
- )
-
- if rorder > 0:
- self.conv_right = nn.Conv2d(
- dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False
- )
- else:
- self.conv_right = None
-
- def forward(self, input):
- f1 = F.relu(self.expand(input))
- p1 = self.shrink(f1)
-
- x = torch.unsqueeze(p1, 1)
- x_per = x.permute(0, 3, 2, 1)
-
- y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
-
- if self.conv_right is not None:
- y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
- y_right = y_right[:, :, self.rstride :, :]
- out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
- else:
- out = x_per + self.conv_left(y_left)
-
- out1 = out.permute(0, 3, 2, 1)
- output = input + out1.squeeze(1)
-
- return output
-
-
-"""
-build stacked dfsmn layers
-"""
-
-
-def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
- repeats = [
- nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) for i in range(fsmn_layers)
- ]
-
- return nn.Sequential(*repeats)
-
-
-if __name__ == "__main__":
- fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
- print(fsmn)
-
- num_params = sum(p.numel() for p in fsmn.parameters())
- print("the number of model params: {}".format(num_params))
- x = torch.zeros(128, 200, 400) # batch-size * time * dim
- y, _ = fsmn(x) # batch-size * time * dim
- print("input shape: {}".format(x.shape))
- print("output shape: {}".format(y.shape))
-
- print(fsmn.to_kaldi_net())
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index dc30a94..0d39ca7 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -523,6 +523,7 @@
feats_dim=560,
model_name="encoder",
onnx: bool = True,
+ ctc_linear: nn.Module = None,
):
super().__init__()
self.embed = model.embed
@@ -553,6 +554,8 @@
self.num_heads = model.encoders[0].self_attn.h
self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
+ self.ctc_linear = ctc_linear
+
def prepare_mask(self, mask):
mask_3d_btd = mask[:, :, None]
if len(mask.shape) == 2:
@@ -566,6 +569,7 @@
def forward(self, speech: torch.Tensor, speech_lengths: torch.Tensor, online: bool = False):
if not online:
speech = speech * self._output_size**0.5
+
mask = self.make_pad_mask(speech_lengths)
mask = self.prepare_mask(mask)
if self.embed is None:
@@ -581,6 +585,10 @@
xs_pad = self.model.after_norm(xs_pad)
+ if self.ctc_linear is not None:
+ xs_pad = self.ctc_linear(xs_pad)
+ xs_pad = F.softmax(xs_pad, dim=2)
+
return xs_pad, speech_lengths
def get_output_size(self):
diff --git a/funasr/models/sanm_kws/__init__.py b/funasr/models/sanm_kws/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/sanm_kws/__init__.py
diff --git a/funasr/models/sanm_kws/export_meta.py b/funasr/models/sanm_kws/export_meta.py
new file mode 100644
index 0000000..a91a22c
--- /dev/null
+++ b/funasr/models/sanm_kws/export_meta.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import types
+import copy
+import torch
+from funasr.register import tables
+
+
+def export_rebuild_model(model, **kwargs):
+ # self.device = kwargs.get("device")
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+
+ if hasattr(model, "ctc"):
+ model.encoder = encoder_class(
+ model.encoder,
+ onnx=is_onnx,
+ feats_dim=kwargs.get("input_size", 560),
+ ctc_linear=model.ctc.ctc_lo
+ )
+ else:
+ assert False, print(model)
+ model.encoder = encoder_class(model.encoder, onnx=is_onnx, feats_dim=kwargs.get("input_size", 560))
+
+ # from funasr.utils.torch_function import sequence_mask
+ # model.make_pad_mask = sequence_mask(max_seq_len=None, flip=False)
+
+ encoder_model = copy.copy(model)
+
+ # encoder
+ encoder_model.forward = types.MethodType(export_encoder_forward, encoder_model)
+ encoder_model.export_dummy_inputs = types.MethodType(export_encoder_dummy_inputs, encoder_model)
+ encoder_model.export_input_names = types.MethodType(export_encoder_input_names, encoder_model)
+ encoder_model.export_output_names = types.MethodType(export_encoder_output_names, encoder_model)
+ encoder_model.export_dynamic_axes = types.MethodType(export_encoder_dynamic_axes, encoder_model)
+ encoder_model.export_name = types.MethodType(export_encoder_name, encoder_model)
+
+ return encoder_model
+
+
+def export_encoder_forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+):
+ # a. To device
+ batch = {
+ "speech": speech,
+ "speech_lengths": speech_lengths,
+ "online": True
+ }
+ # batch = to_device(batch, device=self.device)
+
+ encoder_out, encoder_out_len = self.encoder(**batch)
+ # mask = self.make_pad_mask(encoder_out_len)[:, None, :]
+ # alphas, _ = self.predictor.forward_cnn(enc, mask)
+
+ # return encoder_out, encoder_out_len, alphas
+ return encoder_out, encoder_out_len
+
+
+def export_encoder_dummy_inputs(self):
+ speech = torch.randn(2, 30, 280)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ return (speech, speech_lengths)
+
+
+def export_encoder_input_names(self):
+ return ["speech", "speech_lengths"]
+
+
+def export_encoder_output_names(self):
+ # return ["encoder_out", "encoder_out_len", "alphas"]
+ return ["encoder_out", "encoder_out_len"]
+
+
+def export_encoder_dynamic_axes(self):
+ return {
+ "speech": {
+ 0: "batch_size", 1: "feats_length"
+ },
+ "speech_lengths": {
+ 0: "batch_size",
+ },
+ "encoder_out": {
+ 0: "batch_size", 1: "feats_length"
+ },
+ "encoder_out_len": {
+ 0: "batch_size",
+ },
+ }
+
+
+def export_encoder_name(self):
+ return "encoder.onnx"
diff --git a/funasr/models/sanm_kws/model.py b/funasr/models/sanm_kws/model.py
new file mode 100644
index 0000000..d8d8f95
--- /dev/null
+++ b/funasr/models/sanm_kws/model.py
@@ -0,0 +1,266 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import time
+import torch
+import logging
+from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
+
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import to_device
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+
+@tables.register("model_classes", "SanmKWS")
+class SanmKWS(torch.nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ specaug: Optional[str] = None,
+ specaug_conf: Optional[Dict] = None,
+ normalize: str = None,
+ normalize_conf: Optional[Dict] = None,
+ encoder: str = None,
+ encoder_conf: Optional[Dict] = None,
+ ctc: str = None,
+ ctc_conf: Optional[Dict] = None,
+ ctc_weight: float = 1.0,
+ input_size: int = 360,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ **kwargs,
+ ):
+
+ super().__init__()
+
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug)
+ specaug = specaug_class(**specaug_conf)
+
+ if normalize is not None:
+ normalize_class = tables.normalize_classes.get(normalize)
+ normalize = normalize_class(**normalize_conf)
+
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ if ctc_conf is None:
+ ctc_conf = {}
+ ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
+
+ # note that eos is the same as sos (equivalent ID)
+ self.blank_id = blank_id
+ self.sos = sos if sos is not None else vocab_size - 1
+ self.eos = eos if eos is not None else vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.ctc_weight = ctc_weight
+ # self.token_list = token_list.copy()
+ #
+ # self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ self.encoder = encoder
+
+ self.ctc = ctc
+ self.error_calculator = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+ batch_size = speech.shape[0]
+
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ # decoder: CTC branch
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats = dict()
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+ stats["cer"] = cer_ctc
+
+ loss = loss_ctc
+
+ stats["loss"] = torch.clone(loss.detach())
+ stats["batch_size"] = batch_size
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ return encoder_out, encoder_out_lens
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
+
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+ keywords = kwargs.get("keywords")
+ from funasr.utils.kws_utils import KwsCtcPrefixDecoder
+ self.kws_decoder = KwsCtcPrefixDecoder(
+ ctc=self.ctc,
+ keywords=keywords,
+ token_list=tokenizer.token_list,
+ seg_dict=tokenizer.seg_dict,
+ )
+
+ meta_data = {}
+ if (
+ isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+ ): # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is not None:
+ speech_lengths = speech_lengths.squeeze(-1)
+ else:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio_text_image_video(
+ data_in,
+ fs=frontend.fs,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer,
+ )
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(
+ audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+ )
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = (
+ speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+ )
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+ # Encoder
+ if kwargs.get("fp16", False):
+ speech = speech.half()
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ results = []
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+
+ for i in range(encoder_out.size(0)):
+ x = encoder_out[i, : encoder_out_lens[i], :]
+ detect_result = self.kws_decoder.decode(x)
+ is_deted, det_keyword, det_score = detect_result[0], detect_result[1], detect_result[2]
+
+ if is_deted:
+ self.writer["detect"][key[i]] = "detected " + det_keyword + " " + str(det_score)
+ det_info = "detected " + det_keyword + " " + str(det_score)
+ else:
+ self.writer["detect"][key[i]] = "rejected"
+ det_info = "rejected"
+
+ result_i = {"key": key[i], "text": det_info}
+ results.append(result_i)
+
+ return results, meta_data
+
+ def export(self, **kwargs):
+ from .export_meta import export_rebuild_model
+
+ if "max_seq_len" not in kwargs:
+ kwargs["max_seq_len"] = 512
+ models = export_rebuild_model(model=self, **kwargs)
+ return models
diff --git a/funasr/models/sanm_kws_streaming/__init__.py b/funasr/models/sanm_kws_streaming/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/sanm_kws_streaming/__init__.py
diff --git a/funasr/models/sanm_kws_streaming/export_meta.py b/funasr/models/sanm_kws_streaming/export_meta.py
new file mode 100644
index 0000000..a91a22c
--- /dev/null
+++ b/funasr/models/sanm_kws_streaming/export_meta.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import types
+import copy
+import torch
+from funasr.register import tables
+
+
+def export_rebuild_model(model, **kwargs):
+ # self.device = kwargs.get("device")
+ is_onnx = kwargs.get("type", "onnx") == "onnx"
+ encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+
+ if hasattr(model, "ctc"):
+ model.encoder = encoder_class(
+ model.encoder,
+ onnx=is_onnx,
+ feats_dim=kwargs.get("input_size", 560),
+ ctc_linear=model.ctc.ctc_lo
+ )
+ else:
+ assert False, print(model)
+ model.encoder = encoder_class(model.encoder, onnx=is_onnx, feats_dim=kwargs.get("input_size", 560))
+
+ # from funasr.utils.torch_function import sequence_mask
+ # model.make_pad_mask = sequence_mask(max_seq_len=None, flip=False)
+
+ encoder_model = copy.copy(model)
+
+ # encoder
+ encoder_model.forward = types.MethodType(export_encoder_forward, encoder_model)
+ encoder_model.export_dummy_inputs = types.MethodType(export_encoder_dummy_inputs, encoder_model)
+ encoder_model.export_input_names = types.MethodType(export_encoder_input_names, encoder_model)
+ encoder_model.export_output_names = types.MethodType(export_encoder_output_names, encoder_model)
+ encoder_model.export_dynamic_axes = types.MethodType(export_encoder_dynamic_axes, encoder_model)
+ encoder_model.export_name = types.MethodType(export_encoder_name, encoder_model)
+
+ return encoder_model
+
+
+def export_encoder_forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+):
+ # a. To device
+ batch = {
+ "speech": speech,
+ "speech_lengths": speech_lengths,
+ "online": True
+ }
+ # batch = to_device(batch, device=self.device)
+
+ encoder_out, encoder_out_len = self.encoder(**batch)
+ # mask = self.make_pad_mask(encoder_out_len)[:, None, :]
+ # alphas, _ = self.predictor.forward_cnn(enc, mask)
+
+ # return encoder_out, encoder_out_len, alphas
+ return encoder_out, encoder_out_len
+
+
+def export_encoder_dummy_inputs(self):
+ speech = torch.randn(2, 30, 280)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ return (speech, speech_lengths)
+
+
+def export_encoder_input_names(self):
+ return ["speech", "speech_lengths"]
+
+
+def export_encoder_output_names(self):
+ # return ["encoder_out", "encoder_out_len", "alphas"]
+ return ["encoder_out", "encoder_out_len"]
+
+
+def export_encoder_dynamic_axes(self):
+ return {
+ "speech": {
+ 0: "batch_size", 1: "feats_length"
+ },
+ "speech_lengths": {
+ 0: "batch_size",
+ },
+ "encoder_out": {
+ 0: "batch_size", 1: "feats_length"
+ },
+ "encoder_out_len": {
+ 0: "batch_size",
+ },
+ }
+
+
+def export_encoder_name(self):
+ return "encoder.onnx"
diff --git a/funasr/models/sanm_kws_streaming/model.py b/funasr/models/sanm_kws_streaming/model.py
new file mode 100644
index 0000000..c459e7c
--- /dev/null
+++ b/funasr/models/sanm_kws_streaming/model.py
@@ -0,0 +1,442 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import time
+import torch
+import logging
+from typing import Dict, Tuple
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.sanm_kws.model import SanmKWS
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+@tables.register("model_classes", "SanmKWSStreaming")
+class SanmKWSStreaming(SanmKWS):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ decoding_ind = kwargs.get("decoding_ind")
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # Encoder
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+ else:
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ # decoder: CTC branch
+ if hasattr(self.encoder, "overlap_chunk_cls"):
+ encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(
+ encoder_out, encoder_out_lens, chunk_outs=None
+ )
+ else:
+ encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
+
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats = dict()
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ loss = loss_ctc
+
+ stats["cer"] = cer_ctc
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode_chunk(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ cache: dict = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+ # Forward encoder
+ encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
+ speech, speech_lengths, cache=cache["encoder"]
+ )
+
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ return encoder_out, torch.tensor([encoder_out.size(1)])
+
+ def init_cache(self, cache: dict = {}, **kwargs):
+ chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+ encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
+ decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
+ batch_size = 1
+
+ enc_output_size = kwargs["encoder_conf"]["output_size"]
+ feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
+ cache_encoder = {
+ "start_idx": 0,
+ "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
+ "cif_alphas": torch.zeros((batch_size, 1)),
+ "encoder_out": None,
+ "encoder_out_lens": None,
+ "chunk_size": chunk_size,
+ "encoder_chunk_look_back": encoder_chunk_look_back,
+ "last_chunk": False,
+ "opt": None,
+ "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+ "tail_chunk": False,
+ }
+ cache["encoder"] = cache_encoder
+
+ cache_decoder = {
+ "decode_fsmn": None,
+ "decoder_chunk_look_back": decoder_chunk_look_back,
+ "opt": None,
+ "chunk_size": chunk_size,
+ }
+ cache["decoder"] = cache_decoder
+ cache["frontend"] = {}
+ cache["prev_samples"] = torch.empty(0)
+
+ return cache
+
+ def generate_chunk(
+ self,
+ speech,
+ speech_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+ cache = kwargs.get("cache", {})
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+ # Encoder
+ is_final = kwargs.get("is_final", False)
+ encoder_out, encoder_out_lens = self.encode_chunk(
+ speech, speech_lengths, cache=cache, is_final=is_final
+ )
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ chunk_size = cache["encoder"]["chunk_size"]
+ real_start_pos = chunk_size[0]
+
+ if encoder_out_lens[0] > chunk_size[0] + chunk_size[1] + chunk_size[2]:
+ assert False, print("impossible case 1 !")
+ if encoder_out_lens[0] == chunk_size[0] + chunk_size[1] + chunk_size[2]:
+ real_end_pos = chunk_size[0] + chunk_size[1]
+ elif encoder_out_lens[0] > chunk_size[0] + chunk_size[1]:
+ real_end_pos = chunk_size[0] + chunk_size[1]
+ elif encoder_out_lens[0] > chunk_size[0]:
+ real_end_pos = encoder_out_lens[0]
+ else:
+ assert False, print("impossible case 2 !")
+
+ encoder_out_accum = cache["encoder"]["encoder_out"]
+ if encoder_out_accum is not None:
+ encoder_out_accum = torch.cat((encoder_out_accum, encoder_out[:, real_start_pos:real_end_pos, :]), dim=1)
+ else:
+ encoder_out_accum = encoder_out[:, real_start_pos:real_end_pos, :]
+ cache["encoder"]["encoder_out"] = encoder_out_accum
+
+ if cache["encoder"]["encoder_out_lens"] is not None:
+ cache["encoder"]["encoder_out_lens"][0] += real_end_pos - real_start_pos
+ else:
+ cache["encoder"]["encoder_out_lens"] = encoder_out_lens
+ cache["encoder"]["encoder_out_lens"][0] = real_end_pos - real_start_pos
+
+ if is_final:
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+
+ results = []
+ for i in range(encoder_out_accum.size(0)):
+ x = encoder_out_accum[i, : cache["encoder"]["encoder_out_lens"][i], :]
+ detect_result = self.kws_decoder.decode(x)
+ is_deted, det_keyword, det_score = detect_result[0], detect_result[1], detect_result[2]
+
+ if is_deted:
+ self.writer["detect"][key[i]] = "detected " + det_keyword + " " + str(det_score)
+ det_info = "detected " + det_keyword + " " + str(det_score)
+ else:
+ self.writer["detect"][key[i]] = "rejected"
+ det_info = "rejected"
+
+ result_i = {"key": key[i], "text": det_info}
+ results.append(result_i)
+
+ return results
+ else:
+ return None
+
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ cache: dict = {},
+ **kwargs,
+ ):
+ keywords = kwargs.get("keywords")
+ from funasr.utils.kws_utils import KwsCtcPrefixDecoder
+ self.kws_decoder = KwsCtcPrefixDecoder(
+ ctc=self.ctc,
+ keywords=keywords,
+ token_list=tokenizer.token_list,
+ seg_dict=tokenizer.seg_dict,
+ )
+
+ meta_data = {}
+ chunk_size = kwargs["chunk_size"]
+ chunk_stride_samples = int(chunk_size[1] * 960) # 600ms
+ first_chunk_padding_samples = int(chunk_size[2] * 960) # 600ms
+
+ if len(cache) == 0:
+ self.init_cache(cache, **kwargs)
+
+ time1 = time.perf_counter()
+ cfg = {"is_final": kwargs.get("is_final", False)}
+ audio_sample_list = load_audio_text_image_video(
+ data_in,
+ fs=frontend.fs,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer,
+ cache=cfg,
+ )
+ _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
+
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ assert len(audio_sample_list) == 1, "batch_size must be set 1"
+
+ audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
+
+ if len(audio_sample) < first_chunk_padding_samples:
+ print("key: {}, audio is too short for inference {}".format(key, len(audio_sample)))
+
+ audio_sample_pre = audio_sample[0 : first_chunk_padding_samples]
+ feat_pre, feat_pre_lengths = extract_fbank(
+ [audio_sample_pre],
+ data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend,
+ cache=cache["frontend"],
+ is_final=False,
+ )
+
+ audio_sample = audio_sample[first_chunk_padding_samples:]
+ audio_chunks = int(len(audio_sample) // chunk_stride_samples)
+
+ for i in range(audio_chunks):
+ if i == 0:
+ cache["encoder"]["feats"][:, chunk_size[2] :, :] = feat_pre
+
+ kwargs["is_final"] = False
+ audio_sample_i = audio_sample[i * chunk_stride_samples : (i + 1) * chunk_stride_samples]
+
+ if kwargs["is_final"] and len(audio_sample_i) < 960:
+ cache["encoder"]["tail_chunk"] = True
+ speech = cache["encoder"]["feats"]
+ speech_lengths = torch.tensor([speech.shape[1]], dtype=torch.int64).to(
+ speech.device
+ )
+ else:
+ # extract fbank feats
+ speech, speech_lengths = extract_fbank(
+ [audio_sample_i],
+ data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend,
+ cache=cache["frontend"],
+ is_final=kwargs["is_final"],
+ )
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = (
+ speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+ )
+
+ results_chunk_i = self.generate_chunk(
+ speech,
+ speech_lengths,
+ key=key,
+ tokenizer=tokenizer,
+ cache=cache,
+ frontend=frontend,
+ **kwargs,
+ )
+
+ # results_chunk_i must be None when is_final=False
+ assert results_chunk_i is None
+
+ # process tail samples
+ tail_audio_sample = audio_sample[ audio_chunks * chunk_stride_samples: ]
+ if len(tail_audio_sample) < 960:
+ kwargs["is_final"] = _is_final
+ cache["encoder"]["tail_chunk"] = True
+ speech = cache["encoder"]["feats"]
+ speech_lengths = torch.tensor([speech.shape[1]], dtype=torch.int64).to(
+ speech.device
+ )
+ results_chunk_tail = self.generate_chunk(
+ speech,
+ speech_lengths,
+ key=key,
+ tokenizer=tokenizer,
+ cache=cache,
+ frontend=frontend,
+ **kwargs,
+ )
+ elif len(tail_audio_sample) <= first_chunk_padding_samples:
+ kwargs["is_final"] = _is_final
+ # extract fbank feats
+ # cache["encoder"]["tail_chunk"] = True # cannot be true
+ speech, speech_lengths = extract_fbank(
+ [ tail_audio_sample ],
+ data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend,
+ cache=cache["frontend"],
+ is_final=kwargs["is_final"],
+ )
+ results_chunk_tail = self.generate_chunk(
+ speech,
+ speech_lengths,
+ key=key,
+ tokenizer=tokenizer,
+ cache=cache,
+ frontend=frontend,
+ **kwargs,
+ )
+ elif len(tail_audio_sample) > first_chunk_padding_samples and \
+ len(tail_audio_sample) < chunk_stride_samples:
+ kwargs["is_final"] = False
+ # extract fbank feats
+ speech, speech_lengths = extract_fbank(
+ [ tail_audio_sample ],
+ data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend,
+ cache=cache["frontend"],
+ is_final=kwargs["is_final"],
+ )
+ results_chunk = self.generate_chunk(
+ speech,
+ speech_lengths,
+ key=key,
+ tokenizer=tokenizer,
+ cache=cache,
+ frontend=frontend,
+ **kwargs,
+ )
+ # results_chunk must be None when is_final=False
+ assert results_chunk is None
+
+ # push tail chunk
+ kwargs["is_final"] = _is_final
+ cache["encoder"]["tail_chunk"] = True
+ speech = cache["encoder"]["feats"]
+ speech_lengths = torch.tensor([speech.shape[1]], dtype=torch.int64).to(
+ speech.device
+ )
+ results_chunk_tail = self.generate_chunk(
+ speech,
+ speech_lengths,
+ key=key,
+ tokenizer=tokenizer,
+ cache=cache,
+ frontend=frontend,
+ **kwargs,
+ )
+
+ result = results_chunk_tail
+
+ if _is_final:
+ self.init_cache(cache, **kwargs)
+
+ if kwargs.get("output_dir"):
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+
+ return result, meta_data
+
+ def export(self, **kwargs):
+ from .export_meta import export_rebuild_model
+
+ models = export_rebuild_model(model=self, **kwargs)
+ return models
diff --git a/funasr/models/transformer/scorers/ctc.py b/funasr/models/transformer/scorers/ctc.py
index 73a14bd..eb19d18 100644
--- a/funasr/models/transformer/scorers/ctc.py
+++ b/funasr/models/transformer/scorers/ctc.py
@@ -7,7 +7,6 @@
from funasr.models.transformer.scorers.ctc_prefix_score import CTCPrefixScoreTH
from funasr.models.transformer.scorers.scorer_interface import BatchPartialScorerInterface
-
class CTCPrefixScorer(BatchPartialScorerInterface):
"""Decoder interface wrapper for CTCPrefixScore."""
diff --git a/funasr/tokenizer/char_tokenizer.py b/funasr/tokenizer/char_tokenizer.py
index 805ecd0..7b517da 100644
--- a/funasr/tokenizer/char_tokenizer.py
+++ b/funasr/tokenizer/char_tokenizer.py
@@ -50,9 +50,7 @@
)
def text2tokens(self, line: Union[str, list]) -> List[str]:
-
# if self.split_with_space:
-
if self.seg_dict is not None:
tokens = line.strip().split(" ")
tokens = seg_tokenize(tokens, self.seg_dict)
diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py
index 67f1e55..873f419 100644
--- a/funasr/train_utils/average_nbest_models.py
+++ b/funasr/train_utils/average_nbest_models.py
@@ -30,8 +30,8 @@
map_location="cpu",
)
avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
- val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
- sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
+ val_step_or_epoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_epoch"]
+ sorted_items = sorted(val_step_or_epoch.items(), key=lambda x: x[1], reverse=True)
sorted_items = (
sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:]
)
@@ -53,6 +53,7 @@
checkpoint_files.sort(key=lambda x: int(re.search(r"(\d+)", x).group()), reverse=True)
# Get the last 'last_n' checkpoint paths
checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
+
return checkpoint_paths
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 8ed613c..da2eed5 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -8,6 +8,7 @@
import torch.nn
import torch.optim
import pdb
+import copy
def load_pretrained_model(
@@ -35,11 +36,12 @@
logging.info(f"ckpt: {path}")
if oss_bucket is None:
- src_state = torch.load(path, map_location=map_location)
+ ori_state = torch.load(path, map_location=map_location)
else:
buffer = BytesIO(oss_bucket.get_object(path).read())
- src_state = torch.load(buffer, map_location=map_location)
+ ori_state = torch.load(buffer, map_location=map_location)
+ src_state = copy.deepcopy(ori_state)
src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
src_state = src_state["model"] if "model" in src_state else src_state
@@ -94,7 +96,6 @@
)
else:
dst_state[k] = src_state[k_src]
-
else:
print(f"Warning, miss key in ckpt: {k}, {path}")
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 665a7af..5fe34b9 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -115,8 +115,8 @@
self.saved_ckpts = {}
self.step_or_epoch = -1
self.best_step_or_epoch = ""
- self.val_acc_step_or_eoch = {}
- self.val_loss_step_or_eoch = {}
+ self.val_acc_step_or_epoch = {}
+ self.val_loss_step_or_epoch = {}
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
self.start_data_split_i = 0
@@ -161,12 +161,14 @@
# self.step_or_epoch += 1
state = {
"epoch": epoch,
+ 'step': step,
+ 'total_step': self.batch_total,
"state_dict": model.state_dict(),
"optimizer": optim.state_dict(),
"scheduler": scheduler.state_dict(),
"saved_ckpts": self.saved_ckpts,
- "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
- "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+ "val_acc_step_or_epoch": self.val_acc_step_or_epoch,
+ "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
@@ -183,6 +185,7 @@
if scaler:
state["scaler_state"] = scaler.state_dict()
+
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
if step is None:
@@ -191,47 +194,48 @@
ckpt_name = f"model.pt.ep{epoch}.{step}"
filename = os.path.join(self.output_dir, ckpt_name)
torch.save(state, filename)
+ logging.info(f'Checkpoint saved to {filename}')
- logging.info(f"\nCheckpoint saved to {filename}\n")
- latest = Path(os.path.join(self.output_dir, f"model.pt"))
+ latest = Path(os.path.join(self.output_dir, f'model.pt'))
torch.save(state, latest)
+
if self.best_step_or_epoch == "":
self.best_step_or_epoch = ckpt_name
if self.avg_keep_nbest_models_type == "acc":
if (
- self.val_acc_step_or_eoch[ckpt_name]
- >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
+ self.val_acc_step_or_epoch[ckpt_name]
+ >= self.val_acc_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
- f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
- f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
elif self.avg_keep_nbest_models_type == "loss":
if (
- self.val_loss_step_or_eoch[ckpt_name]
- <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
+ self.val_loss_step_or_epoch[ckpt_name]
+ <= self.val_loss_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
- f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
- f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
else:
print("Undo")
self.saved_ckpts[ckpt_name] = getattr(
- self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
+ self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
)[ckpt_name]
if self.keep_nbest_models > 0:
if len(self.saved_ckpts) > self.keep_nbest_models:
@@ -278,6 +282,7 @@
k_ddp = k.replace("module.", "", 1)
else:
k_ddp = k
+
if k_ddp in src_state.keys():
dst_state[k] = src_state[k_ddp]
else:
@@ -290,14 +295,14 @@
scaler.load_state_dict(checkpoint["scaler_state"])
self.saved_ckpts = checkpoint["saved_ckpts"]
- self.val_acc_step_or_eoch = (
- checkpoint["val_acc_step_or_eoch"]
- if "val_acc_step_or_eoch" in checkpoint
+ self.val_acc_step_or_epoch = (
+ checkpoint["val_acc_step_or_epoch"]
+ if "val_acc_step_or_epoch" in checkpoint
else {}
)
- self.val_loss_step_or_eoch = (
- checkpoint["val_loss_step_or_eoch"]
- if "val_loss_step_or_eoch" in checkpoint
+ self.val_loss_step_or_epoch = (
+ checkpoint["val_loss_step_or_epoch"]
+ if "val_loss_step_or_epoch" in checkpoint
else {}
)
self.best_step_or_epoch = (
@@ -327,6 +332,7 @@
if self.use_ddp or self.use_fsdp:
dist.barrier()
+
def train_epoch(
self,
@@ -559,12 +565,14 @@
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
batch = to_device(batch, self.device)
+
time2 = time.perf_counter()
retval = model(**batch)
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
+
if self.use_ddp or self.use_fsdp:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
@@ -577,28 +585,33 @@
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
+
# Scale the loss since we're not updating for every mini-batch
loss = loss
time4 = time.perf_counter()
- self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
- batch_idx + 1
- )
- if "acc" in stats:
- self.val_acc_avg = (
- self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
- ) / (batch_idx + 1)
- if self.use_ddp or self.use_fsdp:
- val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
- self.device
+ if torch.isfinite(loss):
+ self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
+ batch_idx + 1
)
- val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
- self.device
- )
- dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
- dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
- self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
- self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+
+ if "acc" in stats:
+ self.val_acc_avg = (
+ self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
+ ) / (batch_idx + 1)
+
+ if self.use_ddp or self.use_fsdp:
+ val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
+ self.device
+ )
+ val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
+ self.device
+ )
+ dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+ self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+ self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+
time5 = time.perf_counter()
batch_num_epoch = 1
if hasattr(dataloader_val, "__len__"):
@@ -624,8 +637,8 @@
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
- self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
- self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+ self.val_acc_step_or_epoch[ckpt_name] = self.val_acc_avg
+ self.val_loss_step_or_epoch[ckpt_name] = self.val_loss_avg
model.train()
if self.use_ddp or self.use_fsdp:
diff --git a/funasr/utils/compute_det_ctc.py b/funasr/utils/compute_det_ctc.py
new file mode 100644
index 0000000..47c8182
--- /dev/null
+++ b/funasr/utils/compute_det_ctc.py
@@ -0,0 +1,286 @@
+""" This implementation is adapted from https://github.com/wenet-e2e/wekws/blob/main/wekws/bin/compute_det.py."""
+
+import os
+import json
+import logging
+import argparse
+import threading
+
+import kaldiio
+import torch
+from funasr.utils.kws_utils import split_mixed_label
+
+
+class thread_wrapper(threading.Thread):
+ def __init__(self, func, args=()):
+ super(thread_wrapper, self).__init__()
+ self.func = func
+ self.args = args
+ self.result = []
+
+ def run(self):
+ self.result = self.func(*self.args)
+
+ def get_result(self):
+ try:
+ return self.result
+ except Exception:
+ return None
+
+
+def space_mixed_label(input_str):
+ splits = split_mixed_label(input_str)
+ space_str = ''.join(f'{sub} ' for sub in splits)
+ return space_str.strip()
+
+
+def read_lists(list_file):
+ lists = []
+ with open(list_file, 'r', encoding='utf8') as fin:
+ for line in fin:
+ if line.strip() != '':
+ lists.append(line.strip())
+ return lists
+
+
+def make_pair(wav_lists, trans_lists):
+ logging.info('make pair for wav-trans list')
+
+ trans_table = {}
+ for line in trans_lists:
+ arr = line.strip().replace('\t', ' ').split()
+ if len(arr) < 2:
+ logging.debug('invalid line in trans file: {}'.format(
+ line.strip()))
+ continue
+
+ trans_table[arr[0]] = line.replace(arr[0],'').strip()
+
+ lists = []
+ for line in wav_lists:
+ arr = line.strip().replace('\t', ' ').split()
+ if len(arr) == 2 and arr[0] in trans_table:
+ lists.append(
+ dict(key=arr[0],
+ txt=trans_table[arr[0]],
+ wav=arr[1],
+ sample_rate=16000))
+ else:
+ logging.debug("can't find corresponding trans for key: {}".format(
+ arr[0]))
+ continue
+
+ return lists
+
+
+def count_duration(tid, data_lists):
+ results = []
+
+ for obj in data_lists:
+ assert 'key' in obj
+ assert 'wav' in obj
+ assert 'txt' in obj
+ key = obj['key']
+ wav_file = obj['wav']
+ txt = obj['txt']
+
+ try:
+ rate, waveform = kaldiio.load_mat(wav_file)
+ waveform = torch.tensor(waveform, dtype=torch.float32)
+ waveform = waveform.unsqueeze(0)
+ frames = len(waveform[0])
+ duration = frames / float(rate)
+ except:
+ logging.info(f'load file failed: {wav_file}')
+ duration = 0.0
+
+ obj['duration'] = duration
+ results.append(obj)
+
+ return results
+
+
+def load_data_and_score(keywords_list, data_file, trans_file, score_file):
+ # score_table: {uttid: [keywordlist]}
+ score_table = {}
+ with open(score_file, 'r', encoding='utf8') as fin:
+ # read score file and store in table
+ for line in fin:
+ arr = line.strip().split()
+ key = arr[0]
+ is_detected = arr[1]
+ if is_detected == 'detected':
+ if key not in score_table:
+ score_table.update(
+ {key: {
+ 'kw': space_mixed_label(arr[2]),
+ 'confi': float(arr[3])
+ }})
+ else:
+ if key not in score_table:
+ score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
+
+ wav_lists = read_lists(data_file)
+ trans_lists = read_lists(trans_file)
+ data_lists = make_pair(wav_lists, trans_lists)
+ logging.info(f'origin list samples: {len(data_lists)}')
+
+ # count duration for each wave
+ num_workers = 8
+ start = 0
+ step = int(len(data_lists) / num_workers)
+ tasks = []
+ for idx in range(num_workers):
+ if idx != num_workers - 1:
+ task = thread_wrapper(count_duration,
+ (idx, data_lists[start:start + step]))
+ else:
+ task = thread_wrapper(count_duration, (idx, data_lists[start:]))
+ task.start()
+ tasks.append(task)
+ start += step
+
+ duration_lists = []
+ for task in tasks:
+ task.join()
+ duration_lists += task.get_result()
+ logging.info(f'after list samples: {len(duration_lists)}')
+
+ # build empty structure for keyword-filler infos
+ keyword_filler_table = {}
+ for keyword in keywords_list:
+ keyword = space_mixed_label(keyword)
+ keyword_filler_table[keyword] = {}
+ keyword_filler_table[keyword]['keyword_table'] = {}
+ keyword_filler_table[keyword]['keyword_duration'] = 0.0
+ keyword_filler_table[keyword]['filler_table'] = {}
+ keyword_filler_table[keyword]['filler_duration'] = 0.0
+
+ for obj in duration_lists:
+ assert 'key' in obj
+ assert 'wav' in obj
+ assert 'txt' in obj
+ assert 'duration' in obj
+
+ key = obj['key']
+ wav_file = obj['wav']
+ txt = obj['txt']
+ txt = space_mixed_label(txt)
+ txt_regstr_lrblk = ' ' + txt + ' '
+ duration = obj['duration']
+ assert key in score_table
+
+ for keyword in keywords_list:
+ keyword = space_mixed_label(keyword)
+ keyword_regstr_lrblk = ' ' + keyword + ' '
+ if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
+ if keyword == score_table[key]['kw']:
+ keyword_filler_table[keyword]['keyword_table'].update(
+ {key: score_table[key]['confi']})
+ else:
+ # uttrance detected but not match this keyword
+ keyword_filler_table[keyword]['keyword_table'].update(
+ {key: -1.0})
+ keyword_filler_table[keyword]['keyword_duration'] += duration
+ else:
+ if keyword == score_table[key]['kw']:
+ keyword_filler_table[keyword]['filler_table'].update(
+ {key: score_table[key]['confi']})
+ else:
+ # uttrance if detected, which is not FA for this keyword
+ keyword_filler_table[keyword]['filler_table'].update(
+ {key: -1.0})
+ keyword_filler_table[keyword]['filler_duration'] += duration
+
+ return keyword_filler_table
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='compute det curve')
+ parser.add_argument('--keywords',
+ type=str,
+ required=True,
+ help='preset keyword str, input all keywords')
+ parser.add_argument('--test_data', required=True, help='test data file')
+ parser.add_argument('--trans_data',
+ required=True,
+ default='',
+ help='transcription of test data')
+ parser.add_argument('--score_file', required=True, help='score file')
+ parser.add_argument('--step',
+ type=float,
+ default=0.001,
+ help='threshold step')
+ parser.add_argument('--stats_dir',
+ required=True,
+ help='to save det stats files')
+ args = parser.parse_args()
+
+ root_logger = logging.getLogger()
+ handlers = root_logger.handlers[:]
+ for handler in handlers:
+ root_logger.removeHandler(handler)
+ handler.close()
+
+ logging.basicConfig(level=logging.DEBUG,
+ format='%(asctime)s %(levelname)s %(message)s')
+
+ keywords_list = args.keywords.strip().split(',')
+ keyword_filler_table = load_data_and_score(keywords_list, args.test_data,
+ args.trans_data,
+ args.score_file)
+
+ stats_files = {}
+ for keyword in keywords_list:
+ keyword = space_mixed_label(keyword)
+ keyword_dur = keyword_filler_table[keyword]['keyword_duration']
+ keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
+ filler_dur = keyword_filler_table[keyword]['filler_duration']
+ filler_num = len(keyword_filler_table[keyword]['filler_table'])
+ if keyword_num <= 0:
+ print('Can\'t compute det for {} without positive sample'.format(keyword))
+ continue
+ if filler_num <= 0:
+ print('Can\'t compute det for {} without negative sample'.format(keyword))
+ continue
+
+ logging.info('Computing det for {}'.format(keyword))
+ logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
+ keyword_dur / 3600.0, keyword_num))
+ logging.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0))
+
+ stats_file = os.path.join(args.stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
+ with open(stats_file, 'w', encoding='utf8') as fout:
+ threshold = 0.0
+ while threshold <= 1.0:
+ num_false_reject = 0
+ num_true_detect = 0
+ # transverse the all keyword_table
+ for key, confi in keyword_filler_table[keyword][
+ 'keyword_table'].items():
+ if confi < threshold:
+ num_false_reject += 1
+ else:
+ num_true_detect += 1
+
+ num_false_alarm = 0
+ # transverse the all filler_table
+ for key, confi in keyword_filler_table[keyword][
+ 'filler_table'].items():
+ if confi >= threshold:
+ num_false_alarm += 1
+ # print(f'false alarm: {keyword}, {key}, {confi}')
+
+ # false_reject_rate = num_false_reject / keyword_num
+ true_detect_rate = num_true_detect / keyword_num
+
+ num_false_alarm = max(num_false_alarm, 1e-6)
+ false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
+ false_alarm_rate = num_false_alarm / filler_num
+
+ fout.write('{:.3f} {:.6f} {:.6f} {:.6f}\n'.format(
+ threshold, true_detect_rate, false_alarm_rate,
+ false_alarm_per_hour))
+ threshold += args.step
+
+ stats_files[keyword] = stats_file
diff --git a/funasr/utils/kws_utils.py b/funasr/utils/kws_utils.py
new file mode 100644
index 0000000..4935040
--- /dev/null
+++ b/funasr/utils/kws_utils.py
@@ -0,0 +1,284 @@
+import re
+import logging
+
+import torch
+import math
+from collections import defaultdict
+from typing import List, Optional, Tuple
+
+
+symbol_str = '[鈥�!"#$%&\'()*+,-./:;<>=?@锛屻��?鈽呫�佲�︺�愩�戙�娿�嬶紵鈥溾�濃�樷�欙紒[\\]^_`{|}~\s]+'
+
+
+def split_mixed_label(input_str):
+ tokens = []
+ s = input_str.lower()
+ while len(s) > 0:
+ match = re.match(r'[A-Za-z!?,<>()\']+', s)
+ if match is not None:
+ word = match.group(0)
+ else:
+ word = s[0:1]
+ tokens.append(word)
+ s = s.replace(word, '', 1).strip(' ')
+ return tokens
+
+
+def query_token_set(txt, symbol_table, lexicon_table):
+ tokens_str = tuple()
+ tokens_idx = tuple()
+
+ if txt in symbol_table:
+ tokens_str = tokens_str + (txt, )
+ tokens_idx = tokens_idx + (symbol_table[txt], )
+ return tokens_str, tokens_idx
+
+ parts = split_mixed_label(txt)
+ for part in parts:
+ if part == '!sil' or part == '(sil)' or part == '<sil>':
+ tokens_str = tokens_str + ('!sil', )
+ elif part == '<blank>' or part == '<blank>':
+ tokens_str = tokens_str + ('<blank>', )
+ elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
+ tokens_str = tokens_str + ('<unk>', )
+ elif part in symbol_table:
+ tokens_str = tokens_str + (part, )
+ elif part in lexicon_table:
+ for ch in lexicon_table[part]:
+ tokens_str = tokens_str + (ch, )
+ else:
+ part = re.sub(symbol_str, '', part)
+ for ch in part:
+ tokens_str = tokens_str + (ch, )
+
+ for ch in tokens_str:
+ if ch in symbol_table:
+ tokens_idx = tokens_idx + (symbol_table[ch], )
+ elif ch == '!sil':
+ if 'sil' in symbol_table:
+ tokens_idx = tokens_idx + (symbol_table['sil'], )
+ else:
+ tokens_idx = tokens_idx + (symbol_table['<blank>'], )
+ elif ch == '<unk>':
+ if '<unk>' in symbol_table:
+ tokens_idx = tokens_idx + (symbol_table['<unk>'], )
+ else:
+ tokens_idx = tokens_idx + (symbol_table['<blank>'], )
+ else:
+ if '<unk>' in symbol_table:
+ tokens_idx = tokens_idx + (symbol_table['<unk>'], )
+ logging.info(f'\'{ch}\' is not in token set, replace with <unk>')
+ else:
+ tokens_idx = tokens_idx + (symbol_table['<blank>'], )
+ logging.info(f'\'{ch}\' is not in token set, replace with <blank>')
+
+ return tokens_str, tokens_idx
+
+
+class KwsCtcPrefixDecoder():
+ """Decoder interface wrapper for CTCPrefixDecode."""
+
+ def __init__(
+ self,
+ ctc: torch.nn.Module,
+ keywords: str,
+ token_list: list,
+ seg_dict: dict,
+ ):
+ """Initialize class.
+
+ Args:
+ ctc (torch.nn.Module): The CTC implementation.
+ For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
+
+ """
+ self.ctc = ctc
+ self.token_list = token_list
+
+ token_table = {}
+ for token in token_list:
+ token_table[token] = token_list.index(token)
+
+ self.keywords_idxset = {0}
+ self.keywords_token = {}
+ self.keywords_str = keywords
+ keywords_list = self.keywords_str.strip().replace(' ', '').split(',')
+ for keyword in keywords_list:
+ strs, indexs = query_token_set(keyword, token_table, seg_dict)
+ self.keywords_token[keyword] = {}
+ self.keywords_token[keyword]['token_id'] = indexs
+ self.keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) for i in indexs)
+ [ self.keywords_idxset.add(i) for i in indexs ]
+
+ def beam_search(
+ self,
+ logits: torch.Tensor,
+ logits_lengths: torch.Tensor,
+ keywords_tokenset: set = None,
+ score_beam_size: int = 3,
+ path_beam_size: int = 20,
+ ) -> Tuple[List[List[int]], torch.Tensor]:
+ """ CTC prefix beam search inner implementation
+
+ Args:
+ logits (torch.Tensor): (1, max_len, vocab_size)
+ logits_lengths (torch.Tensor): (1, )
+ keywords_tokenset (set): token set for filtering score
+ score_beam_size (int): beam size for score
+ path_beam_size (int): beam size for path
+
+ Returns:
+ List[List[int]]: nbest results
+ """
+
+ maxlen = logits.size(0)
+ ctc_probs = logits
+ cur_hyps = [(tuple(), (1.0, 0.0, []))]
+
+ # CTC beam search step by step
+ for t in range(0, maxlen):
+ probs = ctc_probs[t] # (vocab_size,)
+ # key: prefix, value (pb, pnb), default value(-inf, -inf)
+ next_hyps = defaultdict(lambda: (0.0, 0.0, []))
+
+ # 2.1 First beam prune: select topk best
+ top_k_probs, top_k_index = probs.topk(
+ score_beam_size) # (score_beam_size,)
+
+ # filter prob score that is too small
+ filter_probs = []
+ filter_index = []
+ for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
+ if keywords_tokenset is not None:
+ if prob > 0.05 and idx in keywords_tokenset:
+ filter_probs.append(prob)
+ filter_index.append(idx)
+ else:
+ if prob > 0.05:
+ filter_probs.append(prob)
+ filter_index.append(idx)
+
+ if len(filter_index) == 0:
+ continue
+
+ for s in filter_index:
+ ps = probs[s].item()
+ # print(f'frame:{t}, token:{s}, score:{ps}')
+
+ for prefix, (pb, pnb, cur_nodes) in cur_hyps:
+ last = prefix[-1] if len(prefix) > 0 else None
+ if s == 0: # blank
+ n_pb, n_pnb, nodes = next_hyps[prefix]
+ n_pb = n_pb + pb * ps + pnb * ps
+ nodes = cur_nodes.copy()
+ next_hyps[prefix] = (n_pb, n_pnb, nodes)
+ elif s == last:
+ if not math.isclose(pnb, 0.0, abs_tol=0.000001):
+ # Update *ss -> *s;
+ n_pb, n_pnb, nodes = next_hyps[prefix]
+ n_pnb = n_pnb + pnb * ps
+ nodes = cur_nodes.copy()
+ if ps > nodes[-1]['prob']: # update frame and prob
+ nodes[-1]['prob'] = ps
+ nodes[-1]['frame'] = t
+ next_hyps[prefix] = (n_pb, n_pnb, nodes)
+
+ if not math.isclose(pb, 0.0, abs_tol=0.000001):
+ # Update *s-s -> *ss, - is for blank
+ n_prefix = prefix + (s, )
+ n_pb, n_pnb, nodes = next_hyps[n_prefix]
+ n_pnb = n_pnb + pb * ps
+ nodes = cur_nodes.copy()
+ nodes.append(dict(token=s, frame=t,
+ prob=ps)) # to record token prob
+ next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
+ else:
+ n_prefix = prefix + (s, )
+ n_pb, n_pnb, nodes = next_hyps[n_prefix]
+ if nodes:
+ if ps > nodes[-1]['prob']: # update frame and prob
+ nodes[-1]['prob'] = ps
+ nodes[-1]['frame'] = t
+ else:
+ nodes = cur_nodes.copy()
+ nodes.append(dict(token=s, frame=t,
+ prob=ps)) # to record token prob
+ n_pnb = n_pnb + pb * ps + pnb * ps
+ next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
+
+ # 2.2 Second beam prune
+ next_hyps = sorted(next_hyps.items(),
+ key=lambda x: (x[1][0] + x[1][1]),
+ reverse=True)
+
+ cur_hyps = next_hyps[:path_beam_size]
+
+ hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
+ return hyps
+
+
+ def is_sublist(self, main_list, check_list):
+ if len(main_list) < len(check_list):
+ return -1
+
+ if len(main_list) == len(check_list):
+ return 0 if main_list == check_list else -1
+
+ for i in range(len(main_list) - len(check_list)):
+ if main_list[i] == check_list[0]:
+ for j in range(len(check_list)):
+ if main_list[i + j] != check_list[j]:
+ break
+ else:
+ return i
+ else:
+ return -1
+
+
+ def _decode_inside(
+ self,
+ logits: torch.Tensor,
+ logits_lengths: torch.Tensor,
+ ):
+ hyps = self.beam_search(logits, logits_lengths, self.keywords_idxset)
+
+ hit_keyword = None
+ hit_score = 1.0
+ # start = 0; end = 0
+ for one_hyp in hyps:
+ prefix_ids = one_hyp[0]
+ # path_score = one_hyp[1]
+ prefix_nodes = one_hyp[2]
+ assert len(prefix_ids) == len(prefix_nodes)
+ for word in self.keywords_token.keys():
+ lab = self.keywords_token[word]['token_id']
+ offset = self.is_sublist(prefix_ids, lab)
+ if offset != -1:
+ hit_keyword = word
+ for idx in range(offset, offset + len(lab)):
+ hit_score *= prefix_nodes[idx]['prob']
+ break
+ if hit_keyword is not None:
+ hit_score = math.sqrt(hit_score)
+ break
+
+ if hit_keyword is not None:
+ return True, hit_keyword, hit_score
+ else:
+ return False, None, None
+
+
+ def decode(self, x: torch.Tensor):
+ """Get an initial state for decoding.
+
+ Args:
+ x (torch.Tensor): The encoded feature tensor
+
+ Returns: decode result
+
+ """
+
+ raw_logp = self.ctc.softmax(x.unsqueeze(0)).detach().squeeze(0).cpu()
+ xlen = torch.tensor([raw_logp.size(1)])
+
+ return self._decode_inside(raw_logp, xlen)
diff --git a/funasr/utils/types.py b/funasr/utils/type_utils.py
similarity index 100%
rename from funasr/utils/types.py
rename to funasr/utils/type_utils.py
--
Gitblit v1.9.1