From 2196844d1d6e5b8732c95896bb46f0eacdd9cf9d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 25 九月 2024 15:10:50 +0800
Subject: [PATCH] Dev kws (#2105)

---
 examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_250_128_fdim80_t2599.yaml       |   95 +
 examples/industrial_data_pretraining/fsmn_kws/convert.sh                                         |   26 
 funasr/models/fsmn_kws_mt/model.py                                                               |  353 +++
 examples/industrial_data_pretraining/fsmn_kws_mt/path.sh                                         |    5 
 funasr/models/sanm_kws/model.py                                                                  |  266 ++
 funasr/datasets/audio_datasets/scp2jsonl.py                                                      |   15 
 funasr/models/sanm_kws_streaming/model.py                                                        |  442 ++++
 examples/industrial_data_pretraining/sanm_kws_streaming/funasr                                   |    1 
 funasr/models/sanm_kws/export_meta.py                                                            |   98 +
 examples/industrial_data_pretraining/sanm_kws_streaming/export.sh                                |   17 
 examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh                             |   44 
 examples/industrial_data_pretraining/fsmn_kws/finetune.sh                                        |  173 +
 funasr/models/fsmn_kws_mt/encoder.py                                                             |  213 ++
 funasr/models/fsmn_kws/__init__.py                                                               |    0 
 funasr/tokenizer/char_tokenizer.py                                                               |    2 
 examples/industrial_data_pretraining/fsmn_kws/path.sh                                            |    5 
 funasr/train_utils/average_nbest_models.py                                                       |    5 
 examples/industrial_data_pretraining/sanm_kws/export.sh                                          |   17 
 examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh                                     |  186 ++
 funasr/models/fsmn_vad_streaming/encoder.py                                                      |  136 -
 funasr/utils/type_utils.py                                                                       |    0 
 examples/industrial_data_pretraining/fsmn_kws_mt/convert.py                                      |  141 +
 examples/industrial_data_pretraining/sanm_kws/funasr                                             |    1 
 funasr/models/sanm_kws/__init__.py                                                               |    0 
 examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml |  103 +
 funasr/utils/compute_det_ctc.py                                                                  |  286 +++
 funasr/models/fsmn_kws_mt/__init__.py                                                            |    0 
 examples/industrial_data_pretraining/fsmn_kws/convert.py                                         |  134 +
 examples/industrial_data_pretraining/sanm_kws_streaming/finetune.sh                              |  258 ++
 examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh                                        |   20 
 funasr/train_utils/trainer.py                                                                    |   93 
 examples/industrial_data_pretraining/sanm_kws/infer.sh                                           |   20 
 funasr/models/transformer/scorers/ctc.py                                                         |    1 
 examples/industrial_data_pretraining/fsmn_kws_mt/funasr                                          |    1 
 funasr/download/download_model_from_hub.py                                                       |   37 
 examples/industrial_data_pretraining/fsmn_kws/funasr                                             |    1 
 funasr/datasets/kws_datasets/datasets.py                                                         |  132 +
 funasr/models/fsmn_kws/model.py                                                                  |  285 +++
 examples/industrial_data_pretraining/sanm_kws/finetune.sh                                        |  172 +
 funasr/models/sanm_kws_streaming/export_meta.py                                                  |   98 +
 funasr/bin/train.py                                                                              |    3 
 examples/industrial_data_pretraining/sanm_kws/path.sh                                            |    5 
 funasr/utils/kws_utils.py                                                                        |  284 +++
 examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_280_200_fdim40_t2602.yaml       |   95 +
 examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml |  103 +
 funasr/datasets/audio_datasets/datasets.py                                                       |   33 
 examples/industrial_data_pretraining/sanm_kws_streaming/path.sh                                  |    5 
 funasr/datasets/kws_datasets/__init__.py                                                         |    0 
 examples/industrial_data_pretraining/fsmn_kws/infer.sh                                           |   20 
 examples/industrial_data_pretraining/sanm_kws_streaming/conf/sanm_6e_320_256_fdim40_t2602.yaml   |  109 +
 funasr/models/fsmn_kws/encoder.py                                                                |  534 +++++
 examples/industrial_data_pretraining/sanm_kws/infer_from_local.sh                                |   41 
 funasr/auto/auto_model.py                                                                        |   69 
 funasr/models/sanm_kws_streaming/__init__.py                                                     |    0 
 examples/industrial_data_pretraining/sanm_kws/conf/sanm_6e_320_256_fdim40_t2602.yaml             |   94 +
 examples/industrial_data_pretraining/sanm_kws_streaming/infer.sh                                 |   34 
 funasr/train_utils/load_pretrained_model.py                                                      |    7 
 examples/industrial_data_pretraining/fsmn_kws/infer_from_local.sh                                |   41 
 examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh                                      |   36 
 funasr/models/ctc/ctc.py                                                                         |   29 
 examples/industrial_data_pretraining/sanm_kws_streaming/infer_from_local.sh                      |   62 
 funasr/models/sanm/encoder.py                                                                    |    8 
 62 files changed, 5,302 insertions(+), 192 deletions(-)

diff --git a/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_250_128_fdim80_t2599.yaml b/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_250_128_fdim80_t2599.yaml
new file mode 100644
index 0000000..666da0c
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_250_128_fdim80_t2599.yaml
@@ -0,0 +1,95 @@
+
+# network architecture
+model: FsmnKWS
+model_conf:
+    ctc_weight: 1.0
+
+# encoder related
+encoder: FSMN
+encoder_conf:
+    input_dim: 400
+    input_affine_dim: 140
+    fsmn_layers: 4
+    linear_dim: 250
+    proj_dim: 128
+    lorder: 10
+    rorder: 2
+    lstride: 1
+    rstride: 1
+    output_affine_dim: 140
+    output_dim: 2599
+    use_softmax: false
+
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 5
+    lfr_n: 3
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 3
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 100
+  keep_nbest_models: 10
+  avg_nbest_model: 10
+  avg_keep_nbest_models_type: loss
+  validate_interval: 50000
+  save_checkpoint_interval: 50000
+  avg_checkpoint_interval: 1000
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 10000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: EspnetStyleBatchSampler
+    batch_type: length # example or length
+    batch_size: 32000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 2048
+    shuffle: true
+    num_workers: 8
+    preprocessor_speech: SpeechPreprocessSpeedPerturb
+    preprocessor_speech_conf:
+      speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+    unk_symbol: <unk>
+    split_with_space: true
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+    extra_linear: false
+
+normalize: null
diff --git a/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_280_200_fdim40_t2602.yaml b/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_280_200_fdim40_t2602.yaml
new file mode 100644
index 0000000..6ad59de
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/conf/fsmn_4e_l10r2_280_200_fdim40_t2602.yaml
@@ -0,0 +1,95 @@
+
+# network architecture
+model: FsmnKWS
+model_conf:
+    ctc_weight: 1.0
+
+# encoder related
+encoder: FSMN
+encoder_conf:
+    input_dim: 360
+    input_affine_dim: 280
+    fsmn_layers: 4
+    linear_dim: 280
+    proj_dim: 200
+    lorder: 10
+    rorder: 2
+    lstride: 1
+    rstride: 1
+    output_affine_dim: 400
+    output_dim: 2602
+    use_softmax: false
+
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 40
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 9
+    lfr_n: 3
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 3
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 100
+  keep_nbest_models: 10
+  avg_nbest_model: 10
+  avg_keep_nbest_models_type: loss
+  validate_interval: 50000
+  save_checkpoint_interval: 50000
+  avg_checkpoint_interval: 1000
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 10000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: EspnetStyleBatchSampler
+    batch_type: length # example or length
+    batch_size: 32000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 2048
+    shuffle: true
+    num_workers: 8
+    preprocessor_speech: SpeechPreprocessSpeedPerturb
+    preprocessor_speech_conf:
+      speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+    unk_symbol: <unk>
+    split_with_space: true
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin
+    reduce: true
+    ignore_nan_grad: true
+    extra_linear: false
+
+normalize: null
diff --git a/examples/industrial_data_pretraining/fsmn_kws/convert.py b/examples/industrial_data_pretraining/fsmn_kws/convert.py
new file mode 100644
index 0000000..1609ef4
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/convert.py
@@ -0,0 +1,134 @@
+from __future__ import print_function
+
+import argparse
+import copy
+import logging
+import os
+from shutil import copyfile
+
+import torch
+import yaml
+from typing import Union
+
+
+from funasr.models.fsmn_kws.model import FsmnKWSConvert
+
+
+def count_parameters(model):
+    return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+def get_args():
+    parser = argparse.ArgumentParser(
+        description=
+        'load and convert network to each other between kaldi/pytorch format')
+    parser.add_argument('--config', required=True, help='config file')
+    parser.add_argument(
+        '--network_file',
+        default='',
+        required=True,
+        help='input network, support kaldi.txt/pytorch.pt')
+    parser.add_argument('--model_dir', required=True, help='save model dir')
+    parser.add_argument('--model_name', required=True, help='save model name')
+    parser.add_argument('--convert_to',
+                        default='kaldi',
+                        required=True,
+                        help='target network type, kaldi/pytorch')
+
+    args = parser.parse_args()
+    return args
+
+
+def convert_to_kaldi(
+    configs,
+    network_file,
+    model_dir,
+    model_name="convert.kaldi.txt"
+):
+    copyfile(network_file, os.path.join(model_dir, 'origin.torch.pt'))
+
+    model = FsmnKWSConvert(
+            vocab_size=configs['encoder_conf']['output_dim'],
+            encoder='FSMNConvert',
+            encoder_conf=configs['encoder_conf'],
+            ctc_conf=configs['ctc_conf'],
+    )
+    print(model)
+    num_params = count_parameters(model)
+    print('the number of model params: {}'.format(num_params))
+
+    states= torch.load(network_file, map_location='cpu')
+    model.load_state_dict(states["state_dict"])
+
+    kaldi_text = os.path.join(model_dir, model_name)
+    with open(kaldi_text, 'w', encoding='utf8') as fout:
+        nnet_desp = model.to_kaldi_net()
+        fout.write(nnet_desp)
+    fout.close()
+
+
+def convert_to_pytorch(
+    configs,
+    network_file,
+    model_dir,
+    model_name="convert.torch.pt"
+):
+    model = FsmnKWSConvert(
+            vocab_size=configs['encoder_conf']['output_dim'],
+            frontend=None,
+            specaug=None,
+            normalize=None,
+            encoder='FSMNConvert',
+            encoder_conf=configs['encoder_conf'],
+            ctc_conf=configs['ctc_conf'],
+    )
+
+    num_params = count_parameters(model)
+    print('the number of model params: {}'.format(num_params))
+
+    copyfile(network_file, os.path.join(model_dir, 'origin.kaldi.txt'))
+    model.to_pytorch_net(network_file)
+
+    save_model_path = os.path.join(model_dir, model_name)
+    torch.save({"model": model.state_dict()}, save_model_path)
+
+    print('convert torch format back to kaldi')
+    kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt')
+    with open(kaldi_text, 'w', encoding='utf8') as fout:
+        nnet_desp = model.to_kaldi_net()
+        fout.write(nnet_desp)
+    fout.close()
+
+    print('Done!')
+
+
+def main():
+    args = get_args()
+    logging.basicConfig(level=logging.DEBUG,
+                        format='%(asctime)s %(levelname)s %(message)s')
+    print(args)
+    with open(args.config, 'r') as fin:
+        configs = yaml.load(fin, Loader=yaml.FullLoader)
+
+    if args.convert_to == 'pytorch':
+        print('convert kaldi net to pytorch...')
+        convert_to_pytorch(
+            configs,
+            args.network_file,
+            args.model_dir,
+            args.model_name
+        )
+    elif args.convert_to == 'kaldi':
+        print('convert pytorch net to kaldi...')
+        convert_to_kaldi(
+            configs,
+            args.network_file,
+            args.model_dir,
+            args.model_name
+        )
+    else:
+        print('unsupported target network type: {}'.format(args.convert_to))
+
+
+if __name__ == '__main__':
+    main()
diff --git a/examples/industrial_data_pretraining/fsmn_kws/convert.sh b/examples/industrial_data_pretraining/fsmn_kws/convert.sh
new file mode 100644
index 0000000..1ec12f0
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/convert.sh
@@ -0,0 +1,26 @@
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models_kws
+mkdir -p ${local_path_root}
+
+local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
+if [ ! -d "$local_path" ]; then
+    git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
+fi
+
+export PATH=${local_path}/runtime:$PATH
+export LD_LIBRARY_PATH=${local_path}/runtime:$LD_LIBRARY_PATH
+
+config=./conf/fsmn_4e_l10r2_250_128_fdim80_t2599.yaml
+torch_nnet=exp/finetune_outputs/model.pt.avg10
+out_dir=exp/finetune_outputs
+
+if [ ! -d "$out_dir" ]; then
+    mkdir -p $out_dir
+fi
+
+python convert.py --config $config --network_file $torch_nnet --model_dir $out_dir --model_name "convert.kaldi.txt" --convert_to kaldi
+
+nnet-copy --binary=true ${out_dir}/convert.kaldi.txt ${out_dir}/convert.kaldi.net
diff --git a/examples/industrial_data_pretraining/fsmn_kws/finetune.sh b/examples/industrial_data_pretraining/fsmn_kws/finetune.sh
new file mode 100755
index 0000000..59a2922
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/finetune.sh
@@ -0,0 +1,173 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+. ./path.sh
+workspace=`pwd`
+
+CUDA_VISIBLE_DEVICES="0,1"
+
+stage=2
+stop_stage=3
+
+inference_device="cuda" #"cpu"
+inference_checkpoint="model.pt.avg10"
+inference_scp="wav.scp"
+inference_batch_size=32
+nj=32
+test_sets="test"
+
+# model_name from model_hub, or model_dir in local path
+
+## option 1, download model automatically, unsupported currently
+model_name_or_model_dir="iic/speech_charctc_kws_phone-xiaoyun"
+
+## option 2, download model by git
+local_path_root=${workspace}/modelscope_models
+model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
+if [ ! -d $model_name_or_model_dir ]; then
+  mkdir -p ${model_name_or_model_dir}
+  git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${model_name_or_model_dir}
+fi
+
+config=fsmn_4e_l10r2_250_128_fdim80_t2599.yaml
+token_list=${model_name_or_model_dir}/funasr/tokens_2599.txt
+lexicon_list=${model_name_or_model_dir}/funasr/lexicon.txt
+cmvn_file=${model_name_or_model_dir}/funasr/am.mvn.dim80_l2r2
+init_param="${model_name_or_model_dir}/funasr/basetrain_fsmn_4e_l10r2_250_128_fdim80_t2599.pt"
+
+
+# data prepare
+# data dir, which contains: train.json, val.json
+data_dir=../../data
+
+train_data="${data_dir}/train.jsonl"
+val_data="${data_dir}/val.jsonl"
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+  echo "stage 1: Generate audio json list"
+  # generate train.jsonl and val.jsonl from wav.scp and text.txt
+  python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+  ++scp_file_list='['''${data_dir}/train_wav.scp''', '''${data_dir}/train_text.txt''']' \
+  ++data_type_list='["source", "target"]' \
+  ++jsonl_file_out="${train_data}"
+
+  python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+  ++scp_file_list='['''${data_dir}/val_wav.scp''', '''${data_dir}/val_text.txt''']' \
+  ++data_type_list='["source", "target"]' \
+  ++jsonl_file_out="${val_data}"
+fi
+
+# exp output dir
+output_dir="${workspace}/exp/finetune_outputs"
+
+# Training Stage
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+  echo "stage 2: KWS Training"
+
+  mkdir -p ${output_dir}
+  current_time=$(date "+%Y-%m-%d_%H-%M")
+  log_file="${output_dir}/train.log.txt.${current_time}"
+  echo "log_file: ${log_file}"
+	echo "finetune use basetrain model: ${init_param}"
+
+  export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
+  gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+  torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+  ../../../funasr/bin/train.py \
+  --config-path "${workspace}/conf" \
+  --config-name "${config}" \
+  ++init_param="${init_param}" \
+  ++disable_update=true \
+  ++train_data_set_list="${train_data}" \
+  ++valid_data_set_list="${val_data}" \
+  ++tokenizer_conf.token_list="${token_list}" \
+  ++tokenizer_conf.seg_dict="${lexicon_list}" \
+  ++frontend_conf.cmvn_file="${cmvn_file}" \
+  ++output_dir="${output_dir}" &> ${log_file}
+fi
+
+
+# Testing Stage
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+  echo "stage 3: Inference"
+  keywords=(灏忎簯灏忎簯)
+  keywords_string=$(IFS=,; echo "${keywords[*]}")
+  echo "keywords: $keywords_string"
+
+  if [ ${inference_device} == "cuda" ]; then
+      nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+  else
+      inference_batch_size=1
+      CUDA_VISIBLE_DEVICES=""
+      for JOB in $(seq ${nj}); do
+          CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+      done
+  fi
+
+  for dset in ${test_sets}; do
+    inference_dir="${output_dir}/inference-${inference_checkpoint}/${dset}"
+    _logdir="${inference_dir}/logdir"
+    echo "inference_dir: ${inference_dir}"
+
+    mkdir -p "${_logdir}"
+    test_data_dir="${data_dir}/${dset}"
+    key_file=${test_data_dir}/${inference_scp}
+
+    split_scps=
+    for JOB in $(seq "${nj}"); do
+        split_scps+=" ${_logdir}/keys.${JOB}.scp"
+    done
+    $FUNASR_DIR/examples/aishell/paraformer/utils/split_scp.pl "${key_file}" ${split_scps}
+
+    gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+    for JOB in $(seq ${nj}); do
+        {
+          id=$((JOB-1))
+          gpuid=${gpuid_list_array[$id]}
+
+          echo "${output_dir}"
+
+          export CUDA_VISIBLE_DEVICES=${gpuid}
+          python ../../../funasr/bin/inference.py \
+          --config-path="${output_dir}" \
+          --config-name="config.yaml" \
+          ++init_param="${output_dir}/${inference_checkpoint}" \
+          ++tokenizer_conf.token_list="${token_list}" \
+          ++tokenizer_conf.seg_dict="${lexicon_list}" \
+          ++frontend_conf.cmvn_file="${cmvn_file}" \
+          ++keywords="\"$keywords_string"\" \
+          ++input="${_logdir}/keys.${JOB}.scp" \
+          ++output_dir="${inference_dir}/${JOB}" \
+          ++device="${inference_device}" \
+          ++ncpu=1 \
+          ++disable_log=true \
+          ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
+        }&
+
+    done
+    wait
+
+    for f in detect; do
+        if [ -f "${inference_dir}/${JOB}/${f}" ]; then
+          for JOB in $(seq "${nj}"); do
+              cat "${inference_dir}/${JOB}/${f}"
+          done | sort -k1 >"${inference_dir}/${f}"
+        fi
+    done
+
+    python funasr/utils/compute_det_ctc.py \
+        --keywords ${keywords_string} \
+        --test_data ${test_data_dir}/wav.scp \
+        --trans_data ${test_data_dir}/text \
+        --score_file ${inference_dir}/detect \
+        --stats_dir ${inference_dir}
+  done
+
+fi
diff --git a/examples/industrial_data_pretraining/fsmn_kws/funasr b/examples/industrial_data_pretraining/fsmn_kws/funasr
new file mode 120000
index 0000000..39a970f
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/funasr
@@ -0,0 +1 @@
+../../../funasr
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_kws/infer.sh b/examples/industrial_data_pretraining/fsmn_kws/infer.sh
new file mode 100644
index 0000000..6e03b89
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/infer.sh
@@ -0,0 +1,20 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+model="iic/speech_charctc_kws_phone-xiaoyun"
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python funasr/bin/inference.py \
++model=${model} \
++input=${input} \
++output_dir="./outputs/debug" \
++device="cpu" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/fsmn_kws/infer_from_local.sh b/examples/industrial_data_pretraining/fsmn_kws/infer_from_local.sh
new file mode 100644
index 0000000..ca1b7c8
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/infer_from_local.sh
@@ -0,0 +1,41 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method2, inference from local model
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+output_dir="./outputs/debug"
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models
+mkdir -p ${local_path_root}
+local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
+git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
+
+device="cuda:0" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
+
+config="inference_fsmn_4e_l10r2_250_128_fdim80_t2599.yaml"
+tokens="${local_path}/funasr/tokens_2599.txt"
+seg_dict="${local_path}/funasr/lexicon.txt"
+init_param="${local_path}/funasr/finetune_fsmn_4e_l10r2_250_128_fdim80_t2599_xiaoyun_xiaoyun.pt"
+cmvn_file="${local_path}/funasr/am.mvn.dim80_l2r2"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python -m funasr.bin.inference \
+--config-path "${local_path}/funasr" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++tokenizer_conf.token_list="${tokens}" \
+++tokenizer_conf.seg_dict="${seg_dict}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++device="${device}" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/fsmn_kws/path.sh b/examples/industrial_data_pretraining/fsmn_kws/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml b/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml
new file mode 100644
index 0000000..fdf52e4
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml
@@ -0,0 +1,103 @@
+
+# network architecture
+model: FsmnKWSMT
+model_conf:
+    ctc_weight: 1.0
+
+# encoder related
+encoder: FSMNMT
+encoder_conf:
+    input_dim: 400
+    input_affine_dim: 140
+    fsmn_layers: 4
+    linear_dim: 250
+    proj_dim: 128
+    lorder: 10
+    rorder: 2
+    lstride: 1
+    rstride: 1
+    output_affine_dim: 140
+    output_dim: 2599
+    output_dim2: 4
+    use_softmax: false
+
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 80
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 5
+    lfr_n: 3
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 3
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 100
+  keep_nbest_models: 100
+  avg_nbest_model: 10
+  avg_keep_nbest_models_type: loss
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 10000
+
+dataset: KwsMTDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: EspnetStyleBatchSampler
+    batch_type: length # example or length
+    batch_size: 64000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 2048
+    shuffle: true
+    num_workers: 8
+    preprocessor_speech: SpeechPreprocessSpeedPerturb
+    preprocessor_speech_conf:
+      speed_perturb: [0.9, 1.0, 1.1]
+    dataloader: DataloaderMapStyle
+
+tokenizer:
+    - CharTokenizer
+    - CharTokenizer
+
+tokenizer_conf:
+    - unk_symbol: <unk>
+      split_with_space: true
+      token_list: null
+      seg_dict: null
+    - unk_symbol: <unk>
+      split_with_space: true
+      token_list: null
+      seg_dict: null
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin  # ctc_type: focalctc, builtin
+    reduce: true
+    ignore_nan_grad: true
+    extra_linear: false
+
+normalize: null
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml b/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml
new file mode 100644
index 0000000..ca9413e
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml
@@ -0,0 +1,103 @@
+
+# network architecture
+model: FsmnKWSMT
+model_conf:
+    ctc_weight: 1.0
+
+# encoder related
+encoder: FSMNMT
+encoder_conf:
+    input_dim: 360
+    input_affine_dim: 280
+    fsmn_layers: 4
+    linear_dim: 280
+    proj_dim: 200
+    lorder: 10
+    rorder: 2
+    lstride: 1
+    rstride: 1
+    output_affine_dim: 400
+    output_dim: 2602
+    output_dim2: 4
+    use_softmax: false
+
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 40
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 9
+    lfr_n: 3
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 3
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 100
+  keep_nbest_models: 100
+  avg_nbest_model: 10
+  avg_keep_nbest_models_type: loss
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 10000
+
+dataset: KwsMTDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: EspnetStyleBatchSampler
+    batch_type: length # example or length
+    batch_size: 64000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 2048
+    shuffle: true
+    num_workers: 8
+    preprocessor_speech: SpeechPreprocessSpeedPerturb
+    preprocessor_speech_conf:
+      speed_perturb: [0.9, 1.0, 1.1]
+    dataloader: DataloaderMapStyle
+
+tokenizer:
+    - CharTokenizer
+    - CharTokenizer
+
+tokenizer_conf:
+    - unk_symbol: <unk>
+      split_with_space: true
+      token_list: null
+      seg_dict: null
+    - unk_symbol: <unk>
+      split_with_space: true
+      token_list: null
+      seg_dict: null
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin  # ctc_type: focalctc, builtin
+    reduce: true
+    ignore_nan_grad: true
+    extra_linear: false
+
+normalize: null
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/convert.py b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.py
new file mode 100644
index 0000000..e63e689
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.py
@@ -0,0 +1,141 @@
+from __future__ import print_function
+
+import argparse
+import copy
+import logging
+import os
+from shutil import copyfile
+
+import torch
+import yaml
+from typing import Union
+from funasr.models.fsmn_kws_mt.encoder import FSMNMTConvert
+from funasr.models.fsmn_kws_mt.model import FsmnKWSMTConvert
+
+
+def count_parameters(model):
+    return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+def get_args():
+    parser = argparse.ArgumentParser(
+        description=
+        'load and convert network to each other between kaldi/pytorch format')
+    parser.add_argument('--config', required=True, help='config file')
+    parser.add_argument(
+        '--network_file',
+        default='',
+        required=True,
+        help='input network, support kaldi.txt/pytorch.pt')
+    parser.add_argument('--model_dir', required=True, help='save model dir')
+    parser.add_argument('--model_name', required=True, help='save model name')
+    parser.add_argument('--model_name2', required=True, help='save model name')
+    parser.add_argument('--convert_to',
+                        default='kaldi',
+                        required=True,
+                        help='target network type, kaldi/pytorch')
+
+    args = parser.parse_args()
+    return args
+
+
+def convert_to_kaldi(
+    configs,
+    network_file,
+    model_dir,
+    model_name="convert.kaldi.txt",
+    model_name2="convert.kaldi2.txt"
+):
+    copyfile(network_file, os.path.join(model_dir, 'origin.torch.pt'))
+
+    model = FsmnKWSMTConvert(
+        vocab_size=configs['encoder_conf']['output_dim'],
+        vocab_size2=configs['encoder_conf']['output_dim2'],
+        encoder='FSMNMTConvert',
+        encoder_conf=configs['encoder_conf'],
+        ctc_conf=configs['ctc_conf'],
+    )
+    print(model)
+    num_params = count_parameters(model)
+    print('the number of model params: {}'.format(num_params))
+
+    states= torch.load(network_file, map_location='cpu')
+    model.load_state_dict(states["state_dict"])
+
+    kaldi_text = os.path.join(model_dir, model_name)
+    with open(kaldi_text, 'w', encoding='utf8') as fout:
+        nnet_desp = model.to_kaldi_net()
+        fout.write(nnet_desp)
+    fout.close()
+
+    kaldi_text2 = os.path.join(model_dir, model_name2)
+    with open(kaldi_text2, 'w', encoding='utf8') as fout:
+        nnet_desp2 = model.to_kaldi_net2()
+        fout.write(nnet_desp2)
+    fout.close()
+
+
+def convert_to_pytorch(
+    configs,
+    network_file,
+    model_dir,
+    model_name="convert.torch.pt"
+):
+    model = FsmnKWSMTConvert(
+        vocab_size=configs['encoder_conf']['output_dim'],
+        vocab_size2=configs['encoder_conf']['output_dim2'],
+        encoder='FSMNMTConvert',
+        encoder_conf=configs['encoder_conf'],
+        ctc_conf=configs['ctc_conf'],
+    )
+
+    num_params = count_parameters(model)
+    print('the number of model params: {}'.format(num_params))
+
+    copyfile(network_file, os.path.join(model_dir, 'origin.kaldi.txt'))
+    model.to_pytorch_net(network_file)
+
+    save_model_path = os.path.join(model_dir, model_name)
+    torch.save({"model": model.state_dict()}, save_model_path)
+
+    print('convert torch format back to kaldi')
+    kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt')
+    with open(kaldi_text, 'w', encoding='utf8') as fout:
+        nnet_desp = model.to_kaldi_net()
+        fout.write(nnet_desp)
+    fout.close()
+
+    print('Done!')
+
+
+def main():
+    args = get_args()
+    logging.basicConfig(level=logging.DEBUG,
+                        format='%(asctime)s %(levelname)s %(message)s')
+    print(args)
+    with open(args.config, 'r') as fin:
+        configs = yaml.load(fin, Loader=yaml.FullLoader)
+
+    if args.convert_to == 'pytorch':
+        print('convert kaldi net to pytorch...')
+        convert_to_pytorch(
+            configs,
+            args.network_file,
+            args.model_dir,
+            args.model_name,
+            args.model_name2,
+        )
+    elif args.convert_to == 'kaldi':
+        print('convert pytorch net to kaldi...')
+        convert_to_kaldi(
+            configs,
+            args.network_file,
+            args.model_dir,
+            args.model_name
+        )
+    else:
+        print('unsupported target network type: {}'.format(args.convert_to))
+
+
+if __name__ == '__main__':
+    main()
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh
new file mode 100644
index 0000000..30e2eed
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/convert.sh
@@ -0,0 +1,36 @@
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models
+mkdir -p ${local_path_root}
+
+local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
+if [ ! -d "$local_path" ]; then
+    git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
+fi
+
+export PATH=${local_path}/runtime:$PATH
+export LD_LIBRARY_PATH=${local_path}/runtime:$LD_LIBRARY_PATH
+
+# finetune config file
+config=./conf/fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml
+
+# finetune output checkpoint
+torch_nnet=exp/finetune_outputs/model.pt.avg10
+
+out_dir=exp/finetune_outputs
+
+if [ ! -d "$out_dir" ]; then
+    mkdir -p $out_dir
+fi
+
+python convert.py --config $config \
+	--network_file $torch_nnet \
+	--model_dir $out_dir \
+	--model_name "convert.kaldi.txt" \
+	--model_name2 "convert.kaldi2.txt" \
+	--convert_to kaldi
+
+nnet-copy --binary=true ${out_dir}/convert.kaldi.txt ${out_dir}/convert.kaldi.net
+nnet-copy --binary=true ${out_dir}/convert.kaldi2.txt ${out_dir}/convert.kaldi2.net
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh
new file mode 100755
index 0000000..1e87021
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/finetune.sh
@@ -0,0 +1,186 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+. ./path.sh
+workspace=`pwd`
+
+CUDA_VISIBLE_DEVICES="0,1"
+
+stage=2
+stop_stage=3
+
+inference_device="cuda" #"cpu"
+inference_checkpoint="model.pt.avg10"
+inference_scp="wav.scp"
+inference_batch_size=32
+nj=32
+test_sets="test"
+
+# model_name from model_hub, or model_dir in local path
+
+## option 1, download model automatically, unsupported currently
+model_name_or_model_dir="iic/speech_charctc_kws_phone-xiaoyun"
+
+## option 2, download model by git
+local_path_root=${workspace}/modelscope_models
+model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
+if [ ! -d $model_name_or_model_dir ]; then
+  mkdir -p ${model_name_or_model_dir}
+  git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${model_name_or_model_dir}
+fi
+
+config=fsmn_4e_l10r2_250_128_fdim80_t2599_t4.yaml
+token_list=${model_name_or_model_dir}/funasr/tokens_2599.txt
+token_list2=${model_name_or_model_dir}/funasr/tokens_xiaoyun_char.txt
+lexicon_list=${model_name_or_model_dir}/funasr/lexicon.txt
+cmvn_file=${model_name_or_model_dir}/funasr/am.mvn.dim80_l2r2
+init_param="${model_name_or_model_dir}/funasr/basetrain_fsmn_4e_l10r2_250_128_fdim80_t2599.pt"
+
+
+# data prepare
+# data dir, which contains: train.json, val.json
+data_dir=../../data
+
+train_data="${data_dir}/train.jsonl"
+val_data="${data_dir}/val.jsonl"
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+  echo "stage 1: Generate audio json list"
+  # generate train.jsonl and val.jsonl from wav.scp and text.txt
+  python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+  ++scp_file_list='['''${data_dir}/train_wav.scp''', '''${data_dir}/train_text.txt''']' \
+  ++data_type_list='["source", "target"]' \
+  ++jsonl_file_out="${train_data}"
+
+  python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+  ++scp_file_list='['''${data_dir}/val_wav.scp''', '''${data_dir}/val_text.txt''']' \
+  ++data_type_list='["source", "target"]' \
+  ++jsonl_file_out="${val_data}"
+fi
+
+# exp output dir
+output_dir="${workspace}/exp/finetune_outputs"
+
+
+# Training Stage
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+  echo "stage 2: KWS Training"
+
+  mkdir -p ${output_dir}
+  current_time=$(date "+%Y-%m-%d_%H-%M")
+  log_file="${output_dir}/train.log.txt.${current_time}"
+  echo "log_file: ${log_file}"
+  echo "finetune use basetrain model: ${init_param}"
+
+  export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
+  gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+  torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+  ../../../funasr/bin/train.py \
+  --config-path "${workspace}/conf" \
+  --config-name "${config}" \
+  ++init_param="${init_param}" \
+  ++token_lists='['''${token_list}''', '''${token_list2}''']' \
+  ++seg_dicts='['''${lexicon_list}''', '''${lexicon_list}''']' \
+  ++disable_update=true \
+  ++train_data_set_list="${train_data}" \
+  ++valid_data_set_list="${val_data}" \
+  ++frontend_conf.cmvn_file="${cmvn_file}" \
+  ++output_dir="${output_dir}" &> ${log_file}
+fi
+
+
+# Testing Stage
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+  echo "stage 3: Inference"
+  keywords=(灏忎簯灏忎簯)
+  keywords_string=$(IFS=,; echo "${keywords[*]}")
+  echo "keywords: $keywords_string"
+
+  if [ ${inference_device} == "cuda" ]; then
+      nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+  else
+      inference_batch_size=1
+      CUDA_VISIBLE_DEVICES=""
+      for JOB in $(seq ${nj}); do
+          CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+      done
+  fi
+
+  for dset in ${test_sets}; do
+    inference_dir="${output_dir}/inference-${inference_checkpoint}/${dset}"
+    _logdir="${inference_dir}/logdir"
+    echo "inference_dir: ${inference_dir}"
+
+    mkdir -p "${_logdir}"
+    test_data_dir="${data_dir}/${dset}"
+    key_file=${test_data_dir}/${inference_scp}
+
+    split_scps=
+    for JOB in $(seq "${nj}"); do
+        split_scps+=" ${_logdir}/keys.${JOB}.scp"
+    done
+    $FUNASR_DIR/examples/aishell/paraformer/utils/split_scp.pl "${key_file}" ${split_scps}
+
+    gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+    for JOB in $(seq ${nj}); do
+        {
+          id=$((JOB-1))
+          gpuid=${gpuid_list_array[$id]}
+
+          echo "${output_dir}"
+
+          export CUDA_VISIBLE_DEVICES=${gpuid}
+          python ../../../funasr/bin/inference.py \
+          --config-path="${output_dir}" \
+          --config-name="config.yaml" \
+          ++init_param="${output_dir}/${inference_checkpoint}" \
+          ++tokenizer_conf.token_list="${token_list}" \
+          ++tokenizer_conf.seg_dict="${lexicon_list}" \
+          ++tokenizer2_conf.token_list="${token_list2}" \
+          ++tokenizer2_conf.seg_dict="${lexicon_list}" \
+          ++frontend_conf.cmvn_file="${cmvn_file}" \
+          ++keywords="\"$keywords_string"\" \
+          ++input="${_logdir}/keys.${JOB}.scp" \
+          ++output_dir="${inference_dir}/${JOB}" \
+          ++device="${inference_device}" \
+          ++ncpu=1 \
+          ++disable_log=true \
+          ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
+        }&
+
+    done
+    wait
+
+    for f in detect detect2; do
+        if [ -f "${inference_dir}/${JOB}/${f}" ]; then
+          for JOB in $(seq "${nj}"); do
+              cat "${inference_dir}/${JOB}/${f}"
+          done | sort -k1 >"${inference_dir}/${f}"
+        fi
+    done
+
+    mkdir -p ${inference_dir}/task1
+    python funasr/utils/compute_det_ctc.py \
+        --keywords ${keywords_string} \
+        --test_data ${test_data_dir}/wav.scp \
+        --trans_data ${test_data_dir}/text \
+        --score_file ${inference_dir}/detect \
+        --stats_dir ${inference_dir}/task1
+
+    mkdir -p ${inference_dir}/task2
+    python funasr/utils/compute_det_ctc.py \
+        --keywords ${keywords_string} \
+        --test_data ${test_data_dir}/wav.scp \
+        --trans_data ${test_data_dir}/text \
+        --score_file ${inference_dir}/detect2 \
+        --stats_dir ${inference_dir}/task2
+  done
+
+fi
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/funasr b/examples/industrial_data_pretraining/fsmn_kws_mt/funasr
new file mode 120000
index 0000000..39a970f
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/funasr
@@ -0,0 +1 @@
+../../../funasr
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh
new file mode 100644
index 0000000..6e03b89
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/infer.sh
@@ -0,0 +1,20 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+model="iic/speech_charctc_kws_phone-xiaoyun"
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python funasr/bin/inference.py \
++model=${model} \
++input=${input} \
++output_dir="./outputs/debug" \
++device="cpu" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh
new file mode 100644
index 0000000..51d2312
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/infer_from_local.sh
@@ -0,0 +1,44 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method2, inference from local model
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+output_dir="./outputs/debug"
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models
+mkdir -p ${local_path_root}
+local_path=${local_path_root}/speech_charctc_kws_phone-xiaoyun
+git clone https://www.modelscope.cn/iic/speech_charctc_kws_phone-xiaoyun.git ${local_path}
+
+device="cuda:0" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
+
+config="inference_fsmn_4e_l10r2_280_200_fdim40_t2602_t4.yaml"
+tokens="${local_path}/funasr/tokens_2602.txt"
+tokens2="${local_path}/funasr/tokens_xiaoyun_char.txt"
+seg_dict="${local_path}/funasr/lexicon.txt"
+init_param="${local_path}/funasr/finetune_fsmn_4e_l10r2_280_200_fdim40_t2602_t4_xiaoyun_xiaoyun.pt"
+cmvn_file="${local_path}/funasr/am.mvn.dim40_l4r4"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python -m funasr.bin.inference \
+--config-path "${local_path}/funasr" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++tokenizer_conf.token_list="${tokens}" \
+++tokenizer_conf.seg_dict="${seg_dict}" \
+++tokenizer2_conf.token_list="${tokens2}" \
+++tokenizer2_conf.seg_dict="${seg_dict}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++device="${device}" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/fsmn_kws_mt/path.sh b/examples/industrial_data_pretraining/fsmn_kws_mt/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/examples/industrial_data_pretraining/fsmn_kws_mt/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/examples/industrial_data_pretraining/sanm_kws/conf/sanm_6e_320_256_fdim40_t2602.yaml b/examples/industrial_data_pretraining/sanm_kws/conf/sanm_6e_320_256_fdim40_t2602.yaml
new file mode 100644
index 0000000..c4d8c18
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/conf/sanm_6e_320_256_fdim40_t2602.yaml
@@ -0,0 +1,94 @@
+
+# network architecture
+model: SanmKWS
+model_conf:
+    ctc_weight: 1.0
+
+# encoder
+encoder: SANMEncoder
+encoder_conf:
+    output_size: 256    # dimension of attention
+    attention_heads: 4
+    linear_units: 320  # the number of units of position-wise feed forward
+    num_blocks: 6      # the number of encoder blocks
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.1
+    input_layer: pe
+    pos_enc_class: SinusoidalPositionEncoder
+    normalize_before: true
+    kernel_size: 11
+    sanm_shfit: 0
+    selfattention_layer_type: sanm
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 40
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 7
+    lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 6
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 100
+  keep_nbest_models: 20
+  avg_nbest_model: 10
+  avg_keep_nbest_models_type: loss
+  validate_interval: 50000
+  save_checkpoint_interval: 50000
+  avg_checkpoint_interval: 1000
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 10000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: EspnetStyleBatchSampler
+    batch_type: length # example or length
+    batch_size: 96000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 2048
+    shuffle: true
+    num_workers: 8
+    preprocessor_speech: SpeechPreprocessSpeedPerturb
+    preprocessor_speech_conf:
+      speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin  # ctc_type: focalctc, builtin
+    reduce: true
+    ignore_nan_grad: true
+normalize: null
diff --git a/examples/industrial_data_pretraining/sanm_kws/export.sh b/examples/industrial_data_pretraining/sanm_kws/export.sh
new file mode 100644
index 0000000..1106292
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/export.sh
@@ -0,0 +1,17 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+config_path="/home/pengteng.spt/source/FunASR_KWS/examples/industrial_data_pretraining/sanm_kws/conf"
+config_path="/home/pengteng.spt/source/FunASR_KWS/examples/industrial_data_pretraining/sanm_kws/exp/20240914_xiaoyun_finetune_sanm_6e_320_256_feats_dim40_char_t2602_offline"
+
+config_file="sanm_6e_320_256_fdim40_t2602.yaml"
+config_file="config.yaml"
+
+model_path="./modelscope_models_kws/speech_charctc_kws_phone-xiaoyun/funasr/finetune_sanm_6e_320_256_fdim40_t2602_online_xiaoyun_commands.pt"
+
+python -m funasr.bin.export \
+    --config-path="${config_path}" \
+    --config-name="${config_file}" \
+    ++init_param=${model_path} \
+    ++type="onnx" \
+    ++quantize=true
diff --git a/examples/industrial_data_pretraining/sanm_kws/finetune.sh b/examples/industrial_data_pretraining/sanm_kws/finetune.sh
new file mode 100755
index 0000000..cf6d488
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/finetune.sh
@@ -0,0 +1,172 @@
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+. ./path.sh
+workspace=`pwd`
+
+CUDA_VISIBLE_DEVICES="0,1"
+
+stage=2
+stop_stage=3
+
+inference_device="cpu" #"cpu"
+inference_device="cuda" #"cpu"
+inference_checkpoint="model.pt.avg10"
+inference_scp="wav.scp"
+inference_batch_size=32
+nj=32
+test_sets="test"
+
+# model_name from model_hub, or model_dir in local path
+
+## option 1, download model automatically, unsupported currently
+model_name_or_model_dir="iic/speech_sanm_kws_phone-xiaoyun-commands-offline"
+
+## option 2, download model by git
+local_path_root=${workspace}/modelscope_models
+model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
+if [ ! -d $model_name_or_model_dir ]; then
+  mkdir -p ${model_name_or_model_dir}
+  git clone https://www.modelscope.cn/iic/speech_sanm_kws_phone-xiaoyun-commands-offline.git ${model_name_or_model_dir}
+fi
+
+config=sanm_6e_320_256_fdim40_t2602.yaml
+token_list=${model_name_or_model_dir}/tokens_2602.txt
+lexicon_list=${model_name_or_model_dir}/lexicon.txt
+cmvn_file=${model_name_or_model_dir}/am.mvn.dim40_l3r3
+init_param="${model_name_or_model_dir}/basetrain_sanm_6e_320_256_fdim40_t2602_offline.pt"
+
+
+# data prepare
+# data dir, which contains: train.json, val.json
+data_dir=../../data
+
+train_data="${data_dir}/train.jsonl"
+val_data="${data_dir}/val.jsonl"
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+  echo "stage 1: Generate audio json list"
+  # generate train.jsonl and val.jsonl from wav.scp and text.txt
+  python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+  ++scp_file_list='['''${data_dir}/train_wav.scp''', '''${data_dir}/train_text.txt''']' \
+  ++data_type_list='["source", "target"]' \
+  ++jsonl_file_out="${train_data}"
+
+  python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+  ++scp_file_list='['''${data_dir}/val_wav.scp''', '''${data_dir}/val_text.txt''']' \
+  ++data_type_list='["source", "target"]' \
+  ++jsonl_file_out="${val_data}"
+fi
+
+# exp output dir
+output_dir="${workspace}/exp/finetune_outputs"
+
+# Training Stage
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+  echo "stage 2: KWS Training"
+
+  mkdir -p ${output_dir}
+  current_time=$(date "+%Y-%m-%d_%H-%M")
+  log_file="${output_dir}/train.log.txt.${current_time}"
+  echo "log_file: ${log_file}"
+	echo "finetune use basetrain model: ${init_param}"
+
+  export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
+  gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+  torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+  ../../../funasr/bin/train.py \
+  --config-path "${workspace}/conf" \
+  --config-name "${config}" \
+  ++init_param="${init_param}" \
+  ++disable_update=true \
+  ++train_data_set_list="${train_data}" \
+  ++valid_data_set_list="${val_data}" \
+  ++tokenizer_conf.token_list="${token_list}" \
+  ++tokenizer_conf.seg_dict="${lexicon_list}" \
+  ++frontend_conf.cmvn_file="${cmvn_file}" \
+  ++output_dir="${output_dir}" &> ${log_file}
+fi
+
+
+# Testing Stage
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+  echo "stage 3: Inference"
+  keywords=(灏忎簯灏忎簯)
+  keywords_string=$(IFS=,; echo "${keywords[*]}")
+  echo "keywords: $keywords_string"
+
+  if [ ${inference_device} == "cuda" ]; then
+      nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+  else
+      inference_batch_size=1
+      CUDA_VISIBLE_DEVICES=""
+      for JOB in $(seq ${nj}); do
+          CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+      done
+  fi
+
+  for dset in ${test_sets}; do
+    inference_dir="${output_dir}/inference-${inference_checkpoint}/${dset}"
+    _logdir="${inference_dir}/logdir"
+    echo "inference_dir: ${inference_dir}"
+
+    mkdir -p "${_logdir}"
+    test_data_dir="${data_dir}/${dset}"
+    key_file=${test_data_dir}/${inference_scp}
+
+    split_scps=
+    for JOB in $(seq "${nj}"); do
+        split_scps+=" ${_logdir}/keys.${JOB}.scp"
+    done
+    $FUNASR_DIR/examples/aishell/paraformer/utils/split_scp.pl "${key_file}" ${split_scps}
+
+    gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+    for JOB in $(seq ${nj}); do
+        {
+          id=$((JOB-1))
+          gpuid=${gpuid_list_array[$id]}
+
+          echo "${output_dir}"
+
+          export CUDA_VISIBLE_DEVICES=${gpuid}
+          python ../../../funasr/bin/inference.py \
+          --config-path="${output_dir}" \
+          --config-name="config.yaml" \
+          ++init_param="${output_dir}/${inference_checkpoint}" \
+          ++tokenizer_conf.token_list="${token_list}" \
+          ++tokenizer_conf.seg_dict="${lexicon_list}" \
+          ++frontend_conf.cmvn_file="${cmvn_file}" \
+          ++keywords="\"$keywords_string"\" \
+          ++input="${_logdir}/keys.${JOB}.scp" \
+          ++output_dir="${inference_dir}/${JOB}" \
+          ++device="${inference_device}" \
+          ++ncpu=1 \
+          ++disable_log=true \
+          ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
+          # ++batch_size="${inference_batch_size}"
+        }&
+
+    done
+    wait
+
+    for f in detect score; do
+        if [ -f "${inference_dir}/${JOB}/${f}" ]; then
+          for JOB in $(seq "${nj}"); do
+              cat "${inference_dir}/${JOB}/${f}"
+          done | sort -k1 >"${inference_dir}/${f}"
+        fi
+    done
+
+    python funasr/utils/compute_det_ctc.py \
+        --keywords ${keywords_string} \
+        --test_data ${test_data_dir}/wav.scp \
+        --trans_data ${test_data_dir}/text \
+        --score_file ${inference_dir}/detect \
+        --stats_dir ${inference_dir}
+  done
+
+fi
diff --git a/examples/industrial_data_pretraining/sanm_kws/funasr b/examples/industrial_data_pretraining/sanm_kws/funasr
new file mode 120000
index 0000000..39a970f
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/funasr
@@ -0,0 +1 @@
+../../../funasr
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/sanm_kws/infer.sh b/examples/industrial_data_pretraining/sanm_kws/infer.sh
new file mode 100644
index 0000000..74455bb
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/infer.sh
@@ -0,0 +1,20 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+model="iic/speech_sanm_kws_phone-xiaoyun-commands-offline"
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python funasr/bin/inference.py \
++model=${model} \
++input=${input} \
++output_dir="./outputs/debug" \
++device="cpu" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/sanm_kws/infer_from_local.sh b/examples/industrial_data_pretraining/sanm_kws/infer_from_local.sh
new file mode 100644
index 0000000..6be8bab
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/infer_from_local.sh
@@ -0,0 +1,41 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method2, inference from local model
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+output_dir="./outputs/debug"
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models
+mkdir -p ${local_path_root}
+local_path=${local_path_root}/speech_sanm_kws_phone-xiaoyun-commands-offline
+git clone https://www.modelscope.cn/iic/speech_sanm_kws_phone-xiaoyun-commands-offline.git ${local_path}
+
+device="cpu" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
+
+config="inference_sanm_6e_320_256_fdim40_t2602_offline.yaml"
+tokens="${local_path}/tokens_2602.txt"
+seg_dict="${local_path}/lexicon.txt"
+init_param="${local_path}/finetune_sanm_6e_320_256_fdim40_t2602_offline_xiaoyun_commands.pt"
+cmvn_file="${local_path}/am.mvn.dim40_l3r3"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python -m funasr.bin.inference \
+--config-path "${local_path}/" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++tokenizer_conf.token_list="${tokens}" \
+++tokenizer_conf.seg_dict="${seg_dict}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++device="${device}" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/sanm_kws/path.sh b/examples/industrial_data_pretraining/sanm_kws/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/conf/sanm_6e_320_256_fdim40_t2602.yaml b/examples/industrial_data_pretraining/sanm_kws_streaming/conf/sanm_6e_320_256_fdim40_t2602.yaml
new file mode 100644
index 0000000..664997c
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/conf/sanm_6e_320_256_fdim40_t2602.yaml
@@ -0,0 +1,109 @@
+
+# network architecture
+model: SanmKWSStreaming
+model_conf:
+    ctc_weight: 1.0
+
+# encoder
+encoder: SANMEncoderChunkOpt
+encoder_conf:
+    output_size: 256    # dimension of attention
+    attention_heads: 4
+    linear_units: 320  # the number of units of position-wise feed forward
+    num_blocks: 6      # the number of encoder blocks
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.1
+    input_layer: pe_online
+    pos_enc_class: SinusoidalPositionEncoder
+    normalize_before: true
+    kernel_size: 11
+    sanm_shfit: 0
+    selfattention_layer_type: sanm
+    chunk_size:
+    - 16
+    - 20
+    stride:
+    - 8
+    - 10
+    pad_left:
+    - 4
+    - 5
+    encoder_att_look_back_factor:
+    - 0
+    - 0
+    decoder_att_look_back_factor:
+    - 0
+    - 0
+
+# frontend related
+frontend: WavFrontendOnline
+frontend_conf:
+    fs: 16000
+    window: hamming
+    n_mels: 40
+    frame_length: 25
+    frame_shift: 10
+    lfr_m: 7
+    lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+    apply_time_warp: false
+    time_warp_window: 5
+    time_warp_mode: bicubic
+    apply_freq_mask: true
+    freq_mask_width_range:
+    - 0
+    - 30
+    lfr_rate: 6
+    num_freq_mask: 1
+    apply_time_mask: true
+    time_mask_width_range:
+    - 0
+    - 12
+    num_time_mask: 1
+
+train_conf:
+  accum_grad: 1
+  grad_clip: 5
+  max_epoch: 100
+  keep_nbest_models: 20
+  avg_nbest_model: 10
+  avg_keep_nbest_models_type: loss
+  validate_interval: 50000
+  save_checkpoint_interval: 50000
+  avg_checkpoint_interval: 1000
+  log_interval: 50
+
+optim: adam
+optim_conf:
+   lr: 0.001
+scheduler: warmuplr
+scheduler_conf:
+   warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+    index_ds: IndexDSJsonl
+    batch_sampler: EspnetStyleBatchSampler
+    batch_type: length # example or length
+    batch_size: 64000 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+    max_token_length: 1600 # filter samples if source_token_len+target_token_len > max_token_length,
+    buffer_size: 2048
+    shuffle: true
+    num_workers: 8
+    preprocessor_speech: SpeechPreprocessSpeedPerturb
+    preprocessor_speech_conf:
+      speed_perturb: [0.9, 1.0, 1.1]
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+
+ctc_conf:
+    dropout_rate: 0.0
+    ctc_type: builtin  # ctc_type: focalctc, builtin
+    reduce: true
+    ignore_nan_grad: true
+normalize: null
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/export.sh b/examples/industrial_data_pretraining/sanm_kws_streaming/export.sh
new file mode 100644
index 0000000..eef6875
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/export.sh
@@ -0,0 +1,17 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+config_path="/home/pengteng.spt/source/FunASR_KWS/examples/industrial_data_pretraining/sanm_kws_streaming/conf"
+config_path="/home/pengteng.spt/source/FunASR_KWS/examples/industrial_data_pretraining/sanm_kws_streaming/exp/20240618_xiaoyun_finetune_sanm_6e_320_256_feats_dim40_char_t2602_online_6"
+
+config_file="sanm_6e_320_256_fdim40_t2602.yaml"
+config_file="config.yaml"
+
+model_path="./modelscope_models_kws/speech_charctc_kws_phone-xiaoyun/funasr/finetune_sanm_6e_320_256_fdim40_t2602_online_xiaoyun_commands.pt"
+
+python -m funasr.bin.export \
+    --config-path="${config_path}" \
+    --config-name="${config_file}" \
+    ++init_param=${model_path} \
+    ++type="onnx" \
+    ++quantize=true
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/finetune.sh b/examples/industrial_data_pretraining/sanm_kws_streaming/finetune.sh
new file mode 100755
index 0000000..8914786
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/finetune.sh
@@ -0,0 +1,258 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+#!/usr/bin/env bash
+
+# Set bash to 'debug' mode, it will exit on :
+# -e 'error', -u 'undefined variable', -o ... 'error in pipeline', -x 'print commands',
+set -e
+set -u
+set -o pipefail
+
+. ./path.sh
+workspace=`pwd`
+
+CUDA_VISIBLE_DEVICES="0,1"
+
+stage=2
+stop_stage=4
+
+inference_device="cpu" #"cpu"
+inference_checkpoint="model.pt.avg10"
+inference_scp="wav.scp"
+inference_batch_size=32
+nj=32
+test_sets="test"
+
+# model_name from model_hub, or model_dir in local path
+
+## option 1, download model automatically, unsupported currently
+model_name_or_model_dir="iic/speech_sanm_kws_phone-xiaoyun-commands-online"
+
+## option 2, download model by git
+local_path_root=${workspace}/modelscope_models
+model_name_or_model_dir=${local_path_root}/${model_name_or_model_dir}
+if [ ! -d $model_name_or_model_dir ]; then
+  mkdir -p ${model_name_or_model_dir}
+  git clone https://www.modelscope.cn/iic/speech_sanm_kws_phone-xiaoyun-commands-online.git ${model_name_or_model_dir}
+fi
+
+config=sanm_6e_320_256_fdim40_t2602.yaml
+token_list=${model_name_or_model_dir}/tokens_2602.txt
+lexicon_list=${model_name_or_model_dir}/lexicon.txt
+cmvn_file=${model_name_or_model_dir}/am.mvn.dim40_l3r3
+init_param="${model_name_or_model_dir}/basetrain_sanm_6e_320_256_fdim40_t2602_online.pt"
+
+
+# data prepare
+# data dir, which contains: train.json, val.json
+data_dir=../../data
+
+train_data="${data_dir}/train.jsonl"
+val_data="${data_dir}/val.jsonl"
+if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
+  echo "stage 1: Generate audio json list"
+  # generate train.jsonl and val.jsonl from wav.scp and text.txt
+  python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+  ++scp_file_list='['''${data_dir}/train_wav.scp''', '''${data_dir}/train_text.txt''']' \
+  ++data_type_list='["source", "target"]' \
+  ++jsonl_file_out="${train_data}"
+
+  python $FUNASR_DIR/funasr/datasets/audio_datasets/scp2jsonl.py \
+  ++scp_file_list='['''${data_dir}/val_wav.scp''', '''${data_dir}/val_text.txt''']' \
+  ++data_type_list='["source", "target"]' \
+  ++jsonl_file_out="${val_data}"
+fi
+
+# exp output dir
+output_dir="${workspace}/exp/finetune_outputs"
+
+# Training Stage
+if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
+  echo "stage 2: KWS Training"
+
+  mkdir -p ${output_dir}
+  current_time=$(date "+%Y-%m-%d_%H-%M")
+  log_file="${output_dir}/train.log.txt.${current_time}"
+  echo "log_file: ${log_file}"
+	echo "finetune use basetrain model: ${init_param}"
+
+  export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES
+  gpu_num=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+  torchrun --nnodes 1 --nproc_per_node ${gpu_num} \
+  ../../../funasr/bin/train.py \
+  --config-path "${workspace}/conf" \
+  --config-name "${config}" \
+  ++init_param="${init_param}" \
+  ++disable_update=true \
+  ++train_data_set_list="${train_data}" \
+  ++valid_data_set_list="${val_data}" \
+  ++tokenizer_conf.token_list="${token_list}" \
+  ++tokenizer_conf.seg_dict="${lexicon_list}" \
+  ++frontend_conf.cmvn_file="${cmvn_file}" \
+  ++output_dir="${output_dir}" &> ${log_file}
+fi
+
+
+# Testing Stage
+if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
+  echo "stage 3: Inference chunk_size: [4, 8, 4]"
+  keywords=(灏忎簯灏忎簯)
+  keywords_string=$(IFS=,; echo "${keywords[*]}")
+  echo "keywords: $keywords_string"
+
+  if [ ${inference_device} == "cuda" ]; then
+      nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+  else
+      inference_batch_size=1
+      CUDA_VISIBLE_DEVICES=""
+      for JOB in $(seq ${nj}); do
+          CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+      done
+  fi
+
+  for dset in ${test_sets}; do
+    inference_dir="${output_dir}/inference-${inference_checkpoint}/${dset}/chunk-4-8-4_elb-0_dlb_0"
+    _logdir="${inference_dir}/logdir"
+    echo "inference_dir: ${inference_dir}"
+
+    mkdir -p "${_logdir}"
+    test_data_dir="${data_dir}/${dset}"
+    key_file=${test_data_dir}/${inference_scp}
+
+    split_scps=
+    for JOB in $(seq "${nj}"); do
+        split_scps+=" ${_logdir}/keys.${JOB}.scp"
+    done
+    $FUNASR_DIR/examples/aishell/paraformer/utils/split_scp.pl "${key_file}" ${split_scps}
+
+    gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+    for JOB in $(seq ${nj}); do
+        {
+          id=$((JOB-1))
+          gpuid=${gpuid_list_array[$id]}
+
+          echo "${output_dir}"
+
+          export CUDA_VISIBLE_DEVICES=${gpuid}
+          python ../../../funasr/bin/inference.py \
+          --config-path="${output_dir}" \
+          --config-name="config.yaml" \
+          ++init_param="${output_dir}/${inference_checkpoint}" \
+          ++tokenizer_conf.token_list="${token_list}" \
+          ++tokenizer_conf.seg_dict="${lexicon_list}" \
+          ++frontend_conf.cmvn_file="${cmvn_file}" \
+          ++keywords="\"$keywords_string"\" \
+          ++input="${_logdir}/keys.${JOB}.scp" \
+          ++output_dir="${inference_dir}/${JOB}" \
+          ++chunk_size='[4, 8, 4]' \
+          ++encoder_chunk_look_back=0 \
+          ++decoder_chunk_look_back=0 \
+          ++device="${inference_device}" \
+          ++ncpu=1 \
+          ++disable_log=true \
+          ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
+        }&
+
+    done
+    wait
+
+    for f in detect score; do
+        if [ -f "${inference_dir}/${JOB}/${f}" ]; then
+          for JOB in $(seq "${nj}"); do
+              cat "${inference_dir}/${JOB}/${f}"
+          done | sort -k1 >"${inference_dir}/${f}"
+        fi
+    done
+
+    python funasr/utils/compute_det_ctc.py \
+        --keywords ${keywords_string} \
+        --test_data ${test_data_dir}/wav.scp \
+        --trans_data ${test_data_dir}/text \
+        --score_file ${inference_dir}/detect \
+        --stats_dir ${inference_dir}
+  done
+
+fi
+
+
+# Testing Stage
+if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
+  echo "stage 4: Inference chunk_size: [5, 10, 5]"
+  keywords=(灏忎簯灏忎簯)
+  keywords_string=$(IFS=,; echo "${keywords[*]}")
+  echo "keywords: $keywords_string"
+
+  if [ ${inference_device} == "cuda" ]; then
+      nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+  else
+      inference_batch_size=1
+      CUDA_VISIBLE_DEVICES=""
+      for JOB in $(seq ${nj}); do
+          CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+      done
+  fi
+
+  for dset in ${test_sets}; do
+    inference_dir="${output_dir}/inference-${inference_checkpoint}/${dset}/chunk-5-10-5_elb-0_dlb_0"
+    _logdir="${inference_dir}/logdir"
+    echo "inference_dir: ${inference_dir}"
+
+    mkdir -p "${_logdir}"
+    test_data_dir="${data_dir}/${dset}"
+    key_file=${test_data_dir}/${inference_scp}
+
+    split_scps=
+    for JOB in $(seq "${nj}"); do
+        split_scps+=" ${_logdir}/keys.${JOB}.scp"
+    done
+    $FUNASR_DIR/examples/aishell/paraformer/utils/split_scp.pl "${key_file}" ${split_scps}
+
+    gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+    for JOB in $(seq ${nj}); do
+        {
+          id=$((JOB-1))
+          gpuid=${gpuid_list_array[$id]}
+
+          echo "${output_dir}"
+
+          export CUDA_VISIBLE_DEVICES=${gpuid}
+          python ../../../funasr/bin/inference.py \
+          --config-path="${output_dir}" \
+          --config-name="config.yaml" \
+          ++init_param="${output_dir}/${inference_checkpoint}" \
+          ++tokenizer_conf.token_list="${token_list}" \
+          ++tokenizer_conf.seg_dict="${lexicon_list}" \
+          ++frontend_conf.cmvn_file="${cmvn_file}" \
+          ++keywords="\"$keywords_string"\" \
+          ++input="${_logdir}/keys.${JOB}.scp" \
+          ++output_dir="${inference_dir}/${JOB}" \
+          ++chunk_size='[5, 10, 5]' \
+          ++encoder_chunk_look_back=0 \
+          ++decoder_chunk_look_back=0 \
+          ++device="${inference_device}" \
+          ++ncpu=1 \
+          ++disable_log=true \
+          ++batch_size="${inference_batch_size}" &> ${_logdir}/log.${JOB}.txt
+        }&
+
+    done
+    wait
+
+    for f in detect; do
+        if [ -f "${inference_dir}/${JOB}/${f}" ]; then
+          for JOB in $(seq "${nj}"); do
+              cat "${inference_dir}/${JOB}/${f}"
+          done | sort -k1 >"${inference_dir}/${f}"
+        fi
+    done
+
+    python funasr/utils/compute_det_ctc.py \
+        --keywords ${keywords_string} \
+        --test_data ${test_data_dir}/wav.scp \
+        --trans_data ${test_data_dir}/text \
+        --score_file ${inference_dir}/detect \
+        --stats_dir ${inference_dir}
+  done
+
+fi
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/funasr b/examples/industrial_data_pretraining/sanm_kws_streaming/funasr
new file mode 120000
index 0000000..39a970f
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/funasr
@@ -0,0 +1 @@
+../../../funasr
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/infer.sh b/examples/industrial_data_pretraining/sanm_kws_streaming/infer.sh
new file mode 100644
index 0000000..ce34a9b
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/infer.sh
@@ -0,0 +1,34 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method1, inference from model hub
+
+model="iic/speech_sanm_kws_phone-xiaoyun-commands-online"
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+python funasr/bin/inference.py \
++model=${model} \
++input=${input} \
++output_dir="./outputs/debug" \
+++chunk_size='[4, 8, 4]' \
+++encoder_chunk_look_back=0 \
+++decoder_chunk_look_back=0 \
++device="cpu" \
+++keywords="\"$keywords_string"\"
+
+
+python funasr/bin/inference.py \
++model=${model} \
++input=${input} \
++output_dir="./outputs/debug" \
+++chunk_size='[5, 10, 5]' \
+++encoder_chunk_look_back=0 \
+++decoder_chunk_look_back=0 \
++device="cpu" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/infer_from_local.sh b/examples/industrial_data_pretraining/sanm_kws_streaming/infer_from_local.sh
new file mode 100644
index 0000000..a1f0639
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/infer_from_local.sh
@@ -0,0 +1,62 @@
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+# method2, inference from local model
+
+# for more input type, please ref to readme.md
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testset/kws_xiaoyunxiaoyun.wav"
+
+output_dir="./outputs/debug"
+
+workspace=`pwd`
+
+# download model
+local_path_root=${workspace}/modelscope_models
+mkdir -p ${local_path_root}
+local_path=${local_path_root}/speech_sanm_kws_phone-xiaoyun-commands-online
+git clone https://www.modelscope.cn/iic/speech_sanm_kws_phone-xiaoyun-commands-online.git ${local_path}
+
+device="cpu" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
+
+config="inference_sanm_6e_320_256_fdim40_t2602_online.yaml"
+tokens="${local_path}/tokens_2602.txt"
+seg_dict="${local_path}/lexicon.txt"
+init_param="${local_path}/finetune_sanm_6e_320_256_fdim40_t2602_online_xiaoyun_commands.pt"
+cmvn_file="${local_path}/am.mvn.dim40_l3r3"
+
+keywords=(灏忎簯灏忎簯)
+keywords_string=$(IFS=,; echo "${keywords[*]}")
+echo "keywords: $keywords_string"
+
+echo "inference sanm streaming with chunk_size=[4, 8, 4]"
+python -m funasr.bin.inference \
+--config-path "${local_path}/" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++tokenizer_conf.token_list="${tokens}" \
+++tokenizer_conf.seg_dict="${seg_dict}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++chunk_size='[4, 8, 4]' \
+++encoder_chunk_look_back=0 \
+++decoder_chunk_look_back=0 \
+++device="${device}" \
+++keywords="\"$keywords_string"\"
+
+
+echo "inference sanm streaming with chunk_size=[5, 10, 5]"
+python -m funasr.bin.inference \
+--config-path "${local_path}/" \
+--config-name "${config}" \
+++init_param="${init_param}" \
+++frontend_conf.cmvn_file="${cmvn_file}" \
+++tokenizer_conf.token_list="${tokens}" \
+++tokenizer_conf.seg_dict="${seg_dict}" \
+++input="${input}" \
+++output_dir="${output_dir}" \
+++chunk_size='[5, 10, 5]' \
+++encoder_chunk_look_back=0 \
+++decoder_chunk_look_back=0 \
+++device="${device}" \
+++keywords="\"$keywords_string"\"
diff --git a/examples/industrial_data_pretraining/sanm_kws_streaming/path.sh b/examples/industrial_data_pretraining/sanm_kws_streaming/path.sh
new file mode 100755
index 0000000..7972642
--- /dev/null
+++ b/examples/industrial_data_pretraining/sanm_kws_streaming/path.sh
@@ -0,0 +1,5 @@
+export FUNASR_DIR=$PWD/../../..
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index ca1f202..9f5f4fb 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -14,6 +14,7 @@
 import numpy as np
 from tqdm import tqdm
 
+from omegaconf import DictConfig, ListConfig
 from funasr.utils.misc import deep_update
 from funasr.register import tables
 from funasr.utils.load_utils import load_bytes
@@ -187,21 +188,59 @@
 
         # build tokenizer
         tokenizer = kwargs.get("tokenizer", None)
-        if tokenizer is not None:
-            tokenizer_class = tables.tokenizer_classes.get(tokenizer)
-            tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
-            kwargs["token_list"] = (
-                tokenizer.token_list if hasattr(tokenizer, "token_list") else None
-            )
-            kwargs["token_list"] = (
-                tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else kwargs["token_list"]
-            )
-            vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
-            if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
-                vocab_size = tokenizer.get_vocab_size()
-        else:
-            vocab_size = -1
         kwargs["tokenizer"] = tokenizer
+        kwargs["vocab_size"] = -1
+
+        if tokenizer is not None:
+            tokenizers = (
+                tokenizer.split(",") if isinstance(tokenizer, str) else tokenizer
+            )  # type of tokenizers is list!!!
+            tokenizers_conf = kwargs.get("tokenizer_conf", {})
+            tokenizers_build = []
+            vocab_sizes = []
+            token_lists = []
+            ### === only for kws ===
+            token_list_files = kwargs.get("token_lists", [])
+            seg_dicts = kwargs.get("seg_dicts", [])
+            ### === only for kws ===
+
+            if not isinstance(tokenizers_conf, (list, tuple, ListConfig)):
+                tokenizers_conf = [tokenizers_conf] * len(tokenizers)
+
+            for i, tokenizer in enumerate(tokenizers):
+                tokenizer_class = tables.tokenizer_classes.get(tokenizer)
+                tokenizer_conf = tokenizers_conf[i]
+
+                ### === only for kws ===
+                if len(token_list_files) > 1:
+                    tokenizer_conf.token_list = token_list_files[i]
+                if len(seg_dicts) > 1:
+                    tokenizer_conf.seg_dict = seg_dicts[i]
+                ### === only for kws ===
+
+                tokenizer = tokenizer_class(**tokenizer_conf)
+                tokenizers_build.append(tokenizer)
+                token_list = tokenizer.token_list if hasattr(tokenizer, "token_list") else None
+                token_list = (
+                    tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else token_list
+                )
+                vocab_size = -1
+                if token_list is not None:
+                    vocab_size = len(token_list)
+
+                    if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
+                        vocab_size = tokenizer.get_vocab_size()
+                token_lists.append(token_list)
+                vocab_sizes.append(vocab_size)
+
+            if len(tokenizers_build) <= 1:
+                tokenizers_build = tokenizers_build[0]
+                token_lists = token_lists[0]
+                vocab_sizes = vocab_sizes[0]
+
+            kwargs["tokenizer"] = tokenizers_build
+            kwargs["vocab_size"] = vocab_sizes
+            kwargs["token_list"] = token_lists
 
         # build frontend
         frontend = kwargs.get("frontend", None)
@@ -219,7 +258,7 @@
         model_conf = {}
         deep_update(model_conf, kwargs.get("model_conf", {}))
         deep_update(model_conf, kwargs)
-        model = model_class(**model_conf, vocab_size=vocab_size)
+        model = model_class(**model_conf)
 
         # init_param
         init_param = kwargs.get("init_param", None)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 2729b80..fcd763f 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -20,6 +20,7 @@
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 from torch.distributed.algorithms.join import Join
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+from tensorboardX import SummaryWriter
 from funasr.train_utils.average_nbest_models import average_checkpoints
 
 from funasr.register import tables
@@ -191,8 +192,6 @@
     tensorboard_dir = os.path.join(kwargs.get("output_dir"), "tensorboard")
     os.makedirs(tensorboard_dir, exist_ok=True)
     try:
-        from tensorboardX import SummaryWriter
-
         writer = SummaryWriter(tensorboard_dir)  # if trainer.rank == 0 else None
     except:
         writer = None
diff --git a/funasr/datasets/audio_datasets/datasets.py b/funasr/datasets/audio_datasets/datasets.py
index 2aafde3..68b2d3c 100644
--- a/funasr/datasets/audio_datasets/datasets.py
+++ b/funasr/datasets/audio_datasets/datasets.py
@@ -1,6 +1,7 @@
 import torch
 import random
 
+
 from funasr.register import tables
 from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
 
@@ -17,6 +18,7 @@
         index_ds: str = None,
         frontend=None,
         tokenizer=None,
+        is_training: bool = True,
         int_pad_value: int = -1,
         float_pad_value: float = 0.0,
         **kwargs,
@@ -24,18 +26,23 @@
         super().__init__()
         index_ds_class = tables.index_ds_classes.get(index_ds)
         self.index_ds = index_ds_class(path, **kwargs)
-        preprocessor_speech = kwargs.get("preprocessor_speech", None)
-        if preprocessor_speech:
-            preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
-            preprocessor_speech = preprocessor_speech_class(
-                **kwargs.get("preprocessor_speech_conf")
-            )
-        self.preprocessor_speech = preprocessor_speech
-        preprocessor_text = kwargs.get("preprocessor_text", None)
-        if preprocessor_text:
-            preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
-            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
-        self.preprocessor_text = preprocessor_text
+
+        self.preprocessor_speech = None
+        self.preprocessor_text = None
+
+        if is_training:
+            preprocessor_speech = kwargs.get("preprocessor_speech", None)
+            if preprocessor_speech:
+                preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+                preprocessor_speech = preprocessor_speech_class(
+                    **kwargs.get("preprocessor_speech_conf")
+                )
+            self.preprocessor_speech = preprocessor_speech
+            preprocessor_text = kwargs.get("preprocessor_text", None)
+            if preprocessor_text:
+                preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+                preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+            self.preprocessor_text = preprocessor_text
 
         self.frontend = frontend
         self.fs = 16000 if frontend is None else frontend.fs
@@ -64,6 +71,7 @@
         data_src = load_audio_text_image_video(source, fs=self.fs)
         if self.preprocessor_speech:
             data_src = self.preprocessor_speech(data_src, fs=self.fs)
+
         speech, speech_lengths = extract_fbank(
             data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
         )  # speech: [b, T, d]
@@ -71,6 +79,7 @@
         target = item["target"]
         if self.preprocessor_text:
             target = self.preprocessor_text(target)
+
         if self.tokenizer:
             ids = self.tokenizer.encode(target)
             text = torch.tensor(ids, dtype=torch.int64)
diff --git a/funasr/datasets/audio_datasets/scp2jsonl.py b/funasr/datasets/audio_datasets/scp2jsonl.py
index f4c9d74..48c64d2 100644
--- a/funasr/datasets/audio_datasets/scp2jsonl.py
+++ b/funasr/datasets/audio_datasets/scp2jsonl.py
@@ -58,7 +58,8 @@
             for key in json_dict[data_type_list[0]].keys():
                 jsonl_line = {"key": key}
                 for data_file in data_type_list:
-                    jsonl_line.update(json_dict[data_file][key])
+                    if key in json_dict[data_file]:
+                        jsonl_line.update(json_dict[data_file][key])
                 jsonl_line = json.dumps(jsonl_line, ensure_ascii=False)
                 f.write(jsonl_line + "\n")
                 f.flush()
@@ -81,10 +82,14 @@
         key = lines[0]
         line = lines[1] if len(lines) > 1 else ""
         line = line.strip()
-        if os.path.exists(line):
-            waveform, _ = librosa.load(line, sr=16000)
-            sample_num = len(waveform)
-            context_len = int(sample_num / 16000 * 1000 / 10)
+        if data_type == "source":
+            if os.path.exists(line):
+                waveform, _ = librosa.load(line, sr=16000)
+                sample_num = len(waveform)
+                context_len = int(sample_num * 1000 / 16000 / 10)
+            else:
+                print("source file not found: {}".format(line))
+                continue
         else:
             context_len = len(line.split()) if " " in line else len(line)
         res[key] = {data_type: line, f"{data_type}_len": context_len}
diff --git a/funasr/datasets/kws_datasets/__init__.py b/funasr/datasets/kws_datasets/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/datasets/kws_datasets/__init__.py
diff --git a/funasr/datasets/kws_datasets/datasets.py b/funasr/datasets/kws_datasets/datasets.py
new file mode 100644
index 0000000..4679295
--- /dev/null
+++ b/funasr/datasets/kws_datasets/datasets.py
@@ -0,0 +1,132 @@
+import torch
+import random
+
+
+from funasr.register import tables
+from funasr.utils.load_utils import extract_fbank, load_audio_text_image_video
+
+
+@tables.register("dataset_classes", "KwsMTDataset")
+class KwsMTDataset(torch.utils.data.Dataset):
+    """
+    KwsMTDataset, support multi tokenizers
+    """
+    def __init__(self,
+                 path,
+                 index_ds: str = None,
+                 frontend=None,
+                 tokenizer=None,
+                 is_training: bool = True,
+                 int_pad_value: int = -1,
+                 float_pad_value: float = 0.0,
+                 **kwargs,
+    ):
+        super().__init__()
+        index_ds_class = tables.index_ds_classes.get(index_ds)
+        self.index_ds = index_ds_class(path, **kwargs)
+
+        self.preprocessor_speech = None
+        self.preprocessor_text = None
+
+        if is_training:
+            preprocessor_speech = kwargs.get("preprocessor_speech", None)
+            if preprocessor_speech:
+                preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
+                preprocessor_speech = preprocessor_speech_class(
+                    **kwargs.get("preprocessor_speech_conf")
+                )
+            self.preprocessor_speech = preprocessor_speech
+
+            preprocessor_text = kwargs.get("preprocessor_text", None)
+            if preprocessor_text:
+                preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
+                preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+            self.preprocessor_text = preprocessor_text
+
+        self.frontend = frontend
+        self.fs = 16000 if frontend is None else frontend.fs
+        self.data_type = "sound"
+        print(tokenizer)
+        self.tokenizer = tokenizer
+
+        self.int_pad_value = int_pad_value
+        self.float_pad_value = float_pad_value
+
+    def get_source_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_source_len(item)
+
+    def get_target_len(self, index):
+        item = self.index_ds[index]
+        return self.index_ds.get_target_len(item)
+
+    def __len__(self):
+        return len(self.index_ds)
+
+    def __getitem__(self, index):
+        item = self.index_ds[index]
+        # import pdb;
+        # pdb.set_trace()
+        source = item["source"]
+        data_src = load_audio_text_image_video(source, fs=self.fs)
+        if self.preprocessor_speech:
+            data_src = self.preprocessor_speech(data_src, fs=self.fs)
+        speech, speech_lengths = extract_fbank(
+            data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
+        ) # speech: [b, T, d]
+
+        target = item["target"]
+        if self.preprocessor_text:
+            target = self.preprocessor_text(target)
+
+        if self.tokenizer[0]:
+            ids = self.tokenizer[0].encode(target)
+            text = torch.tensor(ids, dtype=torch.int64)
+            # print("target: ", target,  ", ids: ", str(ids))
+        else:
+            ids = target
+            text = ids
+
+        if self.tokenizer[1]:
+            ids2 = self.tokenizer[1].encode(target)
+            text2 = torch.tensor(ids2, dtype=torch.int64)
+            # print("target: ", target,  ", ids2: ", str(ids2))
+        else:
+            ids2 = target
+            text2 = ids2
+
+        ids_lengths = len(ids)
+        text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
+
+        ids2_lengths = len(ids2)
+        text2_lengths = torch.tensor([ids2_lengths], dtype=torch.int32)
+
+        return {"speech": speech[0, :, :],
+                "speech_lengths": speech_lengths,
+                "text": text,
+                "text_lengths": text_lengths,
+                "text2": text2,
+                "text2_lengths": text2_lengths,
+                }
+
+
+    def collator(self, samples: list=None):
+        outputs = {}
+        for sample in samples:
+            for key in sample.keys():
+                if key not in outputs:
+                    outputs[key] = []
+                outputs[key].append(sample[key])
+
+        for key, data_list in outputs.items():
+            if isinstance(data_list[0], torch.Tensor):
+                if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
+
+                    pad_value = self.int_pad_value
+                else:
+                    pad_value = self.float_pad_value
+
+                outputs[key] = torch.nn.utils.rnn.pad_sequence(
+                    data_list, batch_first=True, padding_value=pad_value
+                )
+        return outputs
diff --git a/funasr/download/download_model_from_hub.py b/funasr/download/download_model_from_hub.py
index df4f33d..8e51144 100644
--- a/funasr/download/download_model_from_hub.py
+++ b/funasr/download/download_model_from_hub.py
@@ -51,7 +51,8 @@
             cfg = {}
             if "file_path_metas" in conf_json:
                 add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
-            cfg.update(kwargs)
+            # cfg.update(kwargs)
+            cfg = OmegaConf.merge(cfg, kwargs)
             if "config" in cfg:
                 config = OmegaConf.load(cfg["config"])
                 kwargs = OmegaConf.merge(config, cfg)
@@ -159,15 +160,41 @@
 def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg={}):
 
     if isinstance(file_path_metas, dict):
+        if isinstance(cfg, list):
+            cfg.append({})
         for k, v in file_path_metas.items():
             if isinstance(v, str):
                 p = os.path.join(model_or_path, v)
                 if os.path.exists(p):
-                    cfg[k] = p
+                    if isinstance(cfg, dict):
+                        cfg[k] = p
+                    elif isinstance(cfg, list):
+                        # if len(cfg) == 0:
+                        # cfg.append({})
+                        cfg[-1][k] = p
+
             elif isinstance(v, dict):
-                if k not in cfg:
-                    cfg[k] = {}
-                add_file_root_path(model_or_path, v, cfg[k])
+                if isinstance(cfg, dict):
+                    if k not in cfg:
+                        cfg[k] = {}
+                    add_file_root_path(model_or_path, v, cfg[k])
+                # elif isinstance(cfg, list):
+                #     cfg.append({})
+                #     add_file_root_path(model_or_path, v, cfg)
+            elif isinstance(v, (list, tuple)):
+                for i, vv in enumerate(v):
+                    if k not in cfg:
+                        cfg[k] = []
+                    if isinstance(vv, str):
+                        p = os.path.join(model_or_path, v)
+                        file_path_metas[i] = p
+                        if os.path.exists(p):
+                            if isinstance(cfg[k], dict):
+                                cfg[k] = p
+                            elif isinstance(cfg[k], list):
+                                cfg[k].append(p)
+                    elif isinstance(vv, dict):
+                        add_file_root_path(model_or_path, vv, cfg[k])
 
     return cfg
 
diff --git a/funasr/models/ctc/ctc.py b/funasr/models/ctc/ctc.py
index bdfb3a6..8eb64d1 100644
--- a/funasr/models/ctc/ctc.py
+++ b/funasr/models/ctc/ctc.py
@@ -23,11 +23,17 @@
         ctc_type: str = "builtin",
         reduce: bool = True,
         ignore_nan_grad: bool = True,
+        extra_linear: bool = True,
     ):
         super().__init__()
         eprojs = encoder_output_size
         self.dropout_rate = dropout_rate
-        self.ctc_lo = torch.nn.Linear(eprojs, odim)
+
+        if extra_linear:
+            self.ctc_lo = torch.nn.Linear(eprojs, odim)
+        else:
+            self.ctc_lo = None
+
         self.ctc_type = ctc_type
         self.ignore_nan_grad = ignore_nan_grad
 
@@ -130,7 +136,10 @@
             ys_lens: batch of lengths of character sequence (B)
         """
         # hs_pad: (B, L, NProj) -> ys_hat: (B, L, Nvocab)
-        ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
+        if self.ctc_lo is not None:
+            ys_hat = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
+        else:
+            ys_hat = hs_pad
 
         if self.ctc_type == "gtnctc":
             # gtn expects list form for ys
@@ -141,6 +150,7 @@
             # (B, L) -> (BxL,)
             ys_true = torch.cat([ys_pad[i, :l] for i, l in enumerate(ys_lens)])
 
+        hlens = hlens.to(hs_pad.device)
         loss = self.loss_fn(ys_hat, ys_true, hlens, ys_lens).to(
             device=hs_pad.device, dtype=hs_pad.dtype
         )
@@ -155,7 +165,10 @@
         Returns:
             torch.Tensor: softmax applied 3d tensor (B, Tmax, odim)
         """
-        return F.softmax(self.ctc_lo(hs_pad), dim=2)
+        if self.ctc_lo is not None:
+            return F.softmax(self.ctc_lo(hs_pad), dim=2)
+        else:
+            return F.softmax(hs_pad, dim=2)
 
     def log_softmax(self, hs_pad):
         """log_softmax of frame activations
@@ -165,7 +178,10 @@
         Returns:
             torch.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
         """
-        return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
+        if self.ctc_lo is not None:
+            return F.log_softmax(self.ctc_lo(hs_pad), dim=2)
+        else:
+            return F.log_softmax(hs_pad, dim=2)
 
     def argmax(self, hs_pad):
         """argmax of frame activations
@@ -175,4 +191,7 @@
         Returns:
             torch.Tensor: argmax applied 2d tensor (B, Tmax)
         """
-        return torch.argmax(self.ctc_lo(hs_pad), dim=2)
+        if self.ctc_lo is not None:
+            return torch.argmax(self.ctc_lo(hs_pad), dim=2)
+        else:
+            return torch.argmax(hs_pad, dim=2)
diff --git a/funasr/models/fsmn_kws/__init__.py b/funasr/models/fsmn_kws/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/fsmn_kws/__init__.py
diff --git a/funasr/models/fsmn_kws/encoder.py b/funasr/models/fsmn_kws/encoder.py
new file mode 100755
index 0000000..2c31687
--- /dev/null
+++ b/funasr/models/fsmn_kws/encoder.py
@@ -0,0 +1,534 @@
+from typing import Tuple, Dict
+import copy
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from funasr.register import tables
+
+
+def toKaldiMatrix(np_mat):
+    np.set_printoptions(threshold=np.inf, linewidth=np.nan)
+    out_str = str(np_mat)
+    out_str = out_str.replace('[', '')
+    out_str = out_str.replace(']', '')
+    return '[ %s ]\n' % out_str
+
+
+class LinearTransform(nn.Module):
+
+    def __init__(self, input_dim, output_dim):
+        super(LinearTransform, self).__init__()
+        self.input_dim = input_dim
+        self.output_dim = output_dim
+        self.linear = nn.Linear(input_dim, output_dim, bias=False)
+
+    def forward(self, input):
+        output = self.linear(input)
+
+        return output
+
+    def to_kaldi_net(self):
+        re_str = ''
+        re_str += '<LinearTransform> %d %d\n' % (self.output_dim,
+                                                 self.input_dim)
+        re_str += '<LearnRateCoef> 1\n'
+
+        linear_weights = self.state_dict()['linear.weight']
+        x = linear_weights.squeeze().numpy()
+        re_str += toKaldiMatrix(x)
+
+        return re_str
+
+    def to_pytorch_net(self, fread):
+        linear_line = fread.readline()
+        linear_split = linear_line.strip().split()
+        assert len(linear_split) == 3
+        assert linear_split[0] == '<LinearTransform>'
+        self.output_dim = int(linear_split[1])
+        self.input_dim = int(linear_split[2])
+
+        learn_rate_line = fread.readline()
+        assert learn_rate_line.find('LearnRateCoef') != -1
+
+        self.linear.reset_parameters()
+
+        linear_weights = self.state_dict()['linear.weight']
+        #print(linear_weights.shape)
+        new_weights = torch.zeros((self.output_dim, self.input_dim),
+                                  dtype=torch.float32)
+        for i in range(self.output_dim):
+            line = fread.readline()
+            splits = line.strip().strip('\[\]').strip().split()
+            assert len(splits) == self.input_dim
+            cols = torch.tensor([float(item) for item in splits],
+                                dtype=torch.float32)
+            new_weights[i, :] = cols
+
+        self.linear.weight.data = new_weights
+
+
+class AffineTransform(nn.Module):
+
+    def __init__(self, input_dim, output_dim):
+        super(AffineTransform, self).__init__()
+        self.input_dim = input_dim
+        self.output_dim = output_dim
+        self.linear = nn.Linear(input_dim, output_dim)
+
+    def forward(self, input):
+        output = self.linear(input)
+
+        return output
+
+    def to_kaldi_net(self):
+        re_str = ''
+        re_str += '<AffineTransform> %d %d\n' % (self.output_dim,
+                                                 self.input_dim)
+        re_str += '<LearnRateCoef> 1 <BiasLearnRateCoef> 1 <MaxNorm> 0\n'
+
+        linear_weights = self.state_dict()['linear.weight']
+        x = linear_weights.squeeze().numpy()
+        re_str += toKaldiMatrix(x)
+
+        linear_bias = self.state_dict()['linear.bias']
+        x = linear_bias.squeeze().numpy()
+        re_str += toKaldiMatrix(x)
+
+        return re_str
+
+    def to_pytorch_net(self, fread):
+        affine_line = fread.readline()
+        affine_split = affine_line.strip().split()
+        assert len(affine_split) == 3
+        assert affine_split[0] == '<AffineTransform>'
+        self.output_dim = int(affine_split[1])
+        self.input_dim = int(affine_split[2])
+        print('AffineTransform output/input dim: %d %d' %
+              (self.output_dim, self.input_dim))
+
+        learn_rate_line = fread.readline()
+        assert learn_rate_line.find('LearnRateCoef') != -1
+
+        #linear_weights = self.state_dict()['linear.weight']
+        #print(linear_weights.shape)
+        self.linear.reset_parameters()
+
+        new_weights = torch.zeros((self.output_dim, self.input_dim),
+                                  dtype=torch.float32)
+        for i in range(self.output_dim):
+            line = fread.readline()
+            splits = line.strip().strip('\[\]').strip().split()
+            assert len(splits) == self.input_dim
+            cols = torch.tensor([float(item) for item in splits],
+                                dtype=torch.float32)
+            new_weights[i, :] = cols
+
+        self.linear.weight.data = new_weights
+
+        linear_bias = self.state_dict()['linear.bias']
+        #print(linear_bias.shape)
+        bias_line = fread.readline()
+        splits = bias_line.strip().strip('\[\]').strip().split()
+        assert len(splits) == self.output_dim
+        new_bias = torch.tensor([float(item) for item in splits],
+                                dtype=torch.float32)
+
+        self.linear.bias.data = new_bias
+
+
+class RectifiedLinear(nn.Module):
+
+    def __init__(self, input_dim, output_dim):
+        super(RectifiedLinear, self).__init__()
+        self.dim = input_dim
+        self.relu = nn.ReLU()
+        self.dropout = nn.Dropout(0.1)
+
+    def forward(self, input):
+        out = self.relu(input)
+        return out
+
+    def to_kaldi_net(self):
+        re_str = ''
+        re_str += '<RectifiedLinear> %d %d\n' % (self.dim, self.dim)
+        return re_str
+
+    def to_pytorch_net(self, fread):
+        line = fread.readline()
+        splits = line.strip().split()
+        assert len(splits) == 3
+        assert splits[0] == '<RectifiedLinear>'
+        assert int(splits[1]) == int(splits[2])
+        assert int(splits[1]) == self.dim
+        self.dim = int(splits[1])
+
+
+class FSMNBlock(nn.Module):
+
+    def __init__(
+        self,
+        input_dim: int,
+        output_dim: int,
+        lorder=None,
+        rorder=None,
+        lstride=1,
+        rstride=1,
+    ):
+        super(FSMNBlock, self).__init__()
+
+        self.dim = input_dim
+
+        if lorder is None:
+            return
+
+        self.lorder = lorder
+        self.rorder = rorder
+        self.lstride = lstride
+        self.rstride = rstride
+
+        self.conv_left = nn.Conv2d(
+            self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False
+        )
+
+        if self.rorder > 0:
+            self.conv_right = nn.Conv2d(
+                self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False
+            )
+        else:
+            self.conv_right = None
+
+    def forward(self, input: torch.Tensor, cache: torch.Tensor = None):
+        x = torch.unsqueeze(input, 1)
+        x_per = x.permute(0, 3, 2, 1)  # B D T C
+
+        if cache is not None:
+            cache = cache.to(x_per.device)
+            y_left = torch.cat((cache, x_per), dim=2)
+            cache = y_left[:, :, -(self.lorder - 1) * self.lstride :, :]
+        else:
+            y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
+
+        y_left = self.conv_left(y_left)
+        out = x_per + y_left
+
+        if self.conv_right is not None:
+            # maybe need to check
+            y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
+            y_right = y_right[:, :, self.rstride :, :]
+            y_right = self.conv_right(y_right)
+            out += y_right
+
+        out_per = out.permute(0, 3, 2, 1)
+        output = out_per.squeeze(1)
+
+        return output, cache
+
+    def to_kaldi_net(self):
+        re_str = ''
+        re_str += '<Fsmn> %d %d\n' % (self.dim, self.dim)
+        re_str += '<LearnRateCoef> %d <LOrder> %d <ROrder> %d <LStride> %d <RStride> %d <MaxNorm> 0\n' % (
+            1, self.lorder, self.rorder, self.lstride, self.rstride)
+
+        #print(self.conv_left.weight,self.conv_right.weight)
+        lfiters = self.state_dict()['conv_left.weight']
+        x = np.flipud(lfiters.squeeze().numpy().T)
+        re_str += toKaldiMatrix(x)
+
+        if self.conv_right is not None:
+            rfiters = self.state_dict()['conv_right.weight']
+            x = (rfiters.squeeze().numpy().T)
+            re_str += toKaldiMatrix(x)
+
+        return re_str
+
+    def to_pytorch_net(self, fread):
+        fsmn_line = fread.readline()
+        fsmn_split = fsmn_line.strip().split()
+        assert len(fsmn_split) == 3
+        assert fsmn_split[0] == '<Fsmn>'
+        self.dim = int(fsmn_split[1])
+
+        params_line = fread.readline()
+        params_split = params_line.strip().strip('\[\]').strip().split()
+        assert len(params_split) == 12
+        assert params_split[0] == '<LearnRateCoef>'
+        assert params_split[2] == '<LOrder>'
+        self.lorder = int(params_split[3])
+        assert params_split[4] == '<ROrder>'
+        self.rorder = int(params_split[5])
+        assert params_split[6] == '<LStride>'
+        self.lstride = int(params_split[7])
+        assert params_split[8] == '<RStride>'
+        self.rstride = int(params_split[9])
+        assert params_split[10] == '<MaxNorm>'
+
+        #lfilters = self.state_dict()['conv_left.weight']
+        #print(lfilters.shape)
+        print('read conv_left weight')
+        new_lfilters = torch.zeros((self.lorder, 1, self.dim, 1),
+                                   dtype=torch.float32)
+        for i in range(self.lorder):
+            print('read conv_left weight -- %d' % i)
+            line = fread.readline()
+            splits = line.strip().strip('\[\]').strip().split()
+            assert len(splits) == self.dim
+            cols = torch.tensor([float(item) for item in splits],
+                                dtype=torch.float32)
+            new_lfilters[self.lorder - 1 - i, 0, :, 0] = cols
+
+        new_lfilters = torch.transpose(new_lfilters, 0, 2)
+        #print(new_lfilters.shape)
+
+        self.conv_left.reset_parameters()
+        self.conv_left.weight.data = new_lfilters
+        #print(self.conv_left.weight.shape)
+
+        if self.rorder > 0:
+            #rfilters = self.state_dict()['conv_right.weight']
+            #print(rfilters.shape)
+            print('read conv_right weight')
+            new_rfilters = torch.zeros((self.rorder, 1, self.dim, 1),
+                                       dtype=torch.float32)
+            line = fread.readline()
+            for i in range(self.rorder):
+                print('read conv_right weight -- %d' % i)
+                line = fread.readline()
+                splits = line.strip().strip('\[\]').strip().split()
+                assert len(splits) == self.dim
+                cols = torch.tensor([float(item) for item in splits],
+                                    dtype=torch.float32)
+                new_rfilters[i, 0, :, 0] = cols
+
+            new_rfilters = torch.transpose(new_rfilters, 0, 2)
+            #print(new_rfilters.shape)
+            self.conv_right.reset_parameters()
+            self.conv_right.weight.data = new_rfilters
+            #print(self.conv_right.weight.shape)
+
+class BasicBlock(nn.Module):
+    def __init__(
+        self,
+        linear_dim: int,
+        proj_dim: int,
+        lorder: int,
+        rorder: int,
+        lstride: int,
+        rstride: int,
+        stack_layer: int,
+    ):
+        super(BasicBlock, self).__init__()
+        self.lorder = lorder
+        self.rorder = rorder
+        self.lstride = lstride
+        self.rstride = rstride
+        self.stack_layer = stack_layer
+        self.linear = LinearTransform(linear_dim, proj_dim)
+        self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
+        self.affine = AffineTransform(proj_dim, linear_dim)
+        self.relu = RectifiedLinear(linear_dim, linear_dim)
+
+    def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor] = None):
+        x1 = self.linear(input)  # B T D
+
+        if cache is not None:
+            cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+            if cache_layer_name not in cache:
+                cache[cache_layer_name] = torch.zeros(
+                    x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1
+                )
+            x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
+        else:
+            x2, _ = self.fsmn_block(x1, None)
+        x3 = self.affine(x2)
+        x4 = self.relu(x3)
+        return x4
+
+    def to_kaldi_net(self):
+        re_str = ''
+        re_str += self.linear.to_kaldi_net()
+        re_str += self.fsmn_block.to_kaldi_net()
+        re_str += self.affine.to_kaldi_net()
+        re_str += self.relu.to_kaldi_net()
+
+        return re_str
+
+    def to_pytorch_net(self, fread):
+        self.linear.to_pytorch_net(fread)
+        self.fsmn_block.to_pytorch_net(fread)
+        self.affine.to_pytorch_net(fread)
+        self.relu.to_pytorch_net(fread)
+
+
+class BasicBlock_export(nn.Module):
+    def __init__(
+        self,
+        model,
+    ):
+        super(BasicBlock_export, self).__init__()
+        self.linear = model.linear
+        self.fsmn_block = model.fsmn_block
+        self.affine = model.affine
+        self.relu = model.relu
+
+    def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
+        x = self.linear(input)  # B T D
+        # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+        # if cache_layer_name not in in_cache:
+        #     in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+        x, out_cache = self.fsmn_block(x, in_cache)
+        x = self.affine(x)
+        x = self.relu(x)
+        return x, out_cache
+
+
+class FsmnStack(nn.Sequential):
+    def __init__(self, *args):
+        super(FsmnStack, self).__init__(*args)
+
+    def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
+        x = input
+        for module in self._modules.values():
+            x = module(x, cache)
+        return x
+
+    def to_kaldi_net(self):
+        re_str = ''
+        for module in self._modules.values():
+            re_str += module.to_kaldi_net()
+
+        return re_str
+
+    def to_pytorch_net(self, fread):
+        for module in self._modules.values():
+            module.to_pytorch_net(fread)
+
+
+"""
+FSMN net for keyword spotting
+input_dim:              input dimension
+linear_dim:             fsmn input dimensionll
+proj_dim:               fsmn projection dimension
+lorder:                 fsmn left order
+rorder:                 fsmn right order
+num_syn:                output dimension
+fsmn_layers:            no. of sequential fsmn layers
+"""
+
+
+@tables.register("encoder_classes", "FSMNConvert")
+class FSMNConvert(nn.Module):
+    def __init__(
+        self,
+        input_dim: int,
+        input_affine_dim: int,
+        fsmn_layers: int,
+        linear_dim: int,
+        proj_dim: int,
+        lorder: int,
+        rorder: int,
+        lstride: int,
+        rstride: int,
+        output_affine_dim: int,
+        output_dim: int,
+        use_softmax: bool = True,
+    ):
+        super().__init__()
+
+        self.input_dim = input_dim
+        self.input_affine_dim = input_affine_dim
+        self.fsmn_layers = fsmn_layers
+        self.linear_dim = linear_dim
+        self.proj_dim = proj_dim
+        self.output_affine_dim = output_affine_dim
+        self.output_dim = output_dim
+
+        self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+        self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+        self.relu = RectifiedLinear(linear_dim, linear_dim)
+        self.fsmn = FsmnStack(
+            *[
+                BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i)
+                for i in range(fsmn_layers)
+            ]
+        )
+        self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+        self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+
+        self.use_softmax = use_softmax
+        if self.use_softmax:
+            self.softmax = nn.Softmax(dim=-1)
+
+    def output_size(self) -> int:
+        return self.output_dim
+
+    def forward(
+        self,
+        input: torch.Tensor,
+        cache: Dict[str, torch.Tensor] = None
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+        """
+        Args:
+            input (torch.Tensor): Input tensor (B, T, D)
+            cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs,
+            {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
+        """
+
+        x1 = self.in_linear1(input)
+        x2 = self.in_linear2(x1)
+        x3 = self.relu(x2)
+        x4 = self.fsmn(x3, cache)  # self.cache will update automatically in self.fsmn
+        x5 = self.out_linear1(x4)
+        x6 = self.out_linear2(x5)
+
+        if self.use_softmax:
+            x7 = self.softmax(x6)
+            return x7
+
+        return x6
+
+    def to_kaldi_net(self):
+        re_str = ''
+        re_str += '<Nnet>\n'
+        re_str += self.in_linear1.to_kaldi_net()
+        re_str += self.in_linear2.to_kaldi_net()
+        re_str += self.relu.to_kaldi_net()
+
+        for fsmn in self.fsmn:
+            re_str += fsmn.to_kaldi_net()
+
+        re_str += self.out_linear1.to_kaldi_net()
+        re_str += self.out_linear2.to_kaldi_net()
+        re_str += '<Softmax> %d %d\n' % (self.output_dim, self.output_dim)
+        re_str += '</Nnet>\n'
+
+        return re_str
+
+    def to_pytorch_net(self, kaldi_file):
+        with open(kaldi_file, 'r', encoding='utf8') as fread:
+            fread = open(kaldi_file, 'r')
+            nnet_start_line = fread.readline()
+            assert nnet_start_line.strip() == '<Nnet>'
+
+            self.in_linear1.to_pytorch_net(fread)
+            self.in_linear2.to_pytorch_net(fread)
+            self.relu.to_pytorch_net(fread)
+
+            for fsmn in self.fsmn:
+                fsmn.to_pytorch_net(fread)
+
+            self.out_linear1.to_pytorch_net(fread)
+            self.out_linear2.to_pytorch_net(fread)
+
+            softmax_line = fread.readline()
+            softmax_split = softmax_line.strip().split()
+            assert softmax_split[0].strip() == '<Softmax>'
+            assert int(softmax_split[1]) == self.output_dim
+            assert int(softmax_split[2]) == self.output_dim
+
+            nnet_end_line = fread.readline()
+            assert nnet_end_line.strip() == '</Nnet>'
+        fread.close()
diff --git a/funasr/models/fsmn_kws/model.py b/funasr/models/fsmn_kws/model.py
new file mode 100644
index 0000000..066730f
--- /dev/null
+++ b/funasr/models/fsmn_kws/model.py
@@ -0,0 +1,285 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import time
+import torch
+import logging
+from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
+
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+
+@tables.register("model_classes", "FsmnKWS")
+class FsmnKWS(torch.nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
+    def __init__(
+        self,
+        specaug: Optional[str] = None,
+        specaug_conf: Optional[Dict] = None,
+        normalize: str = None,
+        normalize_conf: Optional[Dict] = None,
+        encoder: str = None,
+        encoder_conf: Optional[Dict] = None,
+        ctc: str = None,
+        ctc_conf: Optional[Dict] = None,
+        ctc_weight: float = 1.0,
+        input_size: int = 360,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        **kwargs,
+    ):
+        super().__init__()
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(**encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        if ctc_conf is None:
+            ctc_conf = {}
+        ctc = CTC(
+            odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+        )
+
+        self.blank_id = blank_id
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.ctc_weight = ctc_weight
+
+        # self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+        self.encoder = encoder
+        self.ctc = ctc
+
+        self.error_calculator = None
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+        batch_size = speech.shape[0]
+
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        loss_ctc, cer_ctc = self._calc_ctc_loss(
+            encoder_out, encoder_out_lens, text, text_lengths
+        )
+
+        # Collect CTC branch stats
+        stats = dict()
+        stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+        stats["cer_ctc"] = cer_ctc
+
+        loss = self.ctc_weight * loss_ctc
+
+        stats["cer"] = cer_ctc
+        stats["loss"] = torch.clone(loss.detach())
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+        # Forward encoder
+        encoder_out = self.encoder(speech)
+        encoder_out_lens = speech_lengths
+
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        return encoder_out, encoder_out_lens
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+
+        return loss_ctc, cer_ctc
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list=None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+        keywords = kwargs.get("keywords")
+        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
+        self.kws_decoder = KwsCtcPrefixDecoder(
+            ctc=self.ctc,
+            keywords=keywords,
+            token_list=tokenizer.token_list,
+            seg_dict=tokenizer.seg_dict,
+        )
+
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is not None:
+                speech_lengths = speech_lengths.squeeze(-1)
+            else:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), data_type=kwargs.get("data_type", "sound"), tokenizer=tokenizer)
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        results = []
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+
+        for i in range(encoder_out.size(0)):
+            x = encoder_out[i, :encoder_out_lens[i], :]
+            detect_result = self.kws_decoder.decode(x)
+            is_deted, det_keyword, det_score = detect_result[0], detect_result[1], detect_result[2]
+
+            if is_deted:
+                self.writer["detect"][key[i]] = "detected " + det_keyword + " " + str(det_score)
+                det_info = "detected " + det_keyword + " " + str(det_score)
+            else:
+                self.writer["detect"][key[i]] = "rejected"
+                det_info = "rejected"
+
+            result_i = {"key": key[i], "text": det_info}
+            results.append(result_i)
+
+        return results, meta_data
+
+
+@tables.register("model_classes", "FsmnKWSConvert")
+class FsmnKWSConvert(torch.nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
+    def __init__(
+        self,
+        encoder: str = None,
+        encoder_conf: Optional[Dict] = None,
+        ctc: str = None,
+        ctc_conf: Optional[Dict] = None,
+        ctc_weight: float = 1.0,
+        input_size: int = 360,
+        vocab_size: int = -1,
+        blank_id: int = 0,
+        **kwargs,
+    ):
+        super().__init__()
+
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(**encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        if ctc_conf is None:
+            ctc_conf = {}
+        ctc = CTC(
+            odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+        )
+
+        self.blank_id = blank_id
+        self.vocab_size = vocab_size
+        self.ctc_weight = ctc_weight
+        self.encoder = encoder
+        self.ctc = ctc
+
+        self.error_calculator = None
+
+    def to_kaldi_net(self):
+		    return self.encoder.to_kaldi_net()
+
+
+    def to_pytorch_net(self, kaldi_file):
+		    return self.encoder.to_pytorch_net(kaldi_file)
diff --git a/funasr/models/fsmn_kws_mt/__init__.py b/funasr/models/fsmn_kws_mt/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/fsmn_kws_mt/__init__.py
diff --git a/funasr/models/fsmn_kws_mt/encoder.py b/funasr/models/fsmn_kws_mt/encoder.py
new file mode 100755
index 0000000..4b36d83
--- /dev/null
+++ b/funasr/models/fsmn_kws_mt/encoder.py
@@ -0,0 +1,213 @@
+from typing import Tuple, Dict
+import copy
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from funasr.models.fsmn_kws.encoder import (toKaldiMatrix, LinearTransform, AffineTransform, RectifiedLinear, FSMNBlock, FsmnStack, BasicBlock)
+
+
+from funasr.register import tables
+
+
+'''
+FSMN net for keyword spotting
+input_dim:              input dimension
+linear_dim:             fsmn input dimensionll
+proj_dim:               fsmn projection dimension
+lorder:                 fsmn left order
+rorder:                 fsmn right order
+num_syn:                output dimension
+fsmn_layers:            no. of sequential fsmn layers
+'''
+
+@tables.register("encoder_classes", "FSMNMT")
+class FSMNMT(nn.Module):
+    def __init__(
+            self,
+            input_dim: int,
+            input_affine_dim: int,
+            fsmn_layers: int,
+            linear_dim: int,
+            proj_dim: int,
+            lorder: int,
+            rorder: int,
+            lstride: int,
+            rstride: int,
+            output_affine_dim: int,
+            output_dim: int,
+            output_dim2: int,
+            use_softmax: bool = True,
+    ):
+        super().__init__()
+
+        self.input_dim = input_dim
+        self.input_affine_dim = input_affine_dim
+        self.fsmn_layers = fsmn_layers
+        self.linear_dim = linear_dim
+        self.proj_dim = proj_dim
+        self.output_affine_dim = output_affine_dim
+        self.output_dim = output_dim
+        self.output_dim2 = output_dim2
+
+        self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+        self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+        self.relu = RectifiedLinear(linear_dim, linear_dim)
+        self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+                                range(fsmn_layers)])
+        self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+        self.out_linear1_2 = AffineTransform(linear_dim, output_affine_dim)
+        self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+        self.out_linear2_2 = AffineTransform(output_affine_dim, output_dim2)
+
+        self.use_softmax = use_softmax
+        if self.use_softmax:
+            self.softmax = nn.Softmax(dim=-1)
+
+    def output_size(self) -> int:
+        return self.output_dim
+
+    def output_size2(self) -> int:
+        return self.output_dim2
+
+    def forward(
+            self,
+            input: torch.Tensor,
+            cache: Dict[str, torch.Tensor] = None
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
+        """
+        Args:
+            input (torch.Tensor): Input tensor (B, T, D)
+            cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs,
+            {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
+        """
+
+        x1 = self.in_linear1(input)
+        x2 = self.in_linear2(x1)
+        x3 = self.relu(x2)
+        x4 = self.fsmn(x3, cache)  # self.cache will update automatically in self.fsmn
+        x5 = self.out_linear1(x4)
+        x6 = self.out_linear2(x5)
+
+        x5_2 = self.out_linear1_2(x4)
+        x6_2 = self.out_linear2_2(x5_2)
+
+        if self.use_softmax:
+            x7 = self.softmax(x6)
+            x7_2 = self.softmax(x6_2)
+            return x7, x7_2
+
+        return x6, x6_2
+
+
+@tables.register("encoder_classes", "FSMNMTConvert")
+class FSMNMTConvert(nn.Module):
+    def __init__(
+            self,
+            input_dim: int,
+            input_affine_dim: int,
+            fsmn_layers: int,
+            linear_dim: int,
+            proj_dim: int,
+            lorder: int,
+            rorder: int,
+            lstride: int,
+            rstride: int,
+            output_affine_dim: int,
+            output_dim: int,
+            output_dim2: int,
+            use_softmax: bool = True,
+    ):
+        super().__init__()
+
+        self.input_dim = input_dim
+        self.input_affine_dim = input_affine_dim
+        self.fsmn_layers = fsmn_layers
+        self.linear_dim = linear_dim
+        self.proj_dim = proj_dim
+        self.output_affine_dim = output_affine_dim
+        self.output_dim = output_dim
+        self.output_dim2 = output_dim2
+
+        self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+        self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+        self.relu = RectifiedLinear(linear_dim, linear_dim)
+        self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+                                range(fsmn_layers)])
+        self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+        self.out_linear1_2 = AffineTransform(linear_dim, output_affine_dim)
+        self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+        self.out_linear2_2 = AffineTransform(output_affine_dim, output_dim2)
+
+        self.use_softmax = use_softmax
+        if self.use_softmax:
+            self.softmax = nn.Softmax(dim=-1)
+
+    def output_size(self) -> int:
+        return self.output_dim
+
+    def output_size2(self) -> int:
+        return self.output_dim2
+
+    def to_kaldi_net(self):
+        re_str = ''
+        re_str += '<Nnet>\n'
+        re_str += self.in_linear1.to_kaldi_net()
+        re_str += self.in_linear2.to_kaldi_net()
+        re_str += self.relu.to_kaldi_net()
+
+        for fsmn in self.fsmn:
+            re_str += fsmn.to_kaldi_net()
+
+        re_str += self.out_linear1.to_kaldi_net()
+        re_str += self.out_linear2.to_kaldi_net()
+        re_str += '<Softmax> %d %d\n' % (self.output_dim, self.output_dim)
+        re_str += '</Nnet>\n'
+
+        return re_str
+
+    def to_kaldi_net2(self):
+        re_str = ''
+        re_str += '<Nnet>\n'
+        re_str += self.in_linear1.to_kaldi_net()
+        re_str += self.in_linear2.to_kaldi_net()
+        re_str += self.relu.to_kaldi_net()
+
+        for fsmn in self.fsmn:
+            re_str += fsmn.to_kaldi_net()
+
+        re_str += self.out_linear1_2.to_kaldi_net()
+        re_str += self.out_linear2_2.to_kaldi_net()
+        re_str += '<Softmax> %d %d\n' % (self.output_dim2, self.output_dim2)
+        re_str += '</Nnet>\n'
+
+        return re_str
+
+    def to_pytorch_net(self, kaldi_file):
+        with open(kaldi_file, 'r', encoding='utf8') as fread:
+            fread = open(kaldi_file, 'r')
+            nnet_start_line = fread.readline()
+            assert nnet_start_line.strip() == '<Nnet>'
+
+            self.in_linear1.to_pytorch_net(fread)
+            self.in_linear2.to_pytorch_net(fread)
+            self.relu.to_pytorch_net(fread)
+
+            for fsmn in self.fsmn:
+                fsmn.to_pytorch_net(fread)
+
+            self.out_linear1.to_pytorch_net(fread)
+            self.out_linear2.to_pytorch_net(fread)
+
+            softmax_line = fread.readline()
+            softmax_split = softmax_line.strip().split()
+            assert softmax_split[0].strip() == '<Softmax>'
+            assert int(softmax_split[1]) == self.output_dim
+            assert int(softmax_split[2]) == self.output_dim
+
+            nnet_end_line = fread.readline()
+            assert nnet_end_line.strip() == '</Nnet>'
+        fread.close()
diff --git a/funasr/models/fsmn_kws_mt/model.py b/funasr/models/fsmn_kws_mt/model.py
new file mode 100644
index 0000000..c4645bb
--- /dev/null
+++ b/funasr/models/fsmn_kws_mt/model.py
@@ -0,0 +1,353 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import time
+import torch
+import logging
+from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
+
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+
+@tables.register("model_classes", "FsmnKWSMT")
+class FsmnKWSMT(torch.nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
+    def __init__(
+        self,
+        specaug: Optional[str] = None,
+        specaug_conf: Optional[Dict] = None,
+        normalize: str = None,
+        normalize_conf: Optional[Dict] = None,
+        encoder: str = None,
+        encoder_conf: Optional[Dict] = None,
+        ctc_conf: Optional[Dict] = None,
+        input_size: int = 360,
+        vocab_size: int = -1,
+        vocab_size2: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        **kwargs,
+    ):
+        super().__init__()
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(**encoder_conf)
+        encoder_output_size = encoder.output_size()
+        encoder_output_size2 = encoder.output_size2()
+
+        ctc = CTC(
+            odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+        )
+        ctc2 = CTC(
+            odim=vocab_size2, encoder_output_size=encoder_output_size2, **ctc_conf
+        )
+
+        self.blank_id = blank_id
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+
+        # self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+        self.encoder = encoder
+        self.ctc = ctc
+        self.ctc2 = ctc2
+
+        self.error_calculator = None
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        text2: torch.Tensor,
+        text2_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+                text2: (Batch, Length)
+                text2_lengths: (Batch,)
+        """
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+        batch_size = speech.shape[0]
+
+        # Encoder
+        encoder_out, encoder_out2, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        loss_ctc, cer_ctc = self._calc_ctc_loss(
+            encoder_out, encoder_out_lens, text, text_lengths
+        )
+        loss_ctc2, cer_ctc2 = self._calc_ctc_loss(
+            encoder_out2, encoder_out_lens, text2, text2_lengths
+        )
+
+        # Collect CTC branch stats
+        stats = dict()
+        stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+        stats["cer_ctc"] = cer_ctc
+        stats["loss_ctc2"] = loss_ctc2.detach() if loss_ctc2 is not None else None
+        stats["cer_ctc2"] = cer_ctc2
+
+        loss = 0.5 * loss_ctc + 0.5 * loss_ctc2
+
+        stats["cer"] = cer_ctc
+        stats["cer2"] = cer_ctc2
+        stats["loss"] = torch.clone(loss.detach())
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+        # Forward encoder
+        encoder_out, encoder_out2 = self.encoder(speech)
+        encoder_out_lens = speech_lengths
+
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        if isinstance(encoder_out2, tuple):
+            encoder_out2 = encoder_out2[0]
+
+        return encoder_out, encoder_out2, encoder_out_lens
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+        return loss_ctc, cer_ctc
+
+    def _calc_ctc2_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        loss_ctc = self.ctc2(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc2.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+        return loss_ctc, cer_ctc
+
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list=None,
+        tokenizer=None,
+        tokenizer2=None,
+        frontend=None,
+        **kwargs,
+    ):
+        keywords = kwargs.get("keywords")
+        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
+        self.kws_decoder = KwsCtcPrefixDecoder(
+            ctc=self.ctc,
+            keywords=keywords,
+            token_list=tokenizer.token_list,
+            seg_dict=tokenizer.seg_dict,
+        )
+        self.kws_decoder2 = KwsCtcPrefixDecoder(
+            ctc=self.ctc2,
+            keywords=keywords,
+            token_list=tokenizer2.token_list,
+            seg_dict=tokenizer2.seg_dict,
+        )
+
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is not None:
+                speech_lengths = speech_lengths.squeeze(-1)
+            else:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(
+                data_in,
+                fs=frontend.fs,
+                audio_fs=kwargs.get("fs", 16000),
+                data_type=kwargs.get("data_type", "sound"),
+                tokenizer=tokenizer
+            )
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(
+                audio_sample_list,
+                data_type=kwargs.get("data_type", "sound"),
+                frontend=frontend
+            )
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+        # Encoder
+        encoder_out, encoder_out2, encoder_out_lens = self.encode(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        if isinstance(encoder_out2, tuple):
+            encoder_out2 = encoder_out2[0]
+
+        results = []
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+
+        for i in range(encoder_out.size(0)):
+            x = encoder_out[i, :encoder_out_lens[i], :]
+            detect_result = self.kws_decoder.decode(x)
+            is_deted, det_keyword, det_score = detect_result[0], detect_result[1], detect_result[2]
+
+            if is_deted:
+                self.writer["detect"][key[i]] = "detected " + det_keyword + " " + str(det_score)
+                det_info = "detected " + det_keyword + " " + str(det_score)
+            else:
+                self.writer["detect"][key[i]] = "rejected"
+                det_info = "rejected"
+
+            x2 = encoder_out2[i, :encoder_out_lens[i], :]
+            detect_result2 = self.kws_decoder2.decode(x2)
+            is_deted2, det_keyword2, det_score2 = detect_result2[0], detect_result2[1], detect_result2[2]
+
+            if is_deted2:
+                self.writer["detect2"][key[i]] = "detected " + det_keyword2 + " " + str(det_score2)
+                det_info2 = "detected " + det_keyword2 + " " + str(det_score2)
+            else:
+                self.writer["detect2"][key[i]] = "rejected"
+                det_info2 = "rejected"
+
+            result_i = {"key": key[i], "text": det_info, "text2": det_info2}
+            results.append(result_i)
+
+        return results, meta_data
+
+
+@tables.register("model_classes", "FsmnKWSMTConvert")
+class FsmnKWSMTConvert(torch.nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
+
+    def __init__(
+        self,
+        encoder: str = None,
+        encoder_conf: Optional[Dict] = None,
+        ctc: str = None,
+        ctc_conf: Optional[Dict] = None,
+        ctc_weight: float = 1.0,
+        input_size: int = 360,
+        vocab_size: int = -1,
+        vocab_size2: int = -1,
+        blank_id: int = 0,
+        **kwargs,
+    ):
+        super().__init__()
+
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(**encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        if ctc_conf is None:
+            ctc_conf = {}
+        ctc = CTC(
+            odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+        )
+
+        self.blank_id = blank_id
+        self.vocab_size = vocab_size
+        self.ctc_weight = ctc_weight
+        self.encoder = encoder
+        self.ctc = ctc
+
+        self.error_calculator = None
+
+    def to_kaldi_net(self):
+        return self.encoder.to_kaldi_net()
+
+    def to_kaldi_net2(self):
+        return self.encoder.to_kaldi_net2()
+
+    def to_pytorch_net(self, kaldi_file):
+        return self.encoder.to_pytorch_net(kaldi_file)
diff --git a/funasr/models/fsmn_vad_streaming/encoder.py b/funasr/models/fsmn_vad_streaming/encoder.py
index 6668c5d..14c2f5f 100755
--- a/funasr/models/fsmn_vad_streaming/encoder.py
+++ b/funasr/models/fsmn_vad_streaming/encoder.py
@@ -85,13 +85,17 @@
         else:
             self.conv_right = None
 
-    def forward(self, input: torch.Tensor, cache: torch.Tensor):
+    def forward(self, input: torch.Tensor, cache: torch.Tensor = None):
         x = torch.unsqueeze(input, 1)
         x_per = x.permute(0, 3, 2, 1)  # B D T C
 
-        cache = cache.to(x_per.device)
-        y_left = torch.cat((cache, x_per), dim=2)
-        cache = y_left[:, :, -(self.lorder - 1) * self.lstride :, :]
+        if cache is not None:
+            cache = cache.to(x_per.device)
+            y_left = torch.cat((cache, x_per), dim=2)
+            cache = y_left[:, :, -(self.lorder - 1) * self.lstride :, :]
+        else:
+            y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
+
         y_left = self.conv_left(y_left)
         out = x_per + y_left
 
@@ -130,14 +134,18 @@
         self.affine = AffineTransform(proj_dim, linear_dim)
         self.relu = RectifiedLinear(linear_dim, linear_dim)
 
-    def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
+    def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor] = None):
         x1 = self.linear(input)  # B T D
-        cache_layer_name = "cache_layer_{}".format(self.stack_layer)
-        if cache_layer_name not in cache:
-            cache[cache_layer_name] = torch.zeros(
-                x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1
-            )
-        x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
+
+        if cache is not None:
+            cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+            if cache_layer_name not in cache:
+                cache[cache_layer_name] = torch.zeros(
+                    x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1
+                )
+            x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
+        else:
+            x2, _ = self.fsmn_block(x1, None)
         x3 = self.affine(x2)
         x4 = self.relu(x3)
         return x4
@@ -203,6 +211,7 @@
         rstride: int,
         output_affine_dim: int,
         output_dim: int,
+        use_softmax: bool = True,
     ):
         super().__init__()
 
@@ -225,13 +234,21 @@
         )
         self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
         self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
-        self.softmax = nn.Softmax(dim=-1)
+
+        self.use_softmax = use_softmax
+        if self.use_softmax:
+            self.softmax = nn.Softmax(dim=-1)
 
     def fuse_modules(self):
         pass
 
+    def output_size(self) -> int:
+        return self.output_dim
+
     def forward(
-        self, input: torch.Tensor, cache: Dict[str, torch.Tensor]
+        self,
+        input: torch.Tensor,
+        cache: Dict[str, torch.Tensor] = None
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
         """
         Args:
@@ -246,9 +263,12 @@
         x4 = self.fsmn(x3, cache)  # self.cache will update automatically in self.fsmn
         x5 = self.out_linear1(x4)
         x6 = self.out_linear2(x5)
-        x7 = self.softmax(x6)
 
-        return x7
+        if self.use_softmax:
+            x7 = self.softmax(x6)
+            return x7
+
+        return x6
 
 
 @tables.register("encoder_classes", "FSMNExport")
@@ -276,6 +296,7 @@
         # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
         # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
         # self.softmax = nn.Softmax(dim=-1)
+
         self.in_linear1 = model.in_linear1
         self.in_linear2 = model.in_linear2
         self.relu = model.relu
@@ -317,88 +338,3 @@
         x = self.softmax(x)
 
         return x, out_caches
-
-
-"""
-one deep fsmn layer
-dimproj:                projection dimension, input and output dimension of memory blocks
-dimlinear:              dimension of mapping layer
-lorder:                 left order
-rorder:                 right order
-lstride:                left stride
-rstride:                right stride
-"""
-
-
-@tables.register("encoder_classes", "DFSMN")
-class DFSMN(nn.Module):
-
-    def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
-        super(DFSMN, self).__init__()
-
-        self.lorder = lorder
-        self.rorder = rorder
-        self.lstride = lstride
-        self.rstride = rstride
-
-        self.expand = AffineTransform(dimproj, dimlinear)
-        self.shrink = LinearTransform(dimlinear, dimproj)
-
-        self.conv_left = nn.Conv2d(
-            dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False
-        )
-
-        if rorder > 0:
-            self.conv_right = nn.Conv2d(
-                dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False
-            )
-        else:
-            self.conv_right = None
-
-    def forward(self, input):
-        f1 = F.relu(self.expand(input))
-        p1 = self.shrink(f1)
-
-        x = torch.unsqueeze(p1, 1)
-        x_per = x.permute(0, 3, 2, 1)
-
-        y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
-
-        if self.conv_right is not None:
-            y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
-            y_right = y_right[:, :, self.rstride :, :]
-            out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
-        else:
-            out = x_per + self.conv_left(y_left)
-
-        out1 = out.permute(0, 3, 2, 1)
-        output = input + out1.squeeze(1)
-
-        return output
-
-
-"""
-build stacked dfsmn layers
-"""
-
-
-def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
-    repeats = [
-        nn.Sequential(DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) for i in range(fsmn_layers)
-    ]
-
-    return nn.Sequential(*repeats)
-
-
-if __name__ == "__main__":
-    fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
-    print(fsmn)
-
-    num_params = sum(p.numel() for p in fsmn.parameters())
-    print("the number of model params: {}".format(num_params))
-    x = torch.zeros(128, 200, 400)  # batch-size * time * dim
-    y, _ = fsmn(x)  # batch-size * time * dim
-    print("input shape: {}".format(x.shape))
-    print("output shape: {}".format(y.shape))
-
-    print(fsmn.to_kaldi_net())
diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index dc30a94..0d39ca7 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -523,6 +523,7 @@
         feats_dim=560,
         model_name="encoder",
         onnx: bool = True,
+        ctc_linear: nn.Module = None,
     ):
         super().__init__()
         self.embed = model.embed
@@ -553,6 +554,8 @@
         self.num_heads = model.encoders[0].self_attn.h
         self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
 
+        self.ctc_linear = ctc_linear
+
     def prepare_mask(self, mask):
         mask_3d_btd = mask[:, :, None]
         if len(mask.shape) == 2:
@@ -566,6 +569,7 @@
     def forward(self, speech: torch.Tensor, speech_lengths: torch.Tensor, online: bool = False):
         if not online:
             speech = speech * self._output_size**0.5
+
         mask = self.make_pad_mask(speech_lengths)
         mask = self.prepare_mask(mask)
         if self.embed is None:
@@ -581,6 +585,10 @@
 
         xs_pad = self.model.after_norm(xs_pad)
 
+        if self.ctc_linear is not None:
+            xs_pad = self.ctc_linear(xs_pad)
+            xs_pad = F.softmax(xs_pad, dim=2)
+
         return xs_pad, speech_lengths
 
     def get_output_size(self):
diff --git a/funasr/models/sanm_kws/__init__.py b/funasr/models/sanm_kws/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/sanm_kws/__init__.py
diff --git a/funasr/models/sanm_kws/export_meta.py b/funasr/models/sanm_kws/export_meta.py
new file mode 100644
index 0000000..a91a22c
--- /dev/null
+++ b/funasr/models/sanm_kws/export_meta.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import types
+import copy
+import torch
+from funasr.register import tables
+
+
+def export_rebuild_model(model, **kwargs):
+    # self.device = kwargs.get("device")
+    is_onnx = kwargs.get("type", "onnx") == "onnx"
+    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+
+    if hasattr(model, "ctc"):
+        model.encoder = encoder_class(
+            model.encoder,
+            onnx=is_onnx,
+            feats_dim=kwargs.get("input_size", 560),
+            ctc_linear=model.ctc.ctc_lo
+        )
+    else:
+        assert False, print(model)
+        model.encoder = encoder_class(model.encoder, onnx=is_onnx, feats_dim=kwargs.get("input_size", 560))
+
+    # from funasr.utils.torch_function import sequence_mask
+    # model.make_pad_mask = sequence_mask(max_seq_len=None, flip=False)
+
+    encoder_model = copy.copy(model)
+
+    # encoder
+    encoder_model.forward = types.MethodType(export_encoder_forward, encoder_model)
+    encoder_model.export_dummy_inputs = types.MethodType(export_encoder_dummy_inputs, encoder_model)
+    encoder_model.export_input_names = types.MethodType(export_encoder_input_names, encoder_model)
+    encoder_model.export_output_names = types.MethodType(export_encoder_output_names, encoder_model)
+    encoder_model.export_dynamic_axes = types.MethodType(export_encoder_dynamic_axes, encoder_model)
+    encoder_model.export_name = types.MethodType(export_encoder_name, encoder_model)
+
+    return encoder_model
+
+
+def export_encoder_forward(
+    self,
+    speech: torch.Tensor,
+    speech_lengths: torch.Tensor,
+):
+    # a. To device
+    batch = {
+        "speech": speech,
+        "speech_lengths": speech_lengths,
+        "online": True
+    }
+    # batch = to_device(batch, device=self.device)
+
+    encoder_out, encoder_out_len = self.encoder(**batch)
+    # mask = self.make_pad_mask(encoder_out_len)[:, None, :]
+    # alphas, _ = self.predictor.forward_cnn(enc, mask)
+
+    # return encoder_out, encoder_out_len, alphas
+    return encoder_out, encoder_out_len
+
+
+def export_encoder_dummy_inputs(self):
+    speech = torch.randn(2, 30, 280)
+    speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+    return (speech, speech_lengths)
+
+
+def export_encoder_input_names(self):
+    return ["speech", "speech_lengths"]
+
+
+def export_encoder_output_names(self):
+    # return ["encoder_out", "encoder_out_len", "alphas"]
+    return ["encoder_out", "encoder_out_len"]
+
+
+def export_encoder_dynamic_axes(self):
+    return {
+        "speech": {
+            0: "batch_size", 1: "feats_length"
+        },
+        "speech_lengths": {
+            0: "batch_size",
+        },
+        "encoder_out": {
+            0: "batch_size", 1: "feats_length"
+        },
+        "encoder_out_len": {
+            0: "batch_size",
+        },
+    }
+
+
+def export_encoder_name(self):
+    return "encoder.onnx"
diff --git a/funasr/models/sanm_kws/model.py b/funasr/models/sanm_kws/model.py
new file mode 100644
index 0000000..d8d8f95
--- /dev/null
+++ b/funasr/models/sanm_kws/model.py
@@ -0,0 +1,266 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import time
+import torch
+import logging
+from torch.cuda.amp import autocast
+from typing import Union, Dict, List, Tuple, Optional
+
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import to_device
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+
+@tables.register("model_classes", "SanmKWS")
+class SanmKWS(torch.nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2206.08317
+    """
+
+    def __init__(
+        self,
+        specaug: Optional[str] = None,
+        specaug_conf: Optional[Dict] = None,
+        normalize: str = None,
+        normalize_conf: Optional[Dict] = None,
+        encoder: str = None,
+        encoder_conf: Optional[Dict] = None,
+        ctc: str = None,
+        ctc_conf: Optional[Dict] = None,
+        ctc_weight: float = 1.0,
+        input_size: int = 360,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
+        **kwargs,
+    ):
+
+        super().__init__()
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(input_size=input_size, **encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        if ctc_conf is None:
+            ctc_conf = {}
+        ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
+
+        # note that eos is the same as sos (equivalent ID)
+        self.blank_id = blank_id
+        self.sos = sos if sos is not None else vocab_size - 1
+        self.eos = eos if eos is not None else vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.ctc_weight = ctc_weight
+        # self.token_list = token_list.copy()
+        #
+        # self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+        self.encoder = encoder
+
+        self.ctc = ctc
+        self.error_calculator = None
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+        batch_size = speech.shape[0]
+
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        # decoder: CTC branch
+        loss_ctc, cer_ctc = self._calc_ctc_loss(
+            encoder_out, encoder_out_lens, text, text_lengths
+        )
+
+        # Collect CTC branch stats
+        stats = dict()
+        stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+        stats["cer_ctc"] = cer_ctc
+        stats["cer"] = cer_ctc
+
+        loss = loss_ctc
+
+        stats["loss"] = torch.clone(loss.detach())
+        stats["batch_size"] = batch_size
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+        # Forward encoder
+        encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        return encoder_out, encoder_out_lens
+
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+        return loss_ctc, cer_ctc
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+        keywords = kwargs.get("keywords")
+        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
+        self.kws_decoder = KwsCtcPrefixDecoder(
+            ctc=self.ctc,
+            keywords=keywords,
+            token_list=tokenizer.token_list,
+            seg_dict=tokenizer.seg_dict,
+        )
+
+        meta_data = {}
+        if (
+            isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+        ):  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is not None:
+                speech_lengths = speech_lengths.squeeze(-1)
+            else:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            audio_sample_list = load_audio_text_image_video(
+                data_in,
+                fs=frontend.fs,
+                audio_fs=kwargs.get("fs", 16000),
+                data_type=kwargs.get("data_type", "sound"),
+                tokenizer=tokenizer,
+            )
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            speech, speech_lengths = extract_fbank(
+                audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+            )
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = (
+                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+            )
+
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+        # Encoder
+        if kwargs.get("fp16", False):
+            speech = speech.half()
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        results = []
+        if kwargs.get("output_dir") is not None:
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+
+        for i in range(encoder_out.size(0)):
+            x = encoder_out[i, : encoder_out_lens[i], :]
+            detect_result = self.kws_decoder.decode(x)
+            is_deted, det_keyword, det_score = detect_result[0], detect_result[1], detect_result[2]
+
+            if is_deted:
+                self.writer["detect"][key[i]] = "detected " + det_keyword + " " + str(det_score)
+                det_info = "detected " + det_keyword + " " + str(det_score)
+            else:
+                self.writer["detect"][key[i]] = "rejected"
+                det_info = "rejected"
+
+            result_i = {"key": key[i], "text": det_info}
+            results.append(result_i)
+
+        return results, meta_data
+
+    def export(self, **kwargs):
+        from .export_meta import export_rebuild_model
+
+        if "max_seq_len" not in kwargs:
+            kwargs["max_seq_len"] = 512
+        models = export_rebuild_model(model=self, **kwargs)
+        return models
diff --git a/funasr/models/sanm_kws_streaming/__init__.py b/funasr/models/sanm_kws_streaming/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/sanm_kws_streaming/__init__.py
diff --git a/funasr/models/sanm_kws_streaming/export_meta.py b/funasr/models/sanm_kws_streaming/export_meta.py
new file mode 100644
index 0000000..a91a22c
--- /dev/null
+++ b/funasr/models/sanm_kws_streaming/export_meta.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import types
+import copy
+import torch
+from funasr.register import tables
+
+
+def export_rebuild_model(model, **kwargs):
+    # self.device = kwargs.get("device")
+    is_onnx = kwargs.get("type", "onnx") == "onnx"
+    encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+
+    if hasattr(model, "ctc"):
+        model.encoder = encoder_class(
+            model.encoder,
+            onnx=is_onnx,
+            feats_dim=kwargs.get("input_size", 560),
+            ctc_linear=model.ctc.ctc_lo
+        )
+    else:
+        assert False, print(model)
+        model.encoder = encoder_class(model.encoder, onnx=is_onnx, feats_dim=kwargs.get("input_size", 560))
+
+    # from funasr.utils.torch_function import sequence_mask
+    # model.make_pad_mask = sequence_mask(max_seq_len=None, flip=False)
+
+    encoder_model = copy.copy(model)
+
+    # encoder
+    encoder_model.forward = types.MethodType(export_encoder_forward, encoder_model)
+    encoder_model.export_dummy_inputs = types.MethodType(export_encoder_dummy_inputs, encoder_model)
+    encoder_model.export_input_names = types.MethodType(export_encoder_input_names, encoder_model)
+    encoder_model.export_output_names = types.MethodType(export_encoder_output_names, encoder_model)
+    encoder_model.export_dynamic_axes = types.MethodType(export_encoder_dynamic_axes, encoder_model)
+    encoder_model.export_name = types.MethodType(export_encoder_name, encoder_model)
+
+    return encoder_model
+
+
+def export_encoder_forward(
+    self,
+    speech: torch.Tensor,
+    speech_lengths: torch.Tensor,
+):
+    # a. To device
+    batch = {
+        "speech": speech,
+        "speech_lengths": speech_lengths,
+        "online": True
+    }
+    # batch = to_device(batch, device=self.device)
+
+    encoder_out, encoder_out_len = self.encoder(**batch)
+    # mask = self.make_pad_mask(encoder_out_len)[:, None, :]
+    # alphas, _ = self.predictor.forward_cnn(enc, mask)
+
+    # return encoder_out, encoder_out_len, alphas
+    return encoder_out, encoder_out_len
+
+
+def export_encoder_dummy_inputs(self):
+    speech = torch.randn(2, 30, 280)
+    speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+    return (speech, speech_lengths)
+
+
+def export_encoder_input_names(self):
+    return ["speech", "speech_lengths"]
+
+
+def export_encoder_output_names(self):
+    # return ["encoder_out", "encoder_out_len", "alphas"]
+    return ["encoder_out", "encoder_out_len"]
+
+
+def export_encoder_dynamic_axes(self):
+    return {
+        "speech": {
+            0: "batch_size", 1: "feats_length"
+        },
+        "speech_lengths": {
+            0: "batch_size",
+        },
+        "encoder_out": {
+            0: "batch_size", 1: "feats_length"
+        },
+        "encoder_out_len": {
+            0: "batch_size",
+        },
+    }
+
+
+def export_encoder_name(self):
+    return "encoder.onnx"
diff --git a/funasr/models/sanm_kws_streaming/model.py b/funasr/models/sanm_kws_streaming/model.py
new file mode 100644
index 0000000..c459e7c
--- /dev/null
+++ b/funasr/models/sanm_kws_streaming/model.py
@@ -0,0 +1,442 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import time
+import torch
+import logging
+from typing import Dict, Tuple
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+
+from funasr.register import tables
+from funasr.models.ctc.ctc import CTC
+from funasr.utils import postprocess_utils
+from funasr.metrics.compute_acc import th_accuracy
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.models.sanm_kws.model import SanmKWS
+from funasr.models.paraformer.search import Hypothesis
+from funasr.models.paraformer.cif_predictor import mae_loss
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
+
+@tables.register("model_classes", "SanmKWSStreaming")
+class SanmKWSStreaming(SanmKWS):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2206.08317
+    """
+
+    def __init__(
+        self,
+        *args,
+        **kwargs,
+    ):
+        super().__init__(*args, **kwargs)
+
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        decoding_ind = kwargs.get("decoding_ind")
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+
+        batch_size = speech.shape[0]
+
+        # Encoder
+        if hasattr(self.encoder, "overlap_chunk_cls"):
+            ind = self.encoder.overlap_chunk_cls.random_choice(self.training, decoding_ind)
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, ind=ind)
+        else:
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        # decoder: CTC branch
+        if hasattr(self.encoder, "overlap_chunk_cls"):
+            encoder_out_ctc, encoder_out_lens_ctc = self.encoder.overlap_chunk_cls.remove_chunk(
+                encoder_out, encoder_out_lens, chunk_outs=None
+            )
+        else:
+            encoder_out_ctc, encoder_out_lens_ctc = encoder_out, encoder_out_lens
+
+        loss_ctc, cer_ctc = self._calc_ctc_loss(
+            encoder_out_ctc, encoder_out_lens_ctc, text, text_lengths
+        )
+
+        # Collect CTC branch stats
+        stats = dict()
+        stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+        stats["cer_ctc"] = cer_ctc
+
+        loss = loss_ctc
+
+        stats["cer"] = cer_ctc
+        stats["loss"] = torch.clone(loss.detach())
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+    def encode_chunk(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        cache: dict = None,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+
+        # Forward encoder
+        encoder_out, encoder_out_lens, _ = self.encoder.forward_chunk(
+            speech, speech_lengths, cache=cache["encoder"]
+        )
+
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        return encoder_out, torch.tensor([encoder_out.size(1)])
+
+    def init_cache(self, cache: dict = {}, **kwargs):
+        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
+        encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
+        decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
+        batch_size = 1
+
+        enc_output_size = kwargs["encoder_conf"]["output_size"]
+        feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
+        cache_encoder = {
+            "start_idx": 0,
+            "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
+            "cif_alphas": torch.zeros((batch_size, 1)),
+            "encoder_out": None,
+            "encoder_out_lens": None,
+            "chunk_size": chunk_size,
+            "encoder_chunk_look_back": encoder_chunk_look_back,
+            "last_chunk": False,
+            "opt": None,
+            "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
+            "tail_chunk": False,
+        }
+        cache["encoder"] = cache_encoder
+
+        cache_decoder = {
+            "decode_fsmn": None,
+            "decoder_chunk_look_back": decoder_chunk_look_back,
+            "opt": None,
+            "chunk_size": chunk_size,
+        }
+        cache["decoder"] = cache_decoder
+        cache["frontend"] = {}
+        cache["prev_samples"] = torch.empty(0)
+
+        return cache
+
+    def generate_chunk(
+        self,
+        speech,
+        speech_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+        cache = kwargs.get("cache", {})
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+
+        # Encoder
+        is_final = kwargs.get("is_final", False)
+        encoder_out, encoder_out_lens = self.encode_chunk(
+            speech, speech_lengths, cache=cache, is_final=is_final
+        )
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        chunk_size = cache["encoder"]["chunk_size"]
+        real_start_pos = chunk_size[0]
+
+        if encoder_out_lens[0] > chunk_size[0] + chunk_size[1] + chunk_size[2]:
+            assert False, print("impossible case 1 !")
+        if encoder_out_lens[0] == chunk_size[0] + chunk_size[1] + chunk_size[2]:
+            real_end_pos = chunk_size[0] + chunk_size[1]
+        elif encoder_out_lens[0] > chunk_size[0] + chunk_size[1]:
+            real_end_pos = chunk_size[0] + chunk_size[1]
+        elif encoder_out_lens[0] > chunk_size[0]:
+            real_end_pos = encoder_out_lens[0]
+        else:
+            assert False, print("impossible case 2 !")
+
+        encoder_out_accum = cache["encoder"]["encoder_out"]
+        if encoder_out_accum is not None:
+            encoder_out_accum = torch.cat((encoder_out_accum, encoder_out[:, real_start_pos:real_end_pos, :]), dim=1)
+        else:
+            encoder_out_accum = encoder_out[:, real_start_pos:real_end_pos, :]
+        cache["encoder"]["encoder_out"] = encoder_out_accum
+
+        if cache["encoder"]["encoder_out_lens"] is not None:
+            cache["encoder"]["encoder_out_lens"][0] += real_end_pos - real_start_pos
+        else:
+            cache["encoder"]["encoder_out_lens"] = encoder_out_lens
+            cache["encoder"]["encoder_out_lens"][0] = real_end_pos - real_start_pos
+
+        if is_final:
+            if kwargs.get("output_dir") is not None:
+                if not hasattr(self, "writer"):
+                    self.writer = DatadirWriter(kwargs.get("output_dir"))
+
+            results = []
+            for i in range(encoder_out_accum.size(0)):
+                x = encoder_out_accum[i, : cache["encoder"]["encoder_out_lens"][i], :]
+                detect_result = self.kws_decoder.decode(x)
+                is_deted, det_keyword, det_score = detect_result[0], detect_result[1], detect_result[2]
+
+                if is_deted:
+                    self.writer["detect"][key[i]] = "detected " + det_keyword + " " + str(det_score)
+                    det_info = "detected " + det_keyword + " " + str(det_score)
+                else:
+                    self.writer["detect"][key[i]] = "rejected"
+                    det_info = "rejected"
+
+                result_i = {"key": key[i], "text": det_info}
+                results.append(result_i)
+
+            return results
+        else:
+            return None
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        cache: dict = {},
+        **kwargs,
+    ):
+        keywords = kwargs.get("keywords")
+        from funasr.utils.kws_utils import KwsCtcPrefixDecoder
+        self.kws_decoder = KwsCtcPrefixDecoder(
+            ctc=self.ctc,
+            keywords=keywords,
+            token_list=tokenizer.token_list,
+            seg_dict=tokenizer.seg_dict,
+        )
+
+        meta_data = {}
+        chunk_size = kwargs["chunk_size"]
+        chunk_stride_samples = int(chunk_size[1] * 960)  # 600ms
+        first_chunk_padding_samples = int(chunk_size[2] * 960)  # 600ms
+
+        if len(cache) == 0:
+            self.init_cache(cache, **kwargs)
+
+        time1 = time.perf_counter()
+        cfg = {"is_final": kwargs.get("is_final", False)}
+        audio_sample_list = load_audio_text_image_video(
+            data_in,
+            fs=frontend.fs,
+            audio_fs=kwargs.get("fs", 16000),
+            data_type=kwargs.get("data_type", "sound"),
+            tokenizer=tokenizer,
+            cache=cfg,
+        )
+        _is_final = cfg["is_final"]  # if data_in is a file or url, set is_final=True
+
+        time2 = time.perf_counter()
+        meta_data["load_data"] = f"{time2 - time1:0.3f}"
+        assert len(audio_sample_list) == 1, "batch_size must be set 1"
+
+        audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
+
+        if len(audio_sample) < first_chunk_padding_samples:
+            print("key: {}, audio is too short for inference {}".format(key, len(audio_sample)))
+
+        audio_sample_pre = audio_sample[0 : first_chunk_padding_samples]
+        feat_pre, feat_pre_lengths = extract_fbank(
+            [audio_sample_pre],
+            data_type=kwargs.get("data_type", "sound"),
+            frontend=frontend,
+            cache=cache["frontend"],
+            is_final=False,
+        )
+
+        audio_sample = audio_sample[first_chunk_padding_samples:]
+        audio_chunks = int(len(audio_sample) // chunk_stride_samples)
+
+        for i in range(audio_chunks):
+            if i == 0:
+                cache["encoder"]["feats"][:, chunk_size[2] :, :] = feat_pre
+
+            kwargs["is_final"] = False
+            audio_sample_i = audio_sample[i * chunk_stride_samples : (i + 1) * chunk_stride_samples]
+
+            if kwargs["is_final"] and len(audio_sample_i) < 960:
+                cache["encoder"]["tail_chunk"] = True
+                speech = cache["encoder"]["feats"]
+                speech_lengths = torch.tensor([speech.shape[1]], dtype=torch.int64).to(
+                    speech.device
+                )
+            else:
+                # extract fbank feats
+                speech, speech_lengths = extract_fbank(
+                    [audio_sample_i],
+                    data_type=kwargs.get("data_type", "sound"),
+                    frontend=frontend,
+                    cache=cache["frontend"],
+                    is_final=kwargs["is_final"],
+                )
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            meta_data["batch_data_time"] = (
+                speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+            )
+
+            results_chunk_i = self.generate_chunk(
+                speech,
+                speech_lengths,
+                key=key,
+                tokenizer=tokenizer,
+                cache=cache,
+                frontend=frontend,
+                **kwargs,
+            )
+
+            # results_chunk_i must be None when is_final=False
+            assert results_chunk_i is None
+
+        # process tail samples
+        tail_audio_sample = audio_sample[ audio_chunks * chunk_stride_samples: ]
+        if len(tail_audio_sample) < 960:
+            kwargs["is_final"] = _is_final
+            cache["encoder"]["tail_chunk"] = True
+            speech = cache["encoder"]["feats"]
+            speech_lengths = torch.tensor([speech.shape[1]], dtype=torch.int64).to(
+                speech.device
+            )
+            results_chunk_tail = self.generate_chunk(
+                speech,
+                speech_lengths,
+                key=key,
+                tokenizer=tokenizer,
+                cache=cache,
+                frontend=frontend,
+                **kwargs,
+            )
+        elif len(tail_audio_sample) <= first_chunk_padding_samples:
+            kwargs["is_final"] = _is_final
+            # extract fbank feats
+            # cache["encoder"]["tail_chunk"] = True  # cannot be true
+            speech, speech_lengths = extract_fbank(
+                [ tail_audio_sample ],
+                data_type=kwargs.get("data_type", "sound"),
+                frontend=frontend,
+                cache=cache["frontend"],
+                is_final=kwargs["is_final"],
+            )
+            results_chunk_tail = self.generate_chunk(
+                speech,
+                speech_lengths,
+                key=key,
+                tokenizer=tokenizer,
+                cache=cache,
+                frontend=frontend,
+                **kwargs,
+            )
+        elif len(tail_audio_sample) > first_chunk_padding_samples and \
+             len(tail_audio_sample) < chunk_stride_samples:
+            kwargs["is_final"] = False
+            # extract fbank feats
+            speech, speech_lengths = extract_fbank(
+                [ tail_audio_sample ],
+                data_type=kwargs.get("data_type", "sound"),
+                frontend=frontend,
+                cache=cache["frontend"],
+                is_final=kwargs["is_final"],
+            )
+            results_chunk = self.generate_chunk(
+                speech,
+                speech_lengths,
+                key=key,
+                tokenizer=tokenizer,
+                cache=cache,
+                frontend=frontend,
+                **kwargs,
+            )
+            # results_chunk must be None when is_final=False
+            assert results_chunk is None
+
+            # push tail chunk
+            kwargs["is_final"] = _is_final
+            cache["encoder"]["tail_chunk"] = True
+            speech = cache["encoder"]["feats"]
+            speech_lengths = torch.tensor([speech.shape[1]], dtype=torch.int64).to(
+                speech.device
+            )
+            results_chunk_tail = self.generate_chunk(
+                speech,
+                speech_lengths,
+                key=key,
+                tokenizer=tokenizer,
+                cache=cache,
+                frontend=frontend,
+                **kwargs,
+            )
+
+        result = results_chunk_tail
+
+        if _is_final:
+            self.init_cache(cache, **kwargs)
+
+        if kwargs.get("output_dir"):
+            if not hasattr(self, "writer"):
+                self.writer = DatadirWriter(kwargs.get("output_dir"))
+
+        return result, meta_data
+
+    def export(self, **kwargs):
+        from .export_meta import export_rebuild_model
+
+        models = export_rebuild_model(model=self, **kwargs)
+        return models
diff --git a/funasr/models/transformer/scorers/ctc.py b/funasr/models/transformer/scorers/ctc.py
index 73a14bd..eb19d18 100644
--- a/funasr/models/transformer/scorers/ctc.py
+++ b/funasr/models/transformer/scorers/ctc.py
@@ -7,7 +7,6 @@
 from funasr.models.transformer.scorers.ctc_prefix_score import CTCPrefixScoreTH
 from funasr.models.transformer.scorers.scorer_interface import BatchPartialScorerInterface
 
-
 class CTCPrefixScorer(BatchPartialScorerInterface):
     """Decoder interface wrapper for CTCPrefixScore."""
 
diff --git a/funasr/tokenizer/char_tokenizer.py b/funasr/tokenizer/char_tokenizer.py
index 805ecd0..7b517da 100644
--- a/funasr/tokenizer/char_tokenizer.py
+++ b/funasr/tokenizer/char_tokenizer.py
@@ -50,9 +50,7 @@
         )
 
     def text2tokens(self, line: Union[str, list]) -> List[str]:
-
         # if self.split_with_space:
-
         if self.seg_dict is not None:
             tokens = line.strip().split(" ")
             tokens = seg_tokenize(tokens, self.seg_dict)
diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py
index 67f1e55..873f419 100644
--- a/funasr/train_utils/average_nbest_models.py
+++ b/funasr/train_utils/average_nbest_models.py
@@ -30,8 +30,8 @@
                 map_location="cpu",
             )
         avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
-        val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
-        sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
+        val_step_or_epoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_epoch"]
+        sorted_items = sorted(val_step_or_epoch.items(), key=lambda x: x[1], reverse=True)
         sorted_items = (
             sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:]
         )
@@ -53,6 +53,7 @@
         checkpoint_files.sort(key=lambda x: int(re.search(r"(\d+)", x).group()), reverse=True)
         # Get the last 'last_n' checkpoint paths
         checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
+
     return checkpoint_paths
 
 
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 8ed613c..da2eed5 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -8,6 +8,7 @@
 import torch.nn
 import torch.optim
 import pdb
+import copy
 
 
 def load_pretrained_model(
@@ -35,11 +36,12 @@
     logging.info(f"ckpt: {path}")
 
     if oss_bucket is None:
-        src_state = torch.load(path, map_location=map_location)
+        ori_state = torch.load(path, map_location=map_location)
     else:
         buffer = BytesIO(oss_bucket.get_object(path).read())
-        src_state = torch.load(buffer, map_location=map_location)
+        ori_state = torch.load(buffer, map_location=map_location)
 
+    src_state = copy.deepcopy(ori_state)
     src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
     src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
     src_state = src_state["model"] if "model" in src_state else src_state
@@ -94,7 +96,6 @@
                 )
             else:
                 dst_state[k] = src_state[k_src]
-
         else:
             print(f"Warning, miss key in ckpt: {k}, {path}")
 
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 665a7af..5fe34b9 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -115,8 +115,8 @@
         self.saved_ckpts = {}
         self.step_or_epoch = -1
         self.best_step_or_epoch = ""
-        self.val_acc_step_or_eoch = {}
-        self.val_loss_step_or_eoch = {}
+        self.val_acc_step_or_epoch = {}
+        self.val_loss_step_or_epoch = {}
 
         self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
         self.start_data_split_i = 0
@@ -161,12 +161,14 @@
             # self.step_or_epoch += 1
             state = {
                 "epoch": epoch,
+                'step': step,
+                'total_step': self.batch_total,
                 "state_dict": model.state_dict(),
                 "optimizer": optim.state_dict(),
                 "scheduler": scheduler.state_dict(),
                 "saved_ckpts": self.saved_ckpts,
-                "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
-                "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+                "val_acc_step_or_epoch": self.val_acc_step_or_epoch,
+                "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
                 "best_step_or_epoch": self.best_step_or_epoch,
                 "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
                 "step": step,
@@ -183,6 +185,7 @@
 
             if scaler:
                 state["scaler_state"] = scaler.state_dict()
+
             # Create output directory if it does not exist
             os.makedirs(self.output_dir, exist_ok=True)
             if step is None:
@@ -191,47 +194,48 @@
                 ckpt_name = f"model.pt.ep{epoch}.{step}"
             filename = os.path.join(self.output_dir, ckpt_name)
             torch.save(state, filename)
+            logging.info(f'Checkpoint saved to {filename}')
 
-            logging.info(f"\nCheckpoint saved to {filename}\n")
-            latest = Path(os.path.join(self.output_dir, f"model.pt"))
+            latest = Path(os.path.join(self.output_dir, f'model.pt'))
             torch.save(state, latest)
+
             if self.best_step_or_epoch == "":
                 self.best_step_or_epoch = ckpt_name
 
             if self.avg_keep_nbest_models_type == "acc":
                 if (
-                    self.val_acc_step_or_eoch[ckpt_name]
-                    >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
+                    self.val_acc_step_or_epoch[ckpt_name]
+                    >= self.val_acc_step_or_epoch[self.best_step_or_epoch]
                 ):
                     self.best_step_or_epoch = ckpt_name
                     best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
                     torch.save(state, best_ckpt)
                     logging.info(
-                        f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+                        f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
                     )
                 else:
                     logging.info(
-                        f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+                        f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
                     )
             elif self.avg_keep_nbest_models_type == "loss":
                 if (
-                    self.val_loss_step_or_eoch[ckpt_name]
-                    <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
+                    self.val_loss_step_or_epoch[ckpt_name]
+                    <= self.val_loss_step_or_epoch[self.best_step_or_epoch]
                 ):
                     self.best_step_or_epoch = ckpt_name
                     best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
                     torch.save(state, best_ckpt)
                     logging.info(
-                        f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+                        f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
                     )
                 else:
                     logging.info(
-                        f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+                        f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
                     )
             else:
                 print("Undo")
             self.saved_ckpts[ckpt_name] = getattr(
-                self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
+                self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
             )[ckpt_name]
             if self.keep_nbest_models > 0:
                 if len(self.saved_ckpts) > self.keep_nbest_models:
@@ -278,6 +282,7 @@
                         k_ddp = k.replace("module.", "", 1)
                     else:
                         k_ddp = k
+
                     if k_ddp in src_state.keys():
                         dst_state[k] = src_state[k_ddp]
                     else:
@@ -290,14 +295,14 @@
                     scaler.load_state_dict(checkpoint["scaler_state"])
 
                 self.saved_ckpts = checkpoint["saved_ckpts"]
-                self.val_acc_step_or_eoch = (
-                    checkpoint["val_acc_step_or_eoch"]
-                    if "val_acc_step_or_eoch" in checkpoint
+                self.val_acc_step_or_epoch = (
+                    checkpoint["val_acc_step_or_epoch"]
+                    if "val_acc_step_or_epoch" in checkpoint
                     else {}
                 )
-                self.val_loss_step_or_eoch = (
-                    checkpoint["val_loss_step_or_eoch"]
-                    if "val_loss_step_or_eoch" in checkpoint
+                self.val_loss_step_or_epoch = (
+                    checkpoint["val_loss_step_or_epoch"]
+                    if "val_loss_step_or_epoch" in checkpoint
                     else {}
                 )
                 self.best_step_or_epoch = (
@@ -327,6 +332,7 @@
 
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
+
 
     def train_epoch(
         self,
@@ -559,12 +565,14 @@
                 time1 = time.perf_counter()
                 speed_stats["data_load"] = f"{time1 - time5:0.3f}"
                 batch = to_device(batch, self.device)
+
                 time2 = time.perf_counter()
                 retval = model(**batch)
                 time3 = time.perf_counter()
                 speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 loss, stats, weight = retval
                 stats = {k: v for k, v in stats.items() if v is not None}
+
                 if self.use_ddp or self.use_fsdp:
                     # Apply weighted averaging for loss and stats
                     loss = (loss * weight.type(loss.dtype)).sum()
@@ -577,28 +585,33 @@
                     # Multiply world_size because DistributedDataParallel
                     # automatically normalizes the gradient by world_size.
                     loss *= self.world_size
+
                 # Scale the loss since we're not updating for every mini-batch
                 loss = loss
                 time4 = time.perf_counter()
 
-                self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
-                    batch_idx + 1
-                )
-                if "acc" in stats:
-                    self.val_acc_avg = (
-                        self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
-                    ) / (batch_idx + 1)
-                if self.use_ddp or self.use_fsdp:
-                    val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
-                        self.device
+                if torch.isfinite(loss):
+                    self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
+                        batch_idx + 1
                     )
-                    val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
-                        self.device
-                    )
-                    dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
-                    dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
-                    self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
-                    self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+
+                    if "acc" in stats:
+                        self.val_acc_avg = (
+                            self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
+                        ) / (batch_idx + 1)
+
+                    if self.use_ddp or self.use_fsdp:
+                        val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
+                            self.device
+                        )
+                        val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
+                            self.device
+                        )
+                        dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+                        dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+                        self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+                        self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+
                 time5 = time.perf_counter()
                 batch_num_epoch = 1
                 if hasattr(dataloader_val, "__len__"):
@@ -624,8 +637,8 @@
             ckpt_name = f"model.pt.ep{epoch}"
         else:
             ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
-        self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
-        self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+        self.val_acc_step_or_epoch[ckpt_name] = self.val_acc_avg
+        self.val_loss_step_or_epoch[ckpt_name] = self.val_loss_avg
         model.train()
 
         if self.use_ddp or self.use_fsdp:
diff --git a/funasr/utils/compute_det_ctc.py b/funasr/utils/compute_det_ctc.py
new file mode 100644
index 0000000..47c8182
--- /dev/null
+++ b/funasr/utils/compute_det_ctc.py
@@ -0,0 +1,286 @@
+""" This implementation is adapted from https://github.com/wenet-e2e/wekws/blob/main/wekws/bin/compute_det.py."""
+
+import os
+import json
+import logging
+import argparse
+import threading
+
+import kaldiio
+import torch
+from funasr.utils.kws_utils import split_mixed_label
+
+
+class thread_wrapper(threading.Thread):
+    def __init__(self, func, args=()):
+        super(thread_wrapper, self).__init__()
+        self.func = func
+        self.args = args
+        self.result = []
+
+    def run(self):
+        self.result = self.func(*self.args)
+
+    def get_result(self):
+        try:
+            return self.result
+        except Exception:
+            return None
+
+
+def space_mixed_label(input_str):
+    splits = split_mixed_label(input_str)
+    space_str = ''.join(f'{sub} ' for sub in splits)
+    return space_str.strip()
+
+
+def read_lists(list_file):
+    lists = []
+    with open(list_file, 'r', encoding='utf8') as fin:
+        for line in fin:
+            if line.strip() != '':
+                lists.append(line.strip())
+    return lists
+
+
+def make_pair(wav_lists, trans_lists):
+    logging.info('make pair for wav-trans list')
+
+    trans_table = {}
+    for line in trans_lists:
+        arr = line.strip().replace('\t', ' ').split()
+        if len(arr) < 2:
+            logging.debug('invalid line in trans file: {}'.format(
+                line.strip()))
+            continue
+
+        trans_table[arr[0]] = line.replace(arr[0],'').strip()
+
+    lists = []
+    for line in wav_lists:
+        arr = line.strip().replace('\t', ' ').split()
+        if len(arr) == 2 and arr[0] in trans_table:
+            lists.append(
+                dict(key=arr[0],
+                     txt=trans_table[arr[0]],
+                     wav=arr[1],
+                     sample_rate=16000))
+        else:
+            logging.debug("can't find corresponding trans for key: {}".format(
+                arr[0]))
+            continue
+
+    return lists
+
+
+def count_duration(tid, data_lists):
+    results = []
+
+    for obj in data_lists:
+        assert 'key' in obj
+        assert 'wav' in obj
+        assert 'txt' in obj
+        key = obj['key']
+        wav_file = obj['wav']
+        txt = obj['txt']
+
+        try:
+            rate, waveform = kaldiio.load_mat(wav_file)
+            waveform = torch.tensor(waveform, dtype=torch.float32)
+            waveform = waveform.unsqueeze(0)
+            frames = len(waveform[0])
+            duration = frames / float(rate)
+        except:
+            logging.info(f'load file failed: {wav_file}')
+            duration = 0.0
+
+        obj['duration'] = duration
+        results.append(obj)
+
+    return results
+
+
+def load_data_and_score(keywords_list, data_file, trans_file, score_file):
+    # score_table: {uttid: [keywordlist]}
+    score_table = {}
+    with open(score_file, 'r', encoding='utf8') as fin:
+        # read score file and store in table
+        for line in fin:
+            arr = line.strip().split()
+            key = arr[0]
+            is_detected = arr[1]
+            if is_detected == 'detected':
+                if key not in score_table:
+                    score_table.update(
+                        {key: {
+                            'kw': space_mixed_label(arr[2]),
+                            'confi': float(arr[3])
+                        }})
+            else:
+                if key not in score_table:
+                    score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})
+
+    wav_lists = read_lists(data_file)
+    trans_lists = read_lists(trans_file)
+    data_lists = make_pair(wav_lists, trans_lists)
+    logging.info(f'origin list samples: {len(data_lists)}')
+
+    # count duration for each wave
+    num_workers = 8
+    start = 0
+    step = int(len(data_lists) / num_workers)
+    tasks = []
+    for idx in range(num_workers):
+        if idx != num_workers - 1:
+            task = thread_wrapper(count_duration,
+                                  (idx, data_lists[start:start + step]))
+        else:
+            task = thread_wrapper(count_duration, (idx, data_lists[start:]))
+        task.start()
+        tasks.append(task)
+        start += step
+
+    duration_lists = []
+    for task in tasks:
+        task.join()
+        duration_lists += task.get_result()
+    logging.info(f'after list samples: {len(duration_lists)}')
+
+    # build empty structure for keyword-filler infos
+    keyword_filler_table = {}
+    for keyword in keywords_list:
+        keyword = space_mixed_label(keyword)
+        keyword_filler_table[keyword] = {}
+        keyword_filler_table[keyword]['keyword_table'] = {}
+        keyword_filler_table[keyword]['keyword_duration'] = 0.0
+        keyword_filler_table[keyword]['filler_table'] = {}
+        keyword_filler_table[keyword]['filler_duration'] = 0.0
+
+    for obj in duration_lists:
+        assert 'key' in obj
+        assert 'wav' in obj
+        assert 'txt' in obj
+        assert 'duration' in obj
+
+        key = obj['key']
+        wav_file = obj['wav']
+        txt = obj['txt']
+        txt = space_mixed_label(txt)
+        txt_regstr_lrblk = ' ' + txt + ' '
+        duration = obj['duration']
+        assert key in score_table
+
+        for keyword in keywords_list:
+            keyword = space_mixed_label(keyword)
+            keyword_regstr_lrblk = ' ' + keyword + ' '
+            if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
+                if keyword == score_table[key]['kw']:
+                    keyword_filler_table[keyword]['keyword_table'].update(
+                        {key: score_table[key]['confi']})
+                else:
+                    # uttrance detected but not match this keyword
+                    keyword_filler_table[keyword]['keyword_table'].update(
+                        {key: -1.0})
+                keyword_filler_table[keyword]['keyword_duration'] += duration
+            else:
+                if keyword == score_table[key]['kw']:
+                    keyword_filler_table[keyword]['filler_table'].update(
+                        {key: score_table[key]['confi']})
+                else:
+                    # uttrance if detected, which is not FA for this keyword
+                    keyword_filler_table[keyword]['filler_table'].update(
+                        {key: -1.0})
+                keyword_filler_table[keyword]['filler_duration'] += duration
+
+    return keyword_filler_table
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='compute det curve')
+    parser.add_argument('--keywords',
+                        type=str,
+                        required=True,
+                        help='preset keyword str, input all keywords')
+    parser.add_argument('--test_data', required=True, help='test data file')
+    parser.add_argument('--trans_data',
+                        required=True,
+                        default='',
+                        help='transcription of test data')
+    parser.add_argument('--score_file', required=True, help='score file')
+    parser.add_argument('--step',
+                        type=float,
+                        default=0.001,
+                        help='threshold step')
+    parser.add_argument('--stats_dir',
+                        required=True,
+                        help='to save det stats files')
+    args = parser.parse_args()
+
+    root_logger = logging.getLogger()
+    handlers = root_logger.handlers[:]
+    for handler in handlers:
+        root_logger.removeHandler(handler)
+        handler.close()
+
+    logging.basicConfig(level=logging.DEBUG,
+                        format='%(asctime)s %(levelname)s %(message)s')
+
+    keywords_list = args.keywords.strip().split(',')
+    keyword_filler_table = load_data_and_score(keywords_list, args.test_data,
+                                               args.trans_data,
+                                               args.score_file)
+
+    stats_files = {}
+    for keyword in keywords_list:
+        keyword = space_mixed_label(keyword)
+        keyword_dur = keyword_filler_table[keyword]['keyword_duration']
+        keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
+        filler_dur = keyword_filler_table[keyword]['filler_duration']
+        filler_num = len(keyword_filler_table[keyword]['filler_table'])
+        if keyword_num <= 0:
+            print('Can\'t compute det for {} without positive sample'.format(keyword))
+            continue
+        if filler_num <= 0:
+            print('Can\'t compute det for {} without negative sample'.format(keyword))
+            continue
+
+        logging.info('Computing det for {}'.format(keyword))
+        logging.info('  Keyword duration: {} Hours, wave number: {}'.format(
+            keyword_dur / 3600.0, keyword_num))
+        logging.info('  Filler duration: {} Hours'.format(filler_dur / 3600.0))
+
+        stats_file = os.path.join(args.stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
+        with open(stats_file, 'w', encoding='utf8') as fout:
+            threshold = 0.0
+            while threshold <= 1.0:
+                num_false_reject = 0
+                num_true_detect = 0
+                # transverse the all keyword_table
+                for key, confi in keyword_filler_table[keyword][
+                        'keyword_table'].items():
+                    if confi < threshold:
+                        num_false_reject += 1
+                    else:
+                        num_true_detect += 1
+
+                num_false_alarm = 0
+                # transverse the all filler_table
+                for key, confi in keyword_filler_table[keyword][
+                        'filler_table'].items():
+                    if confi >= threshold:
+                        num_false_alarm += 1
+                        # print(f'false alarm: {keyword}, {key}, {confi}')
+
+                # false_reject_rate = num_false_reject / keyword_num
+                true_detect_rate = num_true_detect / keyword_num
+
+                num_false_alarm = max(num_false_alarm, 1e-6)
+                false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
+                false_alarm_rate = num_false_alarm / filler_num
+
+                fout.write('{:.3f} {:.6f} {:.6f} {:.6f}\n'.format(
+                    threshold, true_detect_rate, false_alarm_rate,
+                    false_alarm_per_hour))
+                threshold += args.step
+
+        stats_files[keyword] = stats_file
diff --git a/funasr/utils/kws_utils.py b/funasr/utils/kws_utils.py
new file mode 100644
index 0000000..4935040
--- /dev/null
+++ b/funasr/utils/kws_utils.py
@@ -0,0 +1,284 @@
+import re
+import logging
+
+import torch
+import math
+from collections import defaultdict
+from typing import List, Optional, Tuple
+
+
+symbol_str = '[鈥�!"#$%&\'()*+,-./:;<>=?@锛屻��?鈽呫�佲�︺�愩�戙�娿�嬶紵鈥溾�濃�樷�欙紒[\\]^_`{|}~\s]+'
+
+
+def split_mixed_label(input_str):
+    tokens = []
+    s = input_str.lower()
+    while len(s) > 0:
+        match = re.match(r'[A-Za-z!?,<>()\']+', s)
+        if match is not None:
+            word = match.group(0)
+        else:
+            word = s[0:1]
+        tokens.append(word)
+        s = s.replace(word, '', 1).strip(' ')
+    return tokens
+
+
+def query_token_set(txt, symbol_table, lexicon_table):
+    tokens_str = tuple()
+    tokens_idx = tuple()
+
+    if txt in symbol_table:
+        tokens_str = tokens_str + (txt, )
+        tokens_idx = tokens_idx + (symbol_table[txt], )
+        return tokens_str, tokens_idx
+
+    parts = split_mixed_label(txt)
+    for part in parts:
+        if part == '!sil' or part == '(sil)' or part == '<sil>':
+            tokens_str = tokens_str + ('!sil', )
+        elif part == '<blank>' or part == '<blank>':
+            tokens_str = tokens_str + ('<blank>', )
+        elif part == '(noise)' or part == 'noise)' or part == '(noise' or part == '<noise>':
+            tokens_str = tokens_str + ('<unk>', )
+        elif part in symbol_table:
+            tokens_str = tokens_str + (part, )
+        elif part in lexicon_table:
+            for ch in lexicon_table[part]:
+                tokens_str = tokens_str + (ch, )
+        else:
+            part = re.sub(symbol_str, '', part)
+            for ch in part:
+                tokens_str = tokens_str + (ch, )
+
+    for ch in tokens_str:
+        if ch in symbol_table:
+            tokens_idx = tokens_idx + (symbol_table[ch], )
+        elif ch == '!sil':
+            if 'sil' in symbol_table:
+                tokens_idx = tokens_idx + (symbol_table['sil'], )
+            else:
+                tokens_idx = tokens_idx + (symbol_table['<blank>'], )
+        elif ch == '<unk>':
+            if '<unk>' in symbol_table:
+                tokens_idx = tokens_idx + (symbol_table['<unk>'], )
+            else:
+                tokens_idx = tokens_idx + (symbol_table['<blank>'], )
+        else:
+            if '<unk>' in symbol_table:
+                tokens_idx = tokens_idx + (symbol_table['<unk>'], )
+                logging.info(f'\'{ch}\' is not in token set, replace with <unk>')
+            else:
+                tokens_idx = tokens_idx + (symbol_table['<blank>'], )
+                logging.info(f'\'{ch}\' is not in token set, replace with <blank>')
+
+    return tokens_str, tokens_idx
+
+
+class KwsCtcPrefixDecoder():
+    """Decoder interface wrapper for CTCPrefixDecode."""
+
+    def __init__(
+        self,
+        ctc: torch.nn.Module,
+        keywords: str,
+        token_list: list,
+        seg_dict: dict,
+    ):
+        """Initialize class.
+
+        Args:
+            ctc (torch.nn.Module): The CTC implementation.
+                For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
+
+        """
+        self.ctc = ctc
+        self.token_list = token_list
+
+        token_table = {}
+        for token in token_list:
+            token_table[token] = token_list.index(token)
+
+        self.keywords_idxset = {0}
+        self.keywords_token = {}
+        self.keywords_str = keywords
+        keywords_list = self.keywords_str.strip().replace(' ', '').split(',')
+        for keyword in keywords_list:
+            strs, indexs = query_token_set(keyword, token_table, seg_dict)
+            self.keywords_token[keyword] = {}
+            self.keywords_token[keyword]['token_id'] = indexs
+            self.keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i) for i in indexs)
+            [ self.keywords_idxset.add(i) for i in indexs ]
+
+    def beam_search(
+        self,
+        logits: torch.Tensor,
+        logits_lengths: torch.Tensor,
+        keywords_tokenset: set = None,
+        score_beam_size: int = 3,
+        path_beam_size: int = 20,
+    ) -> Tuple[List[List[int]], torch.Tensor]:
+        """ CTC prefix beam search inner implementation
+
+        Args:
+            logits (torch.Tensor): (1, max_len, vocab_size)
+            logits_lengths (torch.Tensor): (1, )
+            keywords_tokenset (set): token set for filtering score
+            score_beam_size (int): beam size for score
+            path_beam_size (int): beam size for path
+
+        Returns:
+            List[List[int]]: nbest results
+        """
+
+        maxlen = logits.size(0)
+        ctc_probs = logits
+        cur_hyps = [(tuple(), (1.0, 0.0, []))]
+
+        # CTC beam search step by step
+        for t in range(0, maxlen):
+            probs = ctc_probs[t]  # (vocab_size,)
+            # key: prefix, value (pb, pnb), default value(-inf, -inf)
+            next_hyps = defaultdict(lambda: (0.0, 0.0, []))
+
+            # 2.1 First beam prune: select topk best
+            top_k_probs, top_k_index = probs.topk(
+                score_beam_size)  # (score_beam_size,)
+
+            # filter prob score that is too small
+            filter_probs = []
+            filter_index = []
+            for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
+                if keywords_tokenset is not None:
+                    if prob > 0.05 and idx in keywords_tokenset:
+                        filter_probs.append(prob)
+                        filter_index.append(idx)
+                else:
+                    if prob > 0.05:
+                        filter_probs.append(prob)
+                        filter_index.append(idx)
+
+            if len(filter_index) == 0:
+                continue
+
+            for s in filter_index:
+                ps = probs[s].item()
+                # print(f'frame:{t}, token:{s}, score:{ps}')
+
+                for prefix, (pb, pnb, cur_nodes) in cur_hyps:
+                    last = prefix[-1] if len(prefix) > 0 else None
+                    if s == 0:  # blank
+                        n_pb, n_pnb, nodes = next_hyps[prefix]
+                        n_pb = n_pb + pb * ps + pnb * ps
+                        nodes = cur_nodes.copy()
+                        next_hyps[prefix] = (n_pb, n_pnb, nodes)
+                    elif s == last:
+                        if not math.isclose(pnb, 0.0, abs_tol=0.000001):
+                            # Update *ss -> *s;
+                            n_pb, n_pnb, nodes = next_hyps[prefix]
+                            n_pnb = n_pnb + pnb * ps
+                            nodes = cur_nodes.copy()
+                            if ps > nodes[-1]['prob']:  # update frame and prob
+                                nodes[-1]['prob'] = ps
+                                nodes[-1]['frame'] = t
+                            next_hyps[prefix] = (n_pb, n_pnb, nodes)
+
+                        if not math.isclose(pb, 0.0, abs_tol=0.000001):
+                            # Update *s-s -> *ss, - is for blank
+                            n_prefix = prefix + (s, )
+                            n_pb, n_pnb, nodes = next_hyps[n_prefix]
+                            n_pnb = n_pnb + pb * ps
+                            nodes = cur_nodes.copy()
+                            nodes.append(dict(token=s, frame=t,
+                                              prob=ps))  # to record token prob
+                            next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
+                    else:
+                        n_prefix = prefix + (s, )
+                        n_pb, n_pnb, nodes = next_hyps[n_prefix]
+                        if nodes:
+                            if ps > nodes[-1]['prob']:  # update frame and prob
+                                nodes[-1]['prob'] = ps
+                                nodes[-1]['frame'] = t
+                        else:
+                            nodes = cur_nodes.copy()
+                            nodes.append(dict(token=s, frame=t,
+                                              prob=ps))  # to record token prob
+                        n_pnb = n_pnb + pb * ps + pnb * ps
+                        next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
+
+            # 2.2 Second beam prune
+            next_hyps = sorted(next_hyps.items(),
+                               key=lambda x: (x[1][0] + x[1][1]),
+                               reverse=True)
+
+            cur_hyps = next_hyps[:path_beam_size]
+
+        hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
+        return hyps
+
+
+    def is_sublist(self, main_list, check_list):
+        if len(main_list) < len(check_list):
+            return -1
+
+        if len(main_list) == len(check_list):
+            return 0 if main_list == check_list else -1
+
+        for i in range(len(main_list) - len(check_list)):
+            if main_list[i] == check_list[0]:
+                for j in range(len(check_list)):
+                    if main_list[i + j] != check_list[j]:
+                        break
+                else:
+                    return i
+        else:
+            return -1
+
+
+    def _decode_inside(
+        self,
+        logits: torch.Tensor,
+        logits_lengths: torch.Tensor,
+    ):
+        hyps = self.beam_search(logits, logits_lengths, self.keywords_idxset)
+
+        hit_keyword = None
+        hit_score = 1.0
+        # start = 0; end = 0
+        for one_hyp in hyps:
+            prefix_ids = one_hyp[0]
+            # path_score = one_hyp[1]
+            prefix_nodes = one_hyp[2]
+            assert len(prefix_ids) == len(prefix_nodes)
+            for word in self.keywords_token.keys():
+                lab = self.keywords_token[word]['token_id']
+                offset = self.is_sublist(prefix_ids, lab)
+                if offset != -1:
+                    hit_keyword = word
+                    for idx in range(offset, offset + len(lab)):
+                        hit_score *= prefix_nodes[idx]['prob']
+                    break
+            if hit_keyword is not None:
+                hit_score = math.sqrt(hit_score)
+                break
+
+        if hit_keyword is not None:
+            return True, hit_keyword, hit_score
+        else:
+            return False, None, None
+
+
+    def decode(self, x: torch.Tensor):
+        """Get an initial state for decoding.
+
+        Args:
+            x (torch.Tensor): The encoded feature tensor
+
+        Returns: decode result
+
+        """
+
+        raw_logp = self.ctc.softmax(x.unsqueeze(0)).detach().squeeze(0).cpu()
+        xlen = torch.tensor([raw_logp.size(1)])
+
+        return self._decode_inside(raw_logp, xlen)
diff --git a/funasr/utils/types.py b/funasr/utils/type_utils.py
similarity index 100%
rename from funasr/utils/types.py
rename to funasr/utils/type_utils.py

--
Gitblit v1.9.1