From 59a791121fccd3c9ca177c4f6d33105a82d23ef3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 二月 2023 15:30:49 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/bin/lm_inference_launch.py                                                                             |  130 ++
 egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md |   25 
 funasr/bin/asr_inference_paraformer.py                                                                        |   61 +
 funasr/bin/asr_inference_uniasr_vad.py                                                                        |    2 
 funasr/bin/lm_calc_perplexity.py                                                                              |    3 
 funasr/datasets/preprocessor.py                                                                               |   73 +
 egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py                   |   87 +
 funasr/bin/lm_inference.py                                                                                    |  406 ++++++++
 egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py    |   52 +
 funasr/bin/asr_inference_paraformer_vad_punc.py                                                               |   38 
 egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md |   23 
 funasr/bin/lm_train.py                                                                                        |   50 
 funasr/models/e2e_asr_paraformer.py                                                                           |  496 +++++++++-
 egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py                |   37 
 funasr/models/decoder/contextual_decoder.py                                                                   |  776 ++++++++++++++++
 egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md   |   75 +
 funasr/bin/asr_inference_paraformer_vad.py                                                                    |   19 
 funasr/tasks/lm.py                                                                                            |    2 
 funasr/tasks/asr.py                                                                                           |    8 
 funasr/tasks/abs_task.py                                                                                      |   47 
 funasr/utils/postprocess_utils.py                                                                             |    6 
 funasr/lm/espnet_model.py                                                                                     |    4 
 funasr/bin/asr_inference_uniasr.py                                                                            |    2 
 funasr/bin/asr_inference.py                                                                                   |    2 
 funasr/utils/timestamp_tools.py                                                                               |   57 
 funasr/models/predictor/cif.py                                                                                |    5 
 egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md                  |   53 +
 funasr/bin/tokenize_text.py                                                                                   |  283 +++++
 funasr/bin/asr_inference_paraformer_timestamp.py                                                              |    2 
 egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py                               |   17 
 30 files changed, 2,737 insertions(+), 104 deletions(-)

diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md
new file mode 100644
index 0000000..c2e4354
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/README.md
@@ -0,0 +1,53 @@
+# ModelScope Model
+
+## How to finetune and infer using a pretrained Paraformer-large Model
+
+### Finetune
+
+- Modify finetune training related parameters in `finetune.py`
+    - <strong>output_dir:</strong> # result dir
+    - <strong>data_dir:</strong> # the dataset dir needs to include files: `train/wav.scp`, `train/text`; `validation/wav.scp`, `validation/text`
+    - <strong>dataset_type:</strong> # for dataset larger than 1000 hours, set as `large`, otherwise set as `small`
+    - <strong>batch_bins:</strong> # batch size. For dataset_type is `small`, `batch_bins` indicates the feature frames. For dataset_type is `large`, `batch_bins` indicates the duration in ms
+    - <strong>max_epoch:</strong> # number of training epoch
+    - <strong>lr:</strong> # learning rate
+
+- Then you can run the pipeline to finetune with:
+```python
+    python finetune.py
+```
+
+### Inference
+
+Or you can use the finetuned model for inference directly.
+
+- Setting parameters in `infer.py`
+    - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed
+    - <strong>output_dir:</strong> # result dir
+    - <strong>ngpu:</strong> # the number of GPUs for decoding
+    - <strong>njob:</strong> # the number of jobs for each GPU
+
+- Then you can run the pipeline to infer with:
+```python
+    python infer.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/1best_recog/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set. 
+
+### Inference using local finetuned model
+
+- Modify inference related parameters in `infer_after_finetune.py`
+    - <strong>output_dir:</strong> # result dir
+    - <strong>data_dir:</strong> # the dataset dir needs to include `test/wav.scp`. If `test/text` is also exists, CER will be computed~~~~
+    - <strong>decoding_model_name:</strong> # set the checkpoint name for decoding, e.g., `valid.cer_ctc.ave.pth`
+
+- Then you can run the pipeline to finetune with:
+```python
+    python infer_after_finetune.py
+```
+
+- Results
+
+The decoding results can be found in `$output_dir/decoding_results/text.cer`, which includes recognition results of each sample and the CER metric of the whole test set. 
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py
new file mode 100644
index 0000000..a5f1ee4
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/finetune.py
@@ -0,0 +1,37 @@
+import os
+
+from modelscope.metainfo import Trainers
+from modelscope.trainers import build_trainer
+
+from funasr.datasets.ms_dataset import MsDataset
+from funasr.utils.modelscope_param import modelscope_args
+
+
+def modelscope_finetune(params):
+    if not os.path.exists(params.output_dir):
+        os.makedirs(params.output_dir, exist_ok=True)
+    # dataset split ["train", "validation"]
+    ds_dict = MsDataset.load(params.data_path)
+    kwargs = dict(
+        model=params.model,
+        data_dir=ds_dict,
+        dataset_type=params.dataset_type,
+        work_dir=params.output_dir,
+        batch_bins=params.batch_bins,
+        max_epoch=params.max_epoch,
+        lr=params.lr)
+    trainer = build_trainer(Trainers.speech_asr_trainer, default_args=kwargs)
+    trainer.train()
+
+
+if __name__ == '__main__':
+    params = modelscope_args(model="damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k",
+                             data_path="./data")
+    params.output_dir = "./checkpoint"
+    params.data_path = "./example_data/"
+    params.dataset_type = "small"
+    params.batch_bins = 16000
+    params.max_epoch = 50
+    params.lr = 0.00002
+
+    modelscope_finetune(params)
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
new file mode 100644
index 0000000..c016c19
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer.py
@@ -0,0 +1,87 @@
+import os
+import shutil
+from multiprocessing import Pool
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_core(output_dir, split_dir, njob, idx):
+    output_dir_job = os.path.join(output_dir, "output.{}".format(idx))
+    gpu_id = (int(idx) - 1) // njob
+    if "CUDA_VISIBLE_DEVICES" in os.environ.keys():
+        gpu_list = os.environ['CUDA_VISIBLE_DEVICES'].split(",")
+        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_list[gpu_id])
+    else:
+        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
+    inference_pipline = pipeline(
+        task=Tasks.auto_speech_recognition,
+        model="damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k",
+        output_dir=output_dir_job,
+    )
+    audio_in = os.path.join(split_dir, "wav.{}.scp".format(idx))
+    inference_pipline(audio_in=audio_in)
+
+
+def modelscope_infer(params):
+    # prepare for multi-GPU decoding
+    ngpu = params["ngpu"]
+    njob = params["njob"]
+    output_dir = params["output_dir"]
+    if os.path.exists(output_dir):
+        shutil.rmtree(output_dir)
+    os.mkdir(output_dir)
+    split_dir = os.path.join(output_dir, "split")
+    os.mkdir(split_dir)
+    nj = ngpu * njob
+    wav_scp_file = os.path.join(params["data_dir"], "wav.scp")
+    with open(wav_scp_file) as f:
+        lines = f.readlines()
+        num_lines = len(lines)
+        num_job_lines = num_lines // nj
+    start = 0
+    for i in range(nj):
+        end = start + num_job_lines
+        file = os.path.join(split_dir, "wav.{}.scp".format(str(i + 1)))
+        with open(file, "w") as f:
+            if i == nj - 1:
+                f.writelines(lines[start:])
+            else:
+                f.writelines(lines[start:end])
+        start = end
+
+    p = Pool(nj)
+    for i in range(nj):
+        p.apply_async(modelscope_infer_core,
+                      args=(output_dir, split_dir, njob, str(i + 1)))
+    p.close()
+    p.join()
+
+    # combine decoding results
+    best_recog_path = os.path.join(output_dir, "1best_recog")
+    os.mkdir(best_recog_path)
+    files = ["text", "token", "score"]
+    for file in files:
+        with open(os.path.join(best_recog_path, file), "w") as f:
+            for i in range(nj):
+                job_file = os.path.join(output_dir, "output.{}/1best_recog".format(str(i + 1)), file)
+                with open(job_file) as f_job:
+                    lines = f_job.readlines()
+                f.writelines(lines)
+
+    # If text exists, compute CER
+    text_in = os.path.join(params["data_dir"], "text")
+    if os.path.exists(text_in):
+        text_proc_file = os.path.join(best_recog_path, "token")
+        compute_wer(text_in, text_proc_file, os.path.join(best_recog_path, "text.cer"))
+
+
+if __name__ == "__main__":
+    params = {}
+    params["data_dir"] = "./data/test"
+    params["output_dir"] = "./results"
+    params["ngpu"] = 2
+    params["njob"] = 5
+    modelscope_infer(params)
diff --git a/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py
new file mode 100644
index 0000000..56c282c
--- /dev/null
+++ b/egs_modelscope/asr/data2vec/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k/infer_after_finetune.py
@@ -0,0 +1,52 @@
+import json
+import os
+import shutil
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+from funasr.utils.compute_wer import compute_wer
+
+
+def modelscope_infer_after_finetune(params):
+    # prepare for decoding
+    pretrained_model_path = os.path.join(os.environ["HOME"], ".cache/modelscope/hub", params["modelscope_model_name"])
+    for file_name in params["required_files"]:
+        if file_name == "configuration.json":
+            with open(os.path.join(pretrained_model_path, file_name)) as f:
+                config_dict = json.load(f)
+                config_dict["model"]["am_model_name"] = params["decoding_model_name"]
+            with open(os.path.join(params["output_dir"], "configuration.json"), "w") as f:
+                json.dump(config_dict, f, indent=4, separators=(',', ': '))
+        else:
+            shutil.copy(os.path.join(pretrained_model_path, file_name),
+                        os.path.join(params["output_dir"], file_name))
+    decoding_path = os.path.join(params["output_dir"], "decode_results")
+    if os.path.exists(decoding_path):
+        shutil.rmtree(decoding_path)
+    os.mkdir(decoding_path)
+
+    # decoding
+    inference_pipeline = pipeline(
+        task=Tasks.auto_speech_recognition,
+        model=params["output_dir"],
+        output_dir=decoding_path,
+    )
+    audio_in = os.path.join(params["data_dir"], "wav.scp")
+    inference_pipeline(audio_in=audio_in)
+
+    # computer CER if GT text is set
+    text_in = os.path.join(params["data_dir"], "text")
+    if os.path.exists(text_in):
+        text_proc_file = os.path.join(decoding_path, "1best_recog/token")
+        compute_wer(text_in, text_proc_file, os.path.join(decoding_path, "text.cer"))
+
+
+if __name__ == '__main__':
+    params = {}
+    params["modelscope_model_name"] = "damo/speech_data2vec_pretrain-paraformer-zh-cn-aishell2-16k"
+    params["required_files"] = ["am.mvn", "decoding.yaml", "configuration.json"]
+    params["output_dir"] = "./checkpoint"
+    params["data_dir"] = "./data/test"
+    params["decoding_model_name"] = "valid.cer_ctc.ave.pth"
+    modelscope_infer_after_finetune(params)
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md
new file mode 100644
index 0000000..5eeae37
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/RESULTS.md
@@ -0,0 +1,23 @@
+# Paraformer-Large
+- Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-aishell1-vocab8404-pytorch/summary>
+- Model size: 220M
+
+# Environments
+- date: `Fri Feb 10 13:34:24 CST 2023`
+- python version: `3.7.12`
+- FunASR version: `0.1.6`
+- pytorch version: `pytorch 1.7.0`
+- Git hash: ``
+- Commit date: ``
+
+# Beachmark Results
+
+## AISHELL-1
+- Decode config:
+  - Decode without CTC
+  - Decode without LM
+
+| testset CER(%) | base model|finetune model |
+|:--------------:|:---------:|:-------------:|
+| dev            | 1.75      |1.62           |
+| test           | 1.95      |1.78           |
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md
new file mode 100644
index 0000000..71d9fee
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/RESULTS.md
@@ -0,0 +1,25 @@
+# Paraformer-Large
+- Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-aishell2-vocab8404-pytorch/summary>
+- Model size: 220M
+
+# Environments
+- date: `Fri Feb 10 13:34:24 CST 2023`
+- python version: `3.7.12`
+- FunASR version: `0.1.6`
+- pytorch version: `pytorch 1.7.0`
+- Git hash: ``
+- Commit date: ``
+
+# Beachmark Results
+
+## AISHELL-2
+- Decode config: 
+  - Decode without CTC
+  - Decode without LM
+
+| testset      | base model|finetune model|
+|:------------:|:---------:|:------------:|
+| dev_ios      | 2.80      |2.60          |
+| test_android | 3.13      |2.84          |
+| test_ios     | 2.85      |2.82          |
+| test_mic     | 3.06      |2.88          |
diff --git a/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
new file mode 100644
index 0000000..ec95be3
--- /dev/null
+++ b/egs_modelscope/asr/paraformer/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/RESULTS.md
@@ -0,0 +1,75 @@
+# Paraformer-Large
+- Model link: <https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary>
+- Model size: 220M
+
+# Environments
+- date: `Tue Nov 22 18:48:39 CST 2022`
+- python version: `3.7.12`
+- FunASR version: `0.1.0`
+- pytorch version: `pytorch 1.7.0`
+- Git hash: ``
+- Commit date: ``
+
+# Beachmark Results
+
+## AISHELL-1
+- Decode config: 
+  - Decode without CTC
+  - Decode without LM
+
+| testset   | CER(%)|
+|:---------:|:-----:|
+| dev       | 1.75  |
+| test      | 1.95  |
+
+## AISHELL-2
+- Decode config: 
+  - Decode without CTC
+  - Decode without LM
+
+| testset      | CER(%)|
+|:------------:|:-----:|
+| dev_ios      | 2.80  |
+| test_android | 3.13  |
+| test_ios     | 2.85  |
+| test_mic     | 3.06  |
+
+## Wenetspeech
+- Decode config: 
+  - Decode without CTC
+  - Decode without LM
+
+| testset   | CER(%)|
+|:---------:|:-----:|
+| dev       | 3.57  |
+| test      | 6.97  |
+| test_net  | 6.74  |
+
+## SpeechIO TIOBE
+- Decode config 1:
+  - Decode without CTC
+  - Decode without LM
+  - With text norm
+- Decode config 2:
+  - Decode without CTC
+  - Decode with Transformer-LM
+  - LM weight: 0.15
+  - With text norm
+
+| testset | w/o LM | w/ LM |
+|:------------------:|:----:|:----:|
+|SPEECHIO_ASR_ZH00001| 0.49 | 0.35 |
+|SPEECHIO_ASR_ZH00002| 3.23 | 2.86 |
+|SPEECHIO_ASR_ZH00003| 1.13 | 0.80 |
+|SPEECHIO_ASR_ZH00004| 1.33 | 1.10 |
+|SPEECHIO_ASR_ZH00005| 1.41 | 1.18 |
+|SPEECHIO_ASR_ZH00006| 5.25 | 4.85 |
+|SPEECHIO_ASR_ZH00007| 5.51 | 4.97 |
+|SPEECHIO_ASR_ZH00008| 3.69 | 3.18 |
+|SPEECHIO_ASR_ZH00009| 3.02 | 2.78 |
+|SPEECHIO_ASR_ZH000010| 3.35 | 2.99 |
+|SPEECHIO_ASR_ZH000011| 1.54 | 1.25 |
+|SPEECHIO_ASR_ZH000012| 2.06 | 1.68 |
+|SPEECHIO_ASR_ZH000013| 2.57 | 2.25 |
+|SPEECHIO_ASR_ZH000014| 3.86 | 3.08 |
+|SPEECHIO_ASR_ZH000015| 3.34 | 2.67 |
diff --git a/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py b/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py
new file mode 100644
index 0000000..ed3b7e2
--- /dev/null
+++ b/egs_modelscope/lm/speech_transformer_lm_zh-cn-common-vocab8404-pytorch/infer.py
@@ -0,0 +1,17 @@
+
+
+##################text浜岃繘鍒舵暟鎹�#####################
+inputs = "hello 澶� 瀹� 濂� 鍛�"
+
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipline = pipeline(
+    task=Tasks.language_model,
+    model='damo/speech_transformer_lm_zh-cn-common-vocab8404-pytorch',
+    output_dir="./tmp/"
+)
+
+rec_result = inference_pipline(text_in=inputs)
+print(rec_result)
+
diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index 16fa3e5..ca8f2bc 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -453,7 +453,7 @@
                     ibest_writer["score"][key] = str(hyp.score)
                 
                 if text is not None:
-                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                     item = {'key': key, 'value': text_postprocessed}
                     asr_result_list.append(item)
                     finish_count += 1
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 3769b6c..6c5acfc 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -3,6 +3,9 @@
 import logging
 import sys
 import time
+import copy
+import os
+import codecs
 from pathlib import Path
 from typing import Optional
 from typing import Sequence
@@ -35,6 +38,8 @@
 from funasr.utils.types import str_or_none
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
 from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
+
 
 header_colors = '\033[95m'
 end_colors = '\033[0m'
@@ -78,6 +83,7 @@
             penalty: float = 0.0,
             nbest: int = 1,
             frontend_conf: dict = None,
+            hotword_list_or_file: str = None,
             **kwargs,
     ):
         assert check_argument_types()
@@ -168,6 +174,34 @@
         self.asr_train_args = asr_train_args
         self.converter = converter
         self.tokenizer = tokenizer
+
+        # 6. [Optional] Build hotword list from file or str
+        if hotword_list_or_file is None:
+            self.hotword_list = None
+        elif os.path.exists(hotword_list_or_file):
+            self.hotword_list = []
+            hotword_str_list = []
+            with codecs.open(hotword_list_or_file, 'r') as fin:
+                for line in fin.readlines():
+                    hw = line.strip()
+                    hotword_str_list.append(hw)
+                    self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+                self.hotword_list.append([1])
+                hotword_str_list.append('<s>')
+            logging.info("Initialized hotword list from file: {}, hotword list: {}."
+                .format(hotword_list_or_file, hotword_str_list))
+        else:
+            logging.info("Attempting to parse hotwords as str...")
+            self.hotword_list = []
+            hotword_str_list = []
+            for hw in hotword_list_or_file.strip().split():
+                hotword_str_list.append(hw)
+                self.hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+            self.hotword_list.append([1])
+            hotword_str_list.append('<s>')
+            logging.info("Hotword list: {}.".format(hotword_str_list))
+
+
         is_use_lm = lm_weight != 0.0 and lm_file is not None
         if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
             beam_search = None
@@ -229,8 +263,14 @@
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
             return []
-        decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
-        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+        if not isinstance(self.asr_model, ContextualParaformer):
+            if self.hotword_list:
+                logging.warning("Hotword is given but asr model is not a ContextualParaformer.")
+            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
+        else:
+            decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length, hw_list=self.hotword_list)
+            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
         results = []
         b, n, d = decoder_out.size()
@@ -388,6 +428,11 @@
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
     )
 
+    if param_dict is not None:
+        hotword_list_or_file = param_dict.get('hotword')
+    else:
+        hotword_list_or_file = None
+
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
     else:
@@ -416,6 +461,7 @@
         ngram_weight=ngram_weight,
         penalty=penalty,
         nbest=nbest,
+        hotword_list_or_file=hotword_list_or_file,
     )
     speech2text = Speech2Text(**speech2text_kwargs)
 
@@ -497,7 +543,7 @@
                         ibest_writer["rtf"][key] = rtf_cur
 
                     if text is not None:
-                        text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                        text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                         item = {'key': key, 'value': text_postprocessed}
                         asr_result_list.append(item)
                         finish_count += 1
@@ -551,7 +597,12 @@
         default=1,
         help="The number of workers used for DataLoader",
     )
-
+    parser.add_argument(
+        "--hotword",
+        type=str_or_none,
+        default=None,
+        help="hotword file path or hotwords seperated by space"
+    )
     group = parser.add_argument_group("Input data related")
     group.add_argument(
         "--data_path_and_name_and_type",
@@ -679,8 +730,10 @@
     print(get_commandline_args(), file=sys.stderr)
     parser = get_parser()
     args = parser.parse_args(cmd)
+    param_dict = {'hotword': args.hotword}
     kwargs = vars(args)
     kwargs.pop("config", None)
+    kwargs['param_dict'] = param_dict
     inference(**kwargs)
 
 
diff --git a/funasr/bin/asr_inference_paraformer_timestamp.py b/funasr/bin/asr_inference_paraformer_timestamp.py
index 7e2e414..7da48e2 100644
--- a/funasr/bin/asr_inference_paraformer_timestamp.py
+++ b/funasr/bin/asr_inference_paraformer_timestamp.py
@@ -436,7 +436,7 @@
                     ibest_writer["score"][key] = str(hyp.score)
     
                 if text is not None:
-                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                     item = {'key': key, 'value': text_postprocessed}
                     asr_result_list.append(item)
                     finish_count += 1
diff --git a/funasr/bin/asr_inference_paraformer_vad.py b/funasr/bin/asr_inference_paraformer_vad.py
index 2832504..dbb2719 100644
--- a/funasr/bin/asr_inference_paraformer_vad.py
+++ b/funasr/bin/asr_inference_paraformer_vad.py
@@ -241,6 +241,11 @@
             allow_variable_data_keys=allow_variable_data_keys,
             inference=True,
         )
+
+        if param_dict is not None:
+            use_timestamp = param_dict.get('use_timestamp', True)
+        else:
+            use_timestamp = True
         
         finish_count = 0
         file_count = 1
@@ -284,8 +289,10 @@
                 text, token, token_int = result[0], result[1], result[2]
                 time_stamp = None if len(result) < 4 else result[3]
                
-                
-                postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+                if use_timestamp and time_stamp is not None:
+                    postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+                else:
+                    postprocessed_result = postprocess_utils.sentence_postprocess(token)
                 text_postprocessed = ""
                 time_stamp_postprocessed = ""
                 text_postprocessed_punc = postprocessed_result
@@ -293,9 +300,11 @@
                     text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
                                                                                postprocessed_result[1], \
                                                                                postprocessed_result[2]
-                    text_postprocessed_punc = text_postprocessed
-                    if len(word_lists) > 0 and text2punc is not None:
-                        text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+                else:
+                    text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+                text_postprocessed_punc = text_postprocessed
+                if len(word_lists) > 0 and text2punc is not None:
+                    text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
 
                 
                 item = {'key': key, 'value': text_postprocessed_punc}
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 1d09c79..c4bb61b 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -14,6 +14,7 @@
 from typing import Any
 from typing import List
 import math
+import copy
 import numpy as np
 import torch
 from typeguard import check_argument_types
@@ -38,8 +39,9 @@
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.tasks.vad import VADTask
-from funasr.utils.timestamp_tools import time_stamp_lfr6
+from funasr.utils.timestamp_tools import time_stamp_lfr6, time_stamp_lfr6_pl
 from funasr.bin.punctuation_infer import Text2Punc
+from funasr.models.e2e_asr_paraformer import BiCifParaformer
 
 header_colors = '\033[95m'
 end_colors = '\033[0m'
@@ -234,6 +236,10 @@
         decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
         decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
+        if isinstance(self.asr_model, BiCifParaformer):
+            _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
+                                                                                   pre_token_length)  # test no bias cif2
+
         results = []
         b, n, d = decoder_out.size()
         for i in range(b):
@@ -276,9 +282,12 @@
                 else:
                     text = None
 
-                time_stamp = time_stamp_lfr6(alphas[i:i+1,], enc_len[i:i+1,], token, begin_time, end_time)
-    
-                results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
+                if isinstance(self.asr_model, BiCifParaformer):
+                    timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
+                    results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
+                else:
+                    time_stamp = time_stamp_lfr6(alphas[i:i + 1, ], enc_len[i:i + 1, ], copy.copy(token), begin_time, end_time)
+                    results.append((text, token, token_int, time_stamp, enc_len_batch_total, lfr_factor))
 
         # assert check_return_type(results)
         return results
@@ -561,6 +570,11 @@
             allow_variable_data_keys=allow_variable_data_keys,
             inference=True,
         )
+
+        if param_dict is not None:
+            use_timestamp = param_dict.get('use_timestamp', True)
+        else:
+            use_timestamp = True
     
         finish_count = 0
         file_count = 1
@@ -603,8 +617,11 @@
                 result = result_segments[0]
                 text, token, token_int = result[0], result[1], result[2]
                 time_stamp = None if len(result) < 4 else result[3]
-    
-                postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+   
+                if use_timestamp and time_stamp is not None: 
+                    postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+                else:
+                    postprocessed_result = postprocess_utils.sentence_postprocess(token)
                 text_postprocessed = ""
                 time_stamp_postprocessed = ""
                 text_postprocessed_punc = postprocessed_result
@@ -612,9 +629,12 @@
                     text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
                                                                                postprocessed_result[1], \
                                                                                postprocessed_result[2]
-                    text_postprocessed_punc = text_postprocessed
-                    if len(word_lists) > 0 and text2punc is not None:
-                        text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+                else:
+                    text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+
+                text_postprocessed_punc = text_postprocessed
+                if len(word_lists) > 0 and text2punc is not None:
+                    text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
     
                 item = {'key': key, 'value': text_postprocessed_punc}
                 if text_postprocessed != "":
diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index cfec9a0..0a5824c 100644
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -492,7 +492,7 @@
                     ibest_writer["score"][key] = str(hyp.score)
     
                 if text is not None:
-                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                     item = {'key': key, 'value': text_postprocessed}
                     asr_result_list.append(item)
                     finish_count += 1
diff --git a/funasr/bin/asr_inference_uniasr_vad.py b/funasr/bin/asr_inference_uniasr_vad.py
index cfec9a0..0a5824c 100644
--- a/funasr/bin/asr_inference_uniasr_vad.py
+++ b/funasr/bin/asr_inference_uniasr_vad.py
@@ -492,7 +492,7 @@
                     ibest_writer["score"][key] = str(hyp.score)
     
                 if text is not None:
-                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                     item = {'key': key, 'value': text_postprocessed}
                     asr_result_list.append(item)
                     finish_count += 1
diff --git a/funasr/bin/lm_calc_perplexity.py b/funasr/bin/lm_calc_perplexity.py
index 27a8a71..198d578 100755
--- a/funasr/bin/lm_calc_perplexity.py
+++ b/funasr/bin/lm_calc_perplexity.py
@@ -56,7 +56,7 @@
     set_all_random_seed(seed)
 
     # 2. Build LM
-    model, train_args = LMTask.build_model_from_file(train_config, model_file, device)
+    model, train_args = LMTask.build_model_from_file(config_file=train_config, model_file=model_file, device=device)
     # Wrape model to make model.nll() data-parallel
     wrapped_model = ForwardAdaptor(model, "nll")
     wrapped_model.to(dtype=getattr(torch, dtype)).eval()
@@ -111,6 +111,7 @@
                     utt_ppl = log_base ** (_nll / ntoken / np.log(log_base))
 
                 # Write PPL of each utts for debugging or analysis
+                writer["utt2nll"][key] = str(-_nll)
                 writer["utt2ppl"][key] = str(utt_ppl)
                 writer["utt2ntokens"][key] = str(ntoken)
 
diff --git a/funasr/bin/lm_inference.py b/funasr/bin/lm_inference.py
new file mode 100644
index 0000000..909cb02
--- /dev/null
+++ b/funasr/bin/lm_inference.py
@@ -0,0 +1,406 @@
+#!/usr/bin/env python3
+import argparse
+import logging
+from pathlib import Path
+import sys
+import os
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+from typing import Any
+from typing import List
+
+import numpy as np
+import torch
+from torch.nn.parallel import data_parallel
+from typeguard import check_argument_types
+
+from funasr.tasks.lm import LMTask
+from funasr.datasets.preprocessor import LMPreprocessor
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.forward_adaptor import ForwardAdaptor
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import float_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+
+def inference(
+    output_dir: str,
+    batch_size: int,
+    dtype: str,
+    ngpu: int,
+    seed: int,
+    num_workers: int,
+    log_level: Union[int, str],
+    train_config: Optional[str],
+    model_file: Optional[str],
+    log_base: Optional[float],
+    key_file: Optional[str] = None,
+    allow_variable_data_keys: bool = False,
+    split_with_space: Optional[bool] = False,
+    seg_dict_file: Optional[str] = None,
+    data_path_and_name_and_type: Sequence[Tuple[str, str, str]] = None,
+    raw_inputs: Union[List[Any], bytes, str] = None,
+    **kwargs,
+):
+    inference_pipeline = inference_modelscope(
+        output_dir=output_dir,
+        raw_inputs=raw_inputs,
+        batch_size=batch_size,
+        dtype=dtype,
+        ngpu=ngpu,
+        seed=seed,
+        num_workers=num_workers,
+        log_level=log_level,
+        key_file=key_file,
+        train_config=train_config,
+        model_file=model_file,
+        log_base = log_base,
+        allow_variable_data_keys = allow_variable_data_keys,
+        split_with_space=split_with_space,
+        seg_dict_file=seg_dict_file,
+        **kwargs,
+    )
+    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+
+def inference_modelscope(
+    batch_size: int,
+    dtype: str,
+    ngpu: int,
+    seed: int,
+    num_workers: int,
+    log_level: Union[int, str],
+    key_file: Optional[str],
+    train_config: Optional[str],
+    model_file: Optional[str],
+    log_base: Optional[float] = 10,
+    allow_variable_data_keys: bool = False,
+    split_with_space: Optional[bool] = False,
+    seg_dict_file: Optional[str] = None,
+    output_dir: Optional[str] = None,
+    param_dict: dict = None,
+    **kwargs,
+):
+    assert check_argument_types()
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+
+    # 2. Build Model
+    model, train_args = LMTask.build_model_from_file(
+        train_config, model_file, device)
+    wrapped_model = ForwardAdaptor(model, "nll")
+    wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
+    logging.info(f"Model:\n{model}")
+
+    preprocessor = LMPreprocessor(
+        train=False,
+        token_type=train_args.token_type,
+        token_list=train_args.token_list,
+        bpemodel=train_args.bpemodel,
+        text_cleaner=train_args.cleaner,
+        g2p_type=train_args.g2p,
+        text_name="text",
+        non_linguistic_symbols=train_args.non_linguistic_symbols,
+        split_with_space=split_with_space,
+        seg_dict_file=seg_dict_file
+    )
+
+    def _forward(
+        data_path_and_name_and_type,
+        raw_inputs: Union[List[Any], bytes, str] = None,
+        output_dir_v2: Optional[str] = None,
+        param_dict: dict = None,
+    ):
+        results = []
+        if output_dir_v2 is not None:
+            writer = DatadirWriter(output_dir_v2)
+        else:
+            writer = None
+
+        if raw_inputs != None:
+            line = raw_inputs.strip()
+            key = "lm demo"
+            if line=="":
+                item = {'key': key, 'value': ""}
+                results.append(item)
+                return results
+            batch = {}
+            batch['text'] = line
+            if preprocessor != None:
+                batch = preprocessor(key, batch)
+            
+            #  Force data-precision
+            for name in batch:
+                value = batch[name]
+                if not isinstance(value, np.ndarray):
+                    raise RuntimeError(
+                        f"All values must be converted to np.ndarray object "
+                        f'by preprocessing, but "{name}" is still {type(value)}.'
+                    )
+                # Cast to desired type
+                if value.dtype.kind == "f":
+                    value = value.astype("float32")
+                elif value.dtype.kind == "i":
+                    value = value.astype("long")
+                else:
+                    raise NotImplementedError(f"Not supported dtype: {value.dtype}")
+                batch[name] = value
+            
+            batch["text_lengths"] = torch.from_numpy(
+                np.array([len(batch["text"])], dtype='int32'))
+            batch["text"] = np.expand_dims(batch["text"], axis=0)
+
+            with torch.no_grad():
+                batch = to_device(batch, device)
+                if ngpu <= 1:
+                    nll, lengths = wrapped_model(**batch)
+                else:
+                    nll, lengths = data_parallel(
+                        wrapped_model, (), range(ngpu), module_kwargs=batch
+                    )
+                ## compute ppl
+                ppl_out_batch = ""
+                ids2tokens = preprocessor.token_id_converter.ids2tokens
+                for sent_ids, sent_nll in zip(batch['text'], nll):
+                    pre_word = "<s>"
+                    cur_word = None
+                    sent_lst = ids2tokens(sent_ids) + ['</s>']
+                    ppl_out = " ".join(sent_lst) + "\n"
+                    for word, word_nll in zip(sent_lst, sent_nll):
+                        cur_word = word
+                        word_nll = -word_nll.cpu()
+                        if log_base is None:
+                            word_prob = np.exp(word_nll)
+                        else:
+                            word_prob = log_base ** (word_nll / np.log(log_base))
+                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+                            cur=cur_word, 
+                            pre=pre_word, 
+                            prob=round(word_prob.item(), 8),
+                            word_nll=round(word_nll.item(), 8)
+                            )
+                        pre_word = cur_word
+                    
+                    sent_nll_mean = sent_nll.mean().cpu().numpy()
+                    sent_nll_sum = sent_nll.sum().cpu().numpy()
+                    if log_base is None:
+                        sent_ppl = np.exp(sent_nll_mean)
+                    else:
+                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+                        sent_nll=round(-sent_nll_sum.item(), 4),
+                        sent_ppl=round(sent_ppl.item(), 4)
+                        )
+                    ppl_out_batch += ppl_out
+                    item = {'key': key, 'value': ppl_out}
+                    if writer is not None:
+                        writer["ppl"][key+":\n"] = ppl_out
+                    results.append(item)
+
+            return results
+                
+        # 3. Build data-iterator
+        loader = LMTask.build_streaming_iterator(
+            data_path_and_name_and_type,
+            dtype=dtype,
+            batch_size=batch_size,
+            key_file=key_file,
+            num_workers=num_workers,
+            preprocess_fn=preprocessor,
+            collate_fn=LMTask.build_collate_fn(train_args, False),
+            allow_variable_data_keys=allow_variable_data_keys,
+            inference=True,
+        )
+
+        # 4. Start for-loop
+        total_nll = 0.0
+        total_ntokens = 0
+        ppl_out_all = ""
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+
+            ppl_out_batch = ""
+            with torch.no_grad():
+                batch = to_device(batch, device)
+                if ngpu <= 1:
+                    # NOTE(kamo): data_parallel also should work with ngpu=1,
+                    # but for debuggability it's better to keep this block.
+                    nll, lengths = wrapped_model(**batch)
+                else:
+                    nll, lengths = data_parallel(
+                        wrapped_model, (), range(ngpu), module_kwargs=batch
+                    )
+                ## print ppl
+                ids2tokens = preprocessor.token_id_converter.ids2tokens
+                for key, sent_ids, sent_nll in zip(keys, batch['text'], nll):
+                    pre_word = "<s>"
+                    cur_word = None
+                    sent_lst = ids2tokens(sent_ids) + ['</s>']
+                    ppl_out = " ".join(sent_lst) + "\n"
+                    for word, word_nll in zip(sent_lst, sent_nll):
+                        cur_word = word
+                        word_nll = -word_nll.cpu()
+                        if log_base is None:
+                            word_prob = np.exp(word_nll)
+                        else:
+                            word_prob = log_base ** (word_nll / np.log(log_base))
+                        ppl_out += '    p( {cur} | {pre} ) = {prob} [ {word_nll} ]\n'.format(
+                            cur=cur_word, 
+                            pre=pre_word, 
+                            prob=round(word_prob.item(), 8),
+                            word_nll=round(word_nll.item(), 8)
+                            )
+                        pre_word = cur_word
+                    
+                    sent_nll_mean = sent_nll.mean().cpu().numpy()
+                    sent_nll_sum = sent_nll.sum().cpu().numpy()
+                    if log_base is None:
+                        sent_ppl = np.exp(sent_nll_mean)
+                    else:
+                        sent_ppl = log_base ** (sent_nll_mean / np.log(log_base))
+                    ppl_out += 'logprob= {sent_nll} ppl= {sent_ppl}\n\n'.format(
+                        sent_nll=round(-sent_nll_sum.item(), 4),
+                        sent_ppl=round(sent_ppl.item(), 4)
+                        )
+                    ppl_out_batch += ppl_out
+                    utt2nll = round(-sent_nll_sum.item(), 5)
+                    item = {'key': key, 'value': ppl_out}
+                    if writer is not None:
+                        writer["ppl"][key+":\n"] = ppl_out
+                        writer["utt2nll"][key] = str(utt2nll)
+                    results.append(item)
+
+            ppl_out_all += ppl_out_batch
+            
+            assert _bs == len(nll) == len(lengths), (_bs, len(nll), len(lengths))
+            # nll: (B, L) -> (B,)
+            nll = nll.detach().cpu().numpy().sum(1)
+            # lengths: (B,)
+            lengths = lengths.detach().cpu().numpy()
+            total_nll += nll.sum()
+            total_ntokens += lengths.sum()
+
+        if log_base is None:
+            ppl = np.exp(total_nll / total_ntokens)
+        else:
+            ppl = log_base ** (total_nll / total_ntokens / np.log(log_base))
+
+        avg_ppl = 'logprob= {total_nll} ppl= {total_ppl}\n'.format(
+            total_nll=round(-total_nll.item(), 4),
+            total_ppl=round(ppl.item(), 4)
+            )
+        item = {'key': 'AVG PPL', 'value': avg_ppl}
+        ppl_out_all += avg_ppl
+        if writer is not None:
+            writer["ppl"]["AVG PPL : "] = avg_ppl
+        results.append(item)
+
+        return results
+
+    return _forward
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="Calc perplexity",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument("--output_dir", type=str, required=False)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+    parser.add_argument(
+        "--log_base",
+        type=float_or_none,
+        default=10,
+        help="The base of logarithm for Perplexity. "
+             "If None, napier's constant is used.",
+        required=False
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument(
+        "--data_path_and_name_and_type",
+        type=str2triple_str,
+        action="append",
+        required=False
+    )
+    group.add_argument(
+        "--raw_inputs",
+        type=str,
+        required=False
+    )
+    group.add_argument("--key_file", type=str_or_none)
+    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+    group.add_argument("--split_with_space", type=str2bool, default=False)
+    group.add_argument("--seg_dict_file", type=str_or_none)
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument("--train_config", type=str)
+    group.add_argument("--model_file", type=str)
+
+    return parser
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    inference(**kwargs)
+
+if __name__ == "__main__":
+    main()
+
diff --git a/funasr/bin/lm_inference_launch.py b/funasr/bin/lm_inference_launch.py
new file mode 100644
index 0000000..492ebab
--- /dev/null
+++ b/funasr/bin/lm_inference_launch.py
@@ -0,0 +1,130 @@
+#!/usr/bin/env python3
+# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+import argparse
+import logging
+import os
+import sys
+from typing import Union, Dict, Any
+
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.utils.types import float_or_none
+
+
+def get_parser():
+    parser = config_argparse.ArgumentParser(
+        description="Calc perplexity",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+    parser.add_argument("--output_dir", type=str, required=True)
+    parser.add_argument("--gpuid_list", type=str, required=True)
+    parser.add_argument(
+        "--ngpu",
+        type=int,
+        default=0,
+        help="The number of gpus. 0 indicates CPU mode",
+    )
+    parser.add_argument("--seed", type=int, default=0, help="Random seed")
+    parser.add_argument("--njob", type=int, default=1, help="Random seed")
+    parser.add_argument(
+        "--dtype",
+        default="float32",
+        choices=["float16", "float32", "float64"],
+        help="Data type",
+    )
+    parser.add_argument(
+        "--num_workers",
+        type=int,
+        default=1,
+        help="The number of workers used for DataLoader",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=1,
+        help="The batch size for inference",
+    )
+    parser.add_argument(
+        "--log_base",
+        type=float_or_none,
+        default=10,
+        help="The base of logarithm for Perplexity. "
+             "If None, napier's constant is used.",
+        required=False
+    )
+
+    group = parser.add_argument_group("Input data related")
+    group.add_argument(
+        "--data_path_and_name_and_type",
+        type=str2triple_str,
+        action="append",
+        required=False
+    )
+    group.add_argument(
+        "--raw_inputs",
+        type=str,
+        required=False
+    )
+    group.add_argument("--key_file", type=str_or_none)
+    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+    group.add_argument("--split_with_space", type=str2bool, default=False)
+    group.add_argument("--seg_dict_file", type=str_or_none)
+
+    group = parser.add_argument_group("The model configuration related")
+    group.add_argument("--train_config", type=str)
+    group.add_argument("--model_file", type=str)
+    group.add_argument("--mode", type=str, default="lm")
+    return parser
+
+def inference_launch(mode, **kwargs):
+    if mode == "transformer":
+        from funasr.bin.lm_inference import inference_modelscope
+        return inference_modelscope(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    kwargs.pop("config", None)
+
+    # set logging messages
+    logging.basicConfig(
+        level=args.log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+    logging.info("Decoding args: {}".format(kwargs))
+
+    # gpu setting
+    if args.ngpu > 0:
+        jobid = int(args.output_dir.split(".")[-1])
+        gpuid = args.gpuid_list.split(",")[(jobid - 1) // args.njob]
+        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+        os.environ["CUDA_VISIBLE_DEVICES"] = gpuid
+
+    kwargs.pop("gpuid_list", None)
+    kwargs.pop("njob", None)
+    results = inference_launch(**kwargs)
+
+
+if __name__ == "__main__":
+    main()
+
diff --git a/funasr/bin/lm_train.py b/funasr/bin/lm_train.py
index faa7a45..8641465 100755
--- a/funasr/bin/lm_train.py
+++ b/funasr/bin/lm_train.py
@@ -1,22 +1,46 @@
 #!/usr/bin/env python3
+
+import os
+
 from funasr.tasks.lm import LMTask
 
 
-def get_parser():
+# for LM Training
+def parse_args():
     parser = LMTask.get_parser()
-    return parser
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
+    )
+    args = parser.parse_args()
+    return args
 
 
-def main(cmd=None):
-    """LM training.
-
-    Example:
-
-        % python lm_train.py asr --print_config --optim adadelta
-        % python lm_train.py --config conf/train_asr.yaml
-    """
-    LMTask.main(cmd=cmd)
+def main(args=None, cmd=None):
+    # for LM Training
+    LMTask.main(args=args, cmd=cmd)
 
 
-if __name__ == "__main__":
-    main()
+if __name__ == '__main__':
+    args = parse_args()
+
+    # setup local gpu_id
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+    # DDP settings
+    if args.ngpu > 1:
+        args.distributed = True
+    else:
+        args.distributed = False
+    assert args.num_worker_count == 1
+
+    # re-compute batch size: when dataset type is small
+    if args.dataset_type == "small" and args.ngpu != 0:
+        if args.batch_size is not None:
+            args.batch_size = args.batch_size * args.ngpu
+        if args.batch_bins is not None:
+            args.batch_bins = args.batch_bins * args.ngpu
+
+    main(args=args)
diff --git a/funasr/bin/tokenize_text.py b/funasr/bin/tokenize_text.py
new file mode 100755
index 0000000..dc565d0
--- /dev/null
+++ b/funasr/bin/tokenize_text.py
@@ -0,0 +1,283 @@
+#!/usr/bin/env python3
+import argparse
+from collections import Counter
+import logging
+from pathlib import Path
+import sys
+from typing import List
+from typing import Optional
+
+from typeguard import check_argument_types
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.text.build_tokenizer import build_tokenizer
+from funasr.text.cleaner import TextCleaner
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+
+def field2slice(field: Optional[str]) -> slice:
+    """Convert field string to slice
+
+    Note that field string accepts 1-based integer.
+
+    Examples:
+        >>> field2slice("1-")
+        slice(0, None, None)
+        >>> field2slice("1-3")
+        slice(0, 3, None)
+        >>> field2slice("-3")
+        slice(None, 3, None)
+    """
+    field = field.strip()
+    try:
+        if "-" in field:
+            # e.g. "2-" or "2-5" or "-7"
+            s1, s2 = field.split("-", maxsplit=1)
+            if s1.strip() == "":
+                s1 = None
+            else:
+                s1 = int(s1)
+                if s1 == 0:
+                    raise ValueError("1-based string")
+            if s2.strip() == "":
+                s2 = None
+            else:
+                s2 = int(s2)
+        else:
+            # e.g. "2"
+            s1 = int(field)
+            s2 = s1 + 1
+            if s1 == 0:
+                raise ValueError("must be 1 or more value")
+    except ValueError:
+        raise RuntimeError(f"Format error: e.g. '2-', '2-5', or '-5': {field}")
+
+    if s1 is None:
+        slic = slice(None, s2)
+    else:
+        # -1 because of 1-based integer following "cut" command
+        # e.g "1-3" -> slice(0, 3)
+        slic = slice(s1 - 1, s2)
+    return slic
+
+
+def tokenize(
+    input: str,
+    output: str,
+    field: Optional[str],
+    delimiter: Optional[str],
+    token_type: str,
+    space_symbol: str,
+    non_linguistic_symbols: Optional[str],
+    bpemodel: Optional[str],
+    log_level: str,
+    write_vocabulary: bool,
+    vocabulary_size: int,
+    remove_non_linguistic_symbols: bool,
+    cutoff: int,
+    add_symbol: List[str],
+    cleaner: Optional[str],
+    g2p: Optional[str],
+):
+    assert check_argument_types()
+
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+    if input == "-":
+        fin = sys.stdin
+    else:
+        fin = Path(input).open("r", encoding="utf-8")
+    if output == "-":
+        fout = sys.stdout
+    else:
+        p = Path(output)
+        p.parent.mkdir(parents=True, exist_ok=True)
+        fout = p.open("w", encoding="utf-8")
+
+    cleaner = TextCleaner(cleaner)
+    tokenizer = build_tokenizer(
+        token_type=token_type,
+        bpemodel=bpemodel,
+        delimiter=delimiter,
+        space_symbol=space_symbol,
+        non_linguistic_symbols=non_linguistic_symbols,
+        remove_non_linguistic_symbols=remove_non_linguistic_symbols,
+        g2p_type=g2p,
+    )
+
+    counter = Counter()
+    if field is not None:
+        field = field2slice(field)
+
+    for line in fin:
+        line = line.rstrip()
+        if field is not None:
+            # e.g. field="2-"
+            # uttidA hello world!! -> hello world!!
+            tokens = line.split(delimiter)
+            tokens = tokens[field]
+            if delimiter is None:
+                line = " ".join(tokens)
+            else:
+                line = delimiter.join(tokens)
+
+        line = cleaner(line)
+        tokens = tokenizer.text2tokens(line)
+        if not write_vocabulary:
+            fout.write(" ".join(tokens) + "\n")
+        else:
+            for t in tokens:
+                counter[t] += 1
+
+    if not write_vocabulary:
+        return
+    
+    ## FIXME
+    ## del duplicate add_symbols in counter
+    for symbol_and_id in add_symbol:
+        # e.g symbol="<blank>:0"
+        try:
+            symbol, idx = symbol_and_id.split(":")
+        except ValueError:
+            raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
+        symbol = symbol.strip()
+        if symbol in counter:
+            del counter[symbol]
+
+    # ======= write_vocabulary mode from here =======
+    # Sort by the number of occurrences in descending order
+    # and filter lower frequency words than cutoff value
+    words_and_counts = list(
+        filter(lambda x: x[1] > cutoff, sorted(counter.items(), key=lambda x: -x[1]))
+    )
+    # Restrict the vocabulary size
+    if vocabulary_size > 0:
+        if vocabulary_size < len(add_symbol):
+            raise RuntimeError(f"vocabulary_size is too small: {vocabulary_size}")
+        words_and_counts = words_and_counts[: vocabulary_size - len(add_symbol)]
+
+    # Parse the values of --add_symbol
+    for symbol_and_id in add_symbol:
+        # e.g symbol="<blank>:0"
+        try:
+            symbol, idx = symbol_and_id.split(":")
+            idx = int(idx)
+        except ValueError:
+            raise RuntimeError(f"Format error: e.g. '<blank>:0': {symbol_and_id}")
+        symbol = symbol.strip()
+
+        # e.g. idx=0  -> append as the first symbol
+        # e.g. idx=-1 -> append as the last symbol
+        if idx < 0:
+            idx = len(words_and_counts) + 1 + idx
+        words_and_counts.insert(idx, (symbol, None))
+
+    # Write words
+    for w, c in words_and_counts:
+        fout.write(w + "\n")
+
+    # Logging
+    total_count = sum(counter.values())
+    invocab_count = sum(c for w, c in words_and_counts if c is not None)
+    logging.info(f"OOV rate = {(total_count - invocab_count) / total_count * 100} %")
+
+
+def get_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(
+        description="Tokenize texts",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--log_level",
+        type=lambda x: x.upper(),
+        default="INFO",
+        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+        help="The verbose level of logging",
+    )
+
+    parser.add_argument(
+        "--input", "-i", required=True, help="Input text. - indicates sys.stdin"
+    )
+    parser.add_argument(
+        "--output", "-o", required=True, help="Output text. - indicates sys.stdout"
+    )
+    parser.add_argument(
+        "--field",
+        "-f",
+        help="The target columns of the input text as 1-based integer. e.g 2-",
+    )
+    parser.add_argument(
+        "--token_type",
+        "-t",
+        default="char",
+        choices=["char", "bpe", "word", "phn"],
+        help="Token type",
+    )
+    parser.add_argument("--delimiter", "-d", default=None, help="The delimiter")
+    parser.add_argument("--space_symbol", default="<space>", help="The space symbol")
+    parser.add_argument("--bpemodel", default=None, help="The bpemodel file path")
+    parser.add_argument(
+        "--non_linguistic_symbols",
+        type=str_or_none,
+        help="non_linguistic_symbols file path",
+    )
+    parser.add_argument(
+        "--remove_non_linguistic_symbols",
+        type=str2bool,
+        default=False,
+        help="Remove non-language-symbols from tokens",
+    )
+    parser.add_argument(
+        "--cleaner",
+        type=str_or_none,
+        choices=[None, "tacotron", "jaconv", "vietnamese", "korean_cleaner"],
+        default=None,
+        help="Apply text cleaning",
+    )
+    parser.add_argument(
+        "--g2p",
+        type=str_or_none,
+        choices=g2p_choices,
+        default=None,
+        help="Specify g2p method if --token_type=phn",
+    )
+
+    group = parser.add_argument_group("write_vocabulary mode related")
+    group.add_argument(
+        "--write_vocabulary",
+        type=str2bool,
+        default=False,
+        help="Write tokens list instead of tokenized text per line",
+    )
+    group.add_argument("--vocabulary_size", type=int, default=0, help="Vocabulary size")
+    group.add_argument(
+        "--cutoff",
+        default=0,
+        type=int,
+        help="cut-off frequency used for write-vocabulary mode",
+    )
+    group.add_argument(
+        "--add_symbol",
+        type=str,
+        default=[],
+        action="append",
+        help="Append symbol e.g. --add_symbol '<blank>:0' --add_symbol '<unk>:1'",
+    )
+
+    return parser
+
+
+def main(cmd=None):
+    print(get_commandline_args(), file=sys.stderr)
+    parser = get_parser()
+    args = parser.parse_args(cmd)
+    kwargs = vars(args)
+    tokenize(**kwargs)
+
+
+if __name__ == "__main__":
+    main()
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 10fbccb..79540c1 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -58,6 +58,15 @@
             continue
     return out_txt.strip().split()
 
+def seg_tokenize_wo_pattern(txt, seg_dict):
+    out_txt = ""
+    for word in txt:
+        if word in seg_dict:
+            out_txt += seg_dict[word] + " "
+        else:
+            out_txt += "<unk>" + " "
+    return out_txt.strip().split()
+
 
 def framing(
         x,
@@ -372,6 +381,70 @@
         data = self._text_process(data)
         return data
 
+## FIXME
+class LMPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: str = None,
+            token_list: Union[Path, str, Iterable[str]] = None,
+            bpemodel: Union[Path, str, Iterable[str]] = None,
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: str = "text",
+            split_with_space: bool = False,
+            seg_dict_file: str = None,
+    ):
+        super().__init__(train,
+                         token_type,
+                         token_list,
+                         bpemodel,
+                         text_cleaner,
+                         g2p_type,
+                         unk_symbol,
+                         space_symbol,
+                         non_linguistic_symbols,
+                         delimiter,
+                         rir_scp,
+                         rir_apply_prob,
+                         noise_scp,
+                         noise_apply_prob,
+                         noise_db_range,
+                         speech_volume_normalize,
+                         speech_name,
+                         text_name,
+                         split_with_space,
+                         seg_dict_file,
+                         )
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        if self.text_name in data and self.tokenizer is not None:
+            text = data[self.text_name]
+            text = self.text_cleaner(text)
+            if self.split_with_space:
+                tokens = text.strip().split(" ")
+                if self.seg_dict is not None:
+                    tokens = seg_tokenize_wo_pattern(tokens, self.seg_dict)
+            else:
+                tokens = self.tokenizer.text2tokens(text)
+            text_ints = self.token_id_converter.tokens2ids(tokens)
+            data[self.text_name] = np.array(text_ints, dtype=np.int64)
+        assert check_return_type(data)
+        return data
+
 
 class CommonPreprocessor_multi(AbsPreprocessor):
     def __init__(
diff --git a/funasr/lm/espnet_model.py b/funasr/lm/espnet_model.py
index 4fc3b49..db11b67 100644
--- a/funasr/lm/espnet_model.py
+++ b/funasr/lm/espnet_model.py
@@ -46,10 +46,10 @@
 
         # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
         # text: (Batch, Length) -> x, y: (Batch, Length + 1)
-        x = F.pad(text, [1, 0], "constant", self.eos)
+        x = F.pad(text, [1, 0], "constant", self.sos)
         t = F.pad(text, [0, 1], "constant", self.ignore_id)
         for i, l in enumerate(text_lengths):
-            t[i, l] = self.sos
+            t[i, l] = self.eos
         x_lengths = text_lengths + 1
 
         # 2. Forward Language model
diff --git a/funasr/models/decoder/contextual_decoder.py b/funasr/models/decoder/contextual_decoder.py
new file mode 100644
index 0000000..32f550a
--- /dev/null
+++ b/funasr/models/decoder/contextual_decoder.py
@@ -0,0 +1,776 @@
+from typing import List
+from typing import Tuple
+import logging
+import torch
+import torch.nn as nn
+import numpy as np
+
+from funasr.modules.streaming_utils import utils as myutils
+from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
+from typeguard import check_argument_types
+
+from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.modules.embedding import PositionalEncoding
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.modules.repeat import repeat
+from funasr.models.decoder.sanm_decoder import DecoderLayerSANM, ParaformerSANMDecoder
+
+
+class ContextualDecoderLayer(nn.Module):
+    def __init__(
+        self,
+        size,
+        self_attn,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+    ):
+        """Construct an DecoderLayer object."""
+        super(ContextualDecoderLayer, self).__init__()
+        self.size = size
+        self.self_attn = self_attn
+        self.src_attn = src_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        if self_attn is not None:
+            self.norm2 = LayerNorm(size)
+        if src_attn is not None:
+            self.norm3 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
+        # tgt = self.dropout(tgt)
+        if isinstance(tgt, Tuple):
+            tgt, _ = tgt
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+        tgt = self.feed_forward(tgt)
+
+        x = tgt
+        if self.normalize_before:
+            tgt = self.norm2(tgt)
+        if self.training:
+            cache = None
+        x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+        x = residual + self.dropout(x)
+        x_self_attn = x
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm3(x)
+        x = self.src_attn(x, memory, memory_mask)
+        x_src_attn = x
+
+        x = residual + self.dropout(x)
+        return x, tgt_mask, x_self_attn, x_src_attn
+
+
+class ContexutalBiasDecoder(nn.Module):
+    def __init__(
+        self,
+        size,
+        src_attn,
+        dropout_rate,
+        normalize_before=True,
+    ):
+        """Construct an DecoderLayer object."""
+        super(ContexutalBiasDecoder, self).__init__()
+        self.size = size
+        self.src_attn = src_attn
+        if src_attn is not None:
+            self.norm3 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+        x = tgt
+        if self.src_attn is not None:
+            if self.normalize_before:
+                x = self.norm3(x)
+            x =  self.dropout(self.src_attn(x, memory, memory_mask))
+        return x, tgt_mask, memory, memory_mask, cache
+
+
+class ContextualParaformerDecoder(ParaformerSANMDecoder):
+    """
+    author: Speech Lab, Alibaba Group, China
+    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+    https://arxiv.org/abs/2006.01713
+    """
+    def __init__(
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        att_layer_num: int = 6,
+        kernel_size: int = 21,
+        sanm_shfit: int = 0,
+    ):
+        assert check_argument_types()
+        super().__init__(
+            vocab_size=vocab_size,
+            encoder_output_size=encoder_output_size,
+            dropout_rate=dropout_rate,
+            positional_dropout_rate=positional_dropout_rate,
+            input_layer=input_layer,
+            use_output_layer=use_output_layer,
+            pos_enc_class=pos_enc_class,
+            normalize_before=normalize_before,
+        )
+
+        attention_dim = encoder_output_size
+        if input_layer == 'none':
+            self.embed = None
+        if input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(vocab_size, attention_dim),
+                # pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        elif input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(vocab_size, attention_dim),
+                torch.nn.LayerNorm(attention_dim),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(attention_dim, positional_dropout_rate),
+            )
+        else:
+            raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
+
+        self.normalize_before = normalize_before
+        if self.normalize_before:
+            self.after_norm = LayerNorm(attention_dim)
+        if use_output_layer:
+            self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
+        else:
+            self.output_layer = None
+
+        self.att_layer_num = att_layer_num
+        self.num_blocks = num_blocks
+        if sanm_shfit is None:
+            sanm_shfit = (kernel_size - 1) // 2
+        self.decoders = repeat(
+            att_layer_num - 1,
+            lambda lnum: DecoderLayerSANM(
+                attention_dim,
+                MultiHeadedAttentionSANMDecoder(
+                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+                ),
+                MultiHeadedAttentionCrossAtt(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        self.dropout = nn.Dropout(dropout_rate)
+        self.bias_decoder = ContexutalBiasDecoder(
+            size=attention_dim,
+            src_attn=MultiHeadedAttentionCrossAtt(
+                attention_heads, attention_dim, src_attention_dropout_rate
+            ),
+            dropout_rate=dropout_rate,
+            normalize_before=True,
+        )
+        self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
+        self.last_decoder = ContextualDecoderLayer(
+                attention_dim,
+                MultiHeadedAttentionSANMDecoder(
+                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+                ),
+                MultiHeadedAttentionCrossAtt(
+                    attention_heads, attention_dim, src_attention_dropout_rate
+                ),
+                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            )
+        if num_blocks - att_layer_num <= 0:
+            self.decoders2 = None
+        else:
+            self.decoders2 = repeat(
+                num_blocks - att_layer_num,
+                lambda lnum: DecoderLayerSANM(
+                    attention_dim,
+                    MultiHeadedAttentionSANMDecoder(
+                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+                    ),
+                    None,
+                    PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                    dropout_rate,
+                    normalize_before,
+                    concat_after,
+                ),
+            )
+
+        self.decoders3 = repeat(
+            1,
+            lambda lnum: DecoderLayerSANM(
+                attention_dim,
+                None,
+                None,
+                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+    def forward(
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
+        contextual_info: torch.Tensor,
+        return_hidden: bool = False,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward decoder.
+
+        Args:
+            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+            hlens: (batch)
+            ys_in_pad:
+                input token ids, int64 (batch, maxlen_out)
+                if input_layer == "embed"
+                input tensor (batch, maxlen_out, #mels) in the other cases
+            ys_in_lens: (batch)
+        Returns:
+            (tuple): tuple containing:
+
+            x: decoded token score before softmax (batch, maxlen_out, token)
+                if use_output_layer is True,
+            olens: (batch, )
+        """
+        tgt = ys_in_pad
+        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+        memory = hs_pad
+        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+        x = tgt
+        x, tgt_mask, memory, memory_mask, _ = self.decoders(
+            x, tgt_mask, memory, memory_mask
+        )
+        _, _, x_self_attn, x_src_attn = self.last_decoder(
+            x, tgt_mask, memory, memory_mask
+        )
+
+        # contextual paraformer related
+        contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
+        contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
+        cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
+
+        if self.bias_output is not None:
+            x = torch.cat([x_src_attn, cx], dim=2)
+            x = self.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
+            x = x_self_attn + self.dropout(x)
+
+        if self.decoders2 is not None:
+            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
+                x, tgt_mask, memory, memory_mask
+            )
+
+        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
+            x, tgt_mask, memory, memory_mask
+        )
+        if self.normalize_before:
+            x = self.after_norm(x)
+        olens = tgt_mask.sum(1)
+        if self.output_layer is not None and return_hidden is False:
+            x = self.output_layer(x)
+        return x, olens
+
+    def gen_tf2torch_map_dict(self):
+
+        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
+        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
+        map_dict_local = {
+
+            ## decoder
+            # ffn
+            "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+
+            # fsmn
+            "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
+                    tensor_name_prefix_tf),
+                    "squeeze": None,
+                    "transpose": None,
+                },  # (256,),(256,)
+            "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
+                    tensor_name_prefix_tf),
+                    "squeeze": None,
+                    "transpose": None,
+                },  # (256,),(256,)
+            "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
+                    tensor_name_prefix_tf),
+                    "squeeze": 0,
+                    "transpose": (1, 2, 0),
+                },  # (256,1,31),(1,31,256,1)
+            # src att
+            "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,256),(1,256,256)
+            "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,256),(1,256,256)
+            "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            # dnn
+            "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+
+            # embed_concat_ffn
+            "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,1024),(1,1024,256)
+
+            # out norm
+            "{}.after_norm.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.after_norm.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+
+            # in embed
+            "{}.embed.0.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/w_embs".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (4235,256),(4235,256)
+
+            # out layer
+            "{}.output_layer.weight".format(tensor_name_prefix_torch):
+                {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
+                 "squeeze": [None, None],
+                 "transpose": [(1, 0), None],
+                 },  # (4235,256),(256,4235)
+            "{}.output_layer.bias".format(tensor_name_prefix_torch):
+                {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
+                          "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
+                 "squeeze": [None, None],
+                 "transpose": [None, None],
+                 },  # (4235,),(4235,)
+
+            ## clas decoder
+            # src att
+            "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,256),(1,256,256)
+            "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (1024,256),(1,256,1024)
+            "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (1024,),(1024,)
+            "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": 0,
+                 "transpose": (1, 0),
+                 },  # (256,256),(1,256,256)
+            "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (256,),(256,)
+            # dnn
+            "{}.bias_output.weight".format(tensor_name_prefix_torch):
+                {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": (2, 1, 0),
+                 },  # (1024,256),(1,256,1024)
+
+        }
+        return map_dict_local
+
+    def convert_tf2torch(self,
+                         var_dict_tf,
+                         var_dict_torch,
+                         ):
+        map_dict = self.gen_tf2torch_map_dict()
+        var_dict_torch_update = dict()
+        decoder_layeridx_sets = set()
+        for name in sorted(var_dict_torch.keys(), reverse=False):
+            names = name.split('.')
+            if names[0] == self.tf2torch_tensor_name_prefix_torch:
+                if names[1] == "decoders":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+                elif names[1] == "last_decoder":
+                    layeridx = 15
+                    name_q = name.replace("last_decoder", "decoders.layeridx")
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+
+
+                elif names[1] == "decoders2":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+                    name_q = name_q.replace("decoders2", "decoders")
+                    layeridx_bias = len(decoder_layeridx_sets)
+
+                    layeridx += layeridx_bias
+                    if "decoders." in name:
+                        decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+
+                elif names[1] == "decoders3":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    if "decoders." in name:
+                        decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+                elif names[1] == "bias_decoder":
+                    name_q = name
+
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+
+
+                elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
+                    name_tf = map_dict[name]["name"]
+                    if isinstance(name_tf, list):
+                        idx_list = 0
+                        if name_tf[idx_list] in var_dict_tf.keys():
+                            pass
+                        else:
+                            idx_list = 1
+                        data_tf = var_dict_tf[name_tf[idx_list]]
+                        if map_dict[name]["squeeze"][idx_list] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
+                        if map_dict[name]["transpose"][idx_list] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
+                                                                                                   name_tf[idx_list],
+                                                                                                   var_dict_tf[name_tf[
+                                                                                                       idx_list]].shape))
+
+                    else:
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+                        if map_dict[name]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+                                                                                          var_dict_tf[name_tf].shape))
+
+                elif names[1] == "after_norm":
+                    name_tf = map_dict[name]["name"]
+                    data_tf = var_dict_tf[name_tf]
+                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                    var_dict_torch_update[name] = data_tf
+                    logging.info(
+                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+                                                                                      var_dict_tf[name_tf].shape))
+
+                elif names[1] == "embed_concat_ffn":
+                    layeridx = int(names[2])
+                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
+
+                    layeridx_bias = 0
+                    layeridx += layeridx_bias
+                    if "decoders." in name:
+                        decoder_layeridx_sets.add(layeridx)
+                    if name_q in map_dict.keys():
+                        name_v = map_dict[name_q]["name"]
+                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
+                        data_tf = var_dict_tf[name_tf]
+                        if map_dict[name_q]["squeeze"] is not None:
+                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                        if map_dict[name_q]["transpose"] is not None:
+                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                        var_dict_torch[
+                                                                                                            name].size(),
+                                                                                                        data_tf.size())
+                        var_dict_torch_update[name] = data_tf
+                        logging.info(
+                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                          var_dict_tf[name_tf].shape))
+
+        return var_dict_torch_update
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 7596896..5786bc4 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -8,6 +8,8 @@
 from typing import Union
 
 import torch
+import random
+import numpy as np
 from typeguard import check_argument_types
 
 from funasr.layers.abs_normalize import AbsNormalize
@@ -24,7 +26,7 @@
 from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.modules.add_sos_eos import add_sos_eos
-from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.nets_utils import make_pad_mask, pad_list
 from funasr.modules.nets_utils import th_accuracy
 from funasr.torch_utils.device_funcs import force_gatherable
 from funasr.train.abs_espnet_model import AbsESPnetModel
@@ -824,7 +826,10 @@
 
 class BiCifParaformer(Paraformer):
 
-    """CTC-attention hybrid Encoder-Decoder model"""
+    """
+    Paraformer model with an extra cif predictor
+    to conduct accurate timestamp prediction
+    """
 
     def __init__(
         self,
@@ -891,7 +896,7 @@
         )
         assert isinstance(self.predictor, CifPredictorV3), "BiCifParaformer should use CIFPredictorV3"
 
-    def _calc_att_loss(
+    def _calc_pre2_loss(
             self,
             encoder_out: torch.Tensor,
             encoder_out_lens: torch.Tensor,
@@ -903,47 +908,12 @@
         if self.predictor_bias == 1:
             _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
             ys_pad_lens = ys_pad_lens + self.predictor_bias
-        pre_acoustic_embeds, pre_token_length, _, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask,
-                                                                                  ignore_id=self.ignore_id)
+        _, _, _, _, pre_token_length2 = self.predictor(encoder_out, ys_pad, encoder_out_mask, ignore_id=self.ignore_id)
 
-        # 0. sampler
-        decoder_out_1st = None
-        if self.sampling_ratio > 0.0:
-            if self.step_cur < 2:
-                logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
-            sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
-                                                           pre_acoustic_embeds)
-        else:
-            if self.step_cur < 2:
-                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
-            sematic_embeds = pre_acoustic_embeds
+        # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+        loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length2), pre_token_length2)
 
-        # 1. Forward decoder
-        decoder_outs = self.decoder(
-            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
-        )
-        decoder_out, _ = decoder_outs[0], decoder_outs[1]
-
-        if decoder_out_1st is None:
-            decoder_out_1st = decoder_out
-        # 2. Compute attention loss
-        loss_att = self.criterion_att(decoder_out, ys_pad)
-        acc_att = th_accuracy(
-            decoder_out_1st.view(-1, self.vocab_size),
-            ys_pad,
-            ignore_label=self.ignore_id,
-        )
-        loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
-        loss_pre2 = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length2)
-
-        # Compute cer/wer using attention-decoder
-        if self.training or self.error_calculator is None:
-            cer_att, wer_att = None, None
-        else:
-            ys_hat = decoder_out_1st.argmax(dim=-1)
-            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
-        return loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2
+        return loss_pre2
     
     def calc_predictor(self, encoder_out, encoder_out_lens):
 
@@ -956,10 +926,154 @@
     def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-        ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out, None, encoder_out_mask, token_num=token_num,
-                                                                                  ignore_id=self.ignore_id)
-        import pdb; pdb.set_trace()
+        ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
+                                                                                               encoder_out_mask,
+                                                                                               token_num)
         return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
+
+    def forward(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+            text: torch.Tensor,
+            text_lengths: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Frontend + Encoder + Decoder + Calc loss
+
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+        assert text_lengths.dim() == 1, text_lengths.shape
+        # Check that batch_size is unified
+        assert (
+                speech.shape[0]
+                == speech_lengths.shape[0]
+                == text.shape[0]
+                == text_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+        batch_size = speech.shape[0]
+        self.step_cur += 1
+        # for data-parallel
+        text = text[:, : text_lengths.max()]
+        speech = speech[:, :speech_lengths.max()]
+
+        # 1. Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+        stats = dict()
+
+        loss_pre2 = self._calc_pre2_loss(
+            encoder_out, encoder_out_lens, text, text_lengths
+        )
+
+        loss = loss_pre2
+
+        stats["loss_pre2"] = loss_pre2.detach().cpu()
+        stats["loss"] = torch.clone(loss.detach())
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+
+class ContextualParaformer(Paraformer):
+    """
+    Paraformer model with contextual hotword
+    """
+
+    def __init__(
+            self,
+            vocab_size: int,
+            token_list: Union[Tuple[str, ...], List[str]],
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            preencoder: Optional[AbsPreEncoder],
+            encoder: AbsEncoder,
+            postencoder: Optional[AbsPostEncoder],
+            decoder: AbsDecoder,
+            ctc: CTC,
+            ctc_weight: float = 0.5,
+            interctc_weight: float = 0.0,
+            ignore_id: int = -1,
+            blank_id: int = 0,
+            sos: int = 1,
+            eos: int = 2,
+            lsm_weight: float = 0.0,
+            length_normalized_loss: bool = False,
+            report_cer: bool = True,
+            report_wer: bool = True,
+            sym_space: str = "<space>",
+            sym_blank: str = "<blank>",
+            extract_feats_in_collect_stats: bool = True,
+            predictor=None,
+            predictor_weight: float = 0.0,
+            predictor_bias: int = 0,
+            sampling_ratio: float = 0.2,
+            min_hw_length: int = 2,
+            max_hw_length: int = 4,
+            sample_rate: float = 0.6,
+            batch_rate: float = 0.5,
+            double_rate: float = -1.0,
+            target_buffer_length: int = -1,
+            inner_dim: int = 256,
+            bias_encoder_type: str = 'lstm',
+            label_bracket: bool = False,
+    ):
+        assert check_argument_types()
+        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
+        assert 0.0 <= interctc_weight < 1.0, interctc_weight
+
+        super().__init__(
+            vocab_size=vocab_size,
+            token_list=token_list,
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            preencoder=preencoder,
+            encoder=encoder,
+            postencoder=postencoder,
+            decoder=decoder,
+            ctc=ctc,
+            ctc_weight=ctc_weight,
+            interctc_weight=interctc_weight,
+            ignore_id=ignore_id,
+            blank_id=blank_id,
+            sos=sos,
+            eos=eos,
+            lsm_weight=lsm_weight,
+            length_normalized_loss=length_normalized_loss,
+            report_cer=report_cer,
+            report_wer=report_wer,
+            sym_space=sym_space,
+            sym_blank=sym_blank,
+            extract_feats_in_collect_stats=extract_feats_in_collect_stats,
+            predictor=predictor,
+            predictor_weight=predictor_weight,
+            predictor_bias=predictor_bias,
+            sampling_ratio=sampling_ratio,
+        )
+
+        if bias_encoder_type == 'lstm':
+            logging.warning("enable bias encoder sampling and contextual training")
+            self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=0)
+            self.bias_embed = torch.nn.Embedding(vocab_size, inner_dim)
+        else:
+            logging.error("Unsupport bias encoder type")
+
+        self.min_hw_length = min_hw_length
+        self.max_hw_length = max_hw_length
+        self.sample_rate = sample_rate
+        self.batch_rate = batch_rate
+        self.target_buffer_length = target_buffer_length
+        self.double_rate = double_rate
+
+        if self.target_buffer_length > 0:
+            self.hotword_buffer = None
+            self.length_record = []
+            self.current_buffer_length = 0
 
     def forward(
             self,
@@ -1038,17 +1152,17 @@
 
         # 2b. Attention decoder branch
         if self.ctc_weight != 1.0:
-            loss_att, acc_att, cer_att, wer_att, loss_pre, loss_pre2 = self._calc_att_loss(
+            loss_att, acc_att, cer_att, wer_att, loss_pre = self._calc_att_loss(
                 encoder_out, encoder_out_lens, text, text_lengths
             )
 
         # 3. CTC-Att loss definition
         if self.ctc_weight == 0.0:
-            loss = loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight
+            loss = loss_att + loss_pre * self.predictor_weight
         elif self.ctc_weight == 1.0:
             loss = loss_ctc
         else:
-            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight + loss_pre2 * self.predictor_weight
+            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att + loss_pre * self.predictor_weight
 
         # Collect Attn branch stats
         stats["loss_att"] = loss_att.detach() if loss_att is not None else None
@@ -1056,10 +1170,292 @@
         stats["cer"] = cer_att
         stats["wer"] = wer_att
         stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
-        stats["loss_pre2"] = loss_pre2.detach().cpu() if loss_pre is not None else None
 
         stats["loss"] = torch.clone(loss.detach())
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-        return loss, stats, weight
\ No newline at end of file
+        return loss, stats, weight
+
+    def _sample_hot_word(self, ys_pad, ys_pad_lens):
+        hw_list = [torch.Tensor([0]).long().to(ys_pad.device)]
+        hw_lengths = [0]  # this length is actually for indice, so -1
+        for i, length in enumerate(ys_pad_lens):
+            if length < 2:
+                continue
+            if length > self.min_hw_length + self.max_hw_length + 2 and random.random() < self.double_rate:
+                # sample double hotword
+                _max_hw_length = min(self.max_hw_length, length // 2)
+                # first hotword
+                start1 = random.randint(0, length // 3)
+                end1 = random.randint(start1 + self.min_hw_length - 1, start1 + _max_hw_length - 1)
+                hw_tokens1 = ys_pad[i][start1:end1 + 1]
+                hw_lengths.append(len(hw_tokens1) - 1)
+                hw_list.append(hw_tokens1)
+                # second hotword
+                start2 = random.randint(end1 + 1, length - self.min_hw_length)
+                end2 = random.randint(min(length - 1, start2 + self.min_hw_length - 1),
+                                      min(length - 1, start2 + self.max_hw_length - 1))
+                hw_tokens2 = ys_pad[i][start2:end2 + 1]
+                hw_lengths.append(len(hw_tokens2) - 1)
+                hw_list.append(hw_tokens2)
+                continue
+            if random.random() < self.sample_rate:
+                if length == 2:
+                    hw_tokens = ys_pad[i][:2]
+                    hw_lengths.append(1)
+                    hw_list.append(hw_tokens)
+                else:
+                    start = random.randint(0, length - self.min_hw_length)
+                    end = random.randint(min(length - 1, start + self.min_hw_length - 1),
+                                         min(length - 1, start + self.max_hw_length - 1)) + 1
+                    # print(start, end)
+                    hw_tokens = ys_pad[i][start:end]
+                    hw_lengths.append(len(hw_tokens) - 1)
+                    hw_list.append(hw_tokens)
+        # padding
+        hw_list_pad = pad_list(hw_list, 0)
+        hw_embed = self.decoder.embed(hw_list_pad)
+        hw_embed, (_, _) = self.bias_encoder(hw_embed)
+        _ind = np.arange(0, len(hw_list)).tolist()
+        # update self.hotword_buffer, throw a part if oversize
+        selected = hw_embed[_ind, hw_lengths]
+        if self.target_buffer_length > 0:
+            _b = selected.shape[0]
+            if self.hotword_buffer is None:
+                self.hotword_buffer = selected
+                self.length_record.append(selected.shape[0])
+                self.current_buffer_length = _b
+            elif self.current_buffer_length + _b < self.target_buffer_length:
+                self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
+                self.current_buffer_length += _b
+                selected = self.hotword_buffer
+            else:
+                self.hotword_buffer = torch.cat([self.hotword_buffer.detach(), selected], dim=0)
+                random_throw = random.randint(self.target_buffer_length // 2, self.target_buffer_length) + 10
+                self.hotword_buffer = self.hotword_buffer[-1 * random_throw:]
+                selected = self.hotword_buffer
+                self.current_buffer_length = selected.shape[0]
+        return selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
+
+    def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds, contextual_info):
+
+        tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
+        ys_pad = ys_pad * tgt_mask[:, :, 0]
+        if self.share_embedding:
+            ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
+        else:
+            ys_pad_embed = self.decoder.embed(ys_pad)
+        with torch.no_grad():
+            decoder_outs = self.decoder(
+                encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens, contextual_info=contextual_info
+            )
+            decoder_out, _ = decoder_outs[0], decoder_outs[1]
+            pred_tokens = decoder_out.argmax(-1)
+            nonpad_positions = ys_pad.ne(self.ignore_id)
+            seq_lens = (nonpad_positions).sum(1)
+            same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
+            input_mask = torch.ones_like(nonpad_positions)
+            bsz, seq_len = ys_pad.size()
+            for li in range(bsz):
+                target_num = (((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio).long()
+                if target_num > 0:
+                    input_mask[li].scatter_(dim=0, index=torch.randperm(seq_lens[li])[:target_num].cuda(), value=0)
+            input_mask = input_mask.eq(1)
+            input_mask = input_mask.masked_fill(~nonpad_positions, False)
+            input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
+
+        sematic_embeds = pre_acoustic_embeds.masked_fill(~input_mask_expand_dim, 0) + ys_pad_embed.masked_fill(
+            input_mask_expand_dim, 0)
+        return sematic_embeds * tgt_mask, decoder_out * tgt_mask
+
+    def _calc_att_loss(
+            self,
+            encoder_out: torch.Tensor,
+            encoder_out_lens: torch.Tensor,
+            ys_pad: torch.Tensor,
+            ys_pad_lens: torch.Tensor,
+    ):
+        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+            encoder_out.device)
+        if self.predictor_bias == 1:
+            _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+            ys_pad_lens = ys_pad_lens + self.predictor_bias
+        pre_acoustic_embeds, pre_token_length, _, pre_peak_index = self.predictor(encoder_out, ys_pad,
+                                                                                  encoder_out_mask,
+                                                                                  ignore_id=self.ignore_id)
+
+        # sample hot word
+        contextual_info = self._sample_hot_word(ys_pad, ys_pad_lens)
+
+        # 0. sampler
+        decoder_out_1st = None
+        if self.sampling_ratio > 0.0:
+            if self.step_cur < 2:
+                logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+            sematic_embeds, decoder_out_1st = self.sampler(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens,
+                                                           pre_acoustic_embeds, contextual_info)
+        else:
+            if self.step_cur < 2:
+                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
+            sematic_embeds = pre_acoustic_embeds
+
+        # 1. Forward decoder
+        decoder_outs = self.decoder(
+            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+        )
+        decoder_out, _ = decoder_outs[0], decoder_outs[1]
+
+        if decoder_out_1st is None:
+            decoder_out_1st = decoder_out
+        # 2. Compute attention loss
+        loss_att = self.criterion_att(decoder_out, ys_pad)
+        acc_att = th_accuracy(
+            decoder_out_1st.view(-1, self.vocab_size),
+            ys_pad,
+            ignore_label=self.ignore_id,
+        )
+        loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+
+        # Compute cer/wer using attention-decoder
+        if self.training or self.error_calculator is None:
+            cer_att, wer_att = None, None
+        else:
+            ys_hat = decoder_out_1st.argmax(dim=-1)
+            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+        return loss_att, acc_att, cer_att, wer_att, loss_pre
+
+    def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, hw_list=None):
+        if hw_list is None:
+            # default hotword list
+            hw_list = [torch.Tensor([self.sos]).long().to(encoder_out.device)]  # empty hotword list
+            hw_list_pad = pad_list(hw_list, 0)
+            hw_embed = self.bias_embed(hw_list_pad)
+            _, (h_n, _) = self.bias_encoder(hw_embed)
+            contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
+        else:
+            hw_lengths = [len(i) for i in hw_list]
+            hw_list_pad = pad_list([torch.Tensor(i).long() for i in hw_list], 0).to(encoder_out.device)
+            hw_embed = self.bias_embed(hw_list_pad)
+            hw_embed = torch.nn.utils.rnn.pack_padded_sequence(hw_embed, hw_lengths, batch_first=True,
+                                                               enforce_sorted=False)
+            _, (h_n, _) = self.bias_encoder(hw_embed)
+            # hw_embed, _ = torch.nn.utils.rnn.pad_packed_sequence(hw_embed, batch_first=True)
+            contextual_info = h_n.squeeze(0).repeat(encoder_out.shape[0], 1, 1)
+        decoder_outs = self.decoder(
+            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
+        )
+        decoder_out = decoder_outs[0]
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+        return decoder_out, ys_pad_lens
+
+    def gen_clas_tf2torch_map_dict(self):
+        tensor_name_prefix_torch = "bias_encoder"
+        tensor_name_prefix_tf = "seq2seq/clas_charrnn"
+
+        tensor_name_prefix_torch_emb = "bias_embed"
+        tensor_name_prefix_tf_emb = "seq2seq"
+
+        map_dict_local = {
+            # in lstm
+            "{}.weight_ih_l0".format(tensor_name_prefix_torch):
+                {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": (1, 0),
+                 "slice": (0, 512),
+                 "unit_k": 512,
+                 },  # (1024, 2048),(2048,512)
+            "{}.weight_hh_l0".format(tensor_name_prefix_torch):
+                {"name": "{}/rnn/lstm_cell/kernel".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": (1, 0),
+                 "slice": (512, 1024),
+                 "unit_k": 512,
+                 },  # (1024, 2048),(2048,512)
+            "{}.bias_ih_l0".format(tensor_name_prefix_torch):
+                {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 "scale": 0.5,
+                 "unit_b": 512,
+                 },  # (2048,),(2048,)
+            "{}.bias_hh_l0".format(tensor_name_prefix_torch):
+                {"name": "{}/rnn/lstm_cell/bias".format(tensor_name_prefix_tf),
+                 "squeeze": None,
+                 "transpose": None,
+                 "scale": 0.5,
+                 "unit_b": 512,
+                 },  # (2048,),(2048,)
+
+            # in embed
+            "{}.weight".format(tensor_name_prefix_torch_emb):
+                {"name": "{}/contextual_encoder/w_char_embs".format(tensor_name_prefix_tf_emb),
+                 "squeeze": None,
+                 "transpose": None,
+                 },  # (4235,256),(4235,256)
+        }
+        return map_dict_local
+
+    def clas_convert_tf2torch(self,
+                              var_dict_tf,
+                              var_dict_torch):
+        map_dict = self.gen_clas_tf2torch_map_dict()
+        var_dict_torch_update = dict()
+        for name in sorted(var_dict_torch.keys(), reverse=False):
+            names = name.split('.')
+            if names[0] == "bias_encoder":
+                name_q = name
+                if name_q in map_dict.keys():
+                    name_v = map_dict[name_q]["name"]
+                    name_tf = name_v
+                    data_tf = var_dict_tf[name_tf]
+                    if map_dict[name_q].get("unit_k") is not None:
+                        dim = map_dict[name_q]["unit_k"]
+                        i = data_tf[:, 0:dim].copy()
+                        f = data_tf[:, dim:2 * dim].copy()
+                        o = data_tf[:, 2 * dim:3 * dim].copy()
+                        g = data_tf[:, 3 * dim:4 * dim].copy()
+                        data_tf = np.concatenate([i, o, f, g], axis=1)
+                    if map_dict[name_q].get("unit_b") is not None:
+                        dim = map_dict[name_q]["unit_b"]
+                        i = data_tf[0:dim].copy()
+                        f = data_tf[dim:2 * dim].copy()
+                        o = data_tf[2 * dim:3 * dim].copy()
+                        g = data_tf[3 * dim:4 * dim].copy()
+                        data_tf = np.concatenate([i, o, f, g], axis=0)
+                    if map_dict[name_q]["squeeze"] is not None:
+                        data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
+                    if map_dict[name_q].get("slice") is not None:
+                        data_tf = data_tf[map_dict[name_q]["slice"][0]:map_dict[name_q]["slice"][1]]
+                    if map_dict[name_q].get("scale") is not None:
+                        data_tf = data_tf * map_dict[name_q]["scale"]
+                    if map_dict[name_q]["transpose"] is not None:
+                        data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
+                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                    assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                    var_dict_torch[
+                                                                                                        name].size(),
+                                                                                                    data_tf.size())
+                    var_dict_torch_update[name] = data_tf
+                    logging.info(
+                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
+                                                                                      var_dict_tf[name_tf].shape))
+            elif names[0] == "bias_embed":
+                name_tf = map_dict[name]["name"]
+                data_tf = var_dict_tf[name_tf]
+                if map_dict[name]["squeeze"] is not None:
+                    data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
+                if map_dict[name]["transpose"] is not None:
+                    data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
+                data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
+                assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
+                                                                                                var_dict_torch[
+                                                                                                    name].size(),
+                                                                                                data_tf.size())
+                var_dict_torch_update[name] = data_tf
+                logging.info(
+                    "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
+                                                                                  var_dict_tf[name_tf].shape))
+
+        return var_dict_torch_update
\ No newline at end of file
diff --git a/funasr/models/predictor/cif.py b/funasr/models/predictor/cif.py
index c34759d..5615373 100644
--- a/funasr/models/predictor/cif.py
+++ b/funasr/models/predictor/cif.py
@@ -544,9 +544,8 @@
             token_num_int = torch.max(token_num).type(torch.int32).item()
             acoustic_embeds = acoustic_embeds[:, :token_num_int, :]
         return acoustic_embeds, token_num, alphas, cif_peak, token_num2
-    
-    def get_upsample_timestamp(self, hidden, target_label=None, mask=None, ignore_id=-1, mask_chunk_predictor=None,
-                target_label_length=None, token_num=None):
+
+    def get_upsample_timestamp(self, hidden, mask=None, token_num=None):
         h = hidden
         b = hidden.shape[0]
         context = h.transpose(1, 2)
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 7899400..02311fd 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -43,6 +43,7 @@
 from funasr.iterators.chunk_iter_factory import ChunkIterFactory
 from funasr.iterators.multiple_iter_factory import MultipleIterFactory
 from funasr.iterators.sequence_iter_factory import SequenceIterFactory
+from funasr.main_funcs.collect_stats import collect_stats
 from funasr.optimizers.sgd import SGD
 from funasr.optimizers.fairseq_adam import FairseqAdam
 from funasr.samplers.build_batch_sampler import BATCH_TYPES
@@ -1272,6 +1273,52 @@
 
         if args.dry_run:
             pass
+        elif args.collect_stats:
+            # Perform on collect_stats mode. This mode has two roles
+            # - Derive the length and dimension of all input data
+            # - Accumulate feats, square values, and the length for whitening
+
+            if args.valid_batch_size is None:
+                args.valid_batch_size = args.batch_size
+
+            if len(args.train_shape_file) != 0:
+                train_key_file = args.train_shape_file[0]
+            else:
+                train_key_file = None
+            if len(args.valid_shape_file) != 0:
+                valid_key_file = args.valid_shape_file[0]
+            else:
+                valid_key_file = None
+
+            collect_stats(
+                model=model,
+                train_iter=cls.build_streaming_iterator(
+                    data_path_and_name_and_type=args.train_data_path_and_name_and_type,
+                    key_file=train_key_file,
+                    batch_size=args.batch_size,
+                    dtype=args.train_dtype,
+                    num_workers=args.num_workers,
+                    allow_variable_data_keys=args.allow_variable_data_keys,
+                    ngpu=args.ngpu,
+                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
+                    collate_fn=cls.build_collate_fn(args, train=False),
+                ),
+                valid_iter=cls.build_streaming_iterator(
+                    data_path_and_name_and_type=args.valid_data_path_and_name_and_type,
+                    key_file=valid_key_file,
+                    batch_size=args.valid_batch_size,
+                    dtype=args.train_dtype,
+                    num_workers=args.num_workers,
+                    allow_variable_data_keys=args.allow_variable_data_keys,
+                    ngpu=args.ngpu,
+                    preprocess_fn=cls.build_preprocess_fn(args, train=False),
+                    collate_fn=cls.build_collate_fn(args, train=False),
+                ),
+                output_dir=output_dir,
+                ngpu=args.ngpu,
+                log_interval=args.log_interval,
+                write_collected_feats=args.write_collected_feats,
+            )
         else:
             logging.info("Training args: {}".format(args))
             # 6. Loads pre-trained model
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 1b7f152..e62a748 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -37,8 +37,9 @@
 )
 from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
 from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
 from funasr.models.e2e_asr import ESPnetASRModel
-from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer
+from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
 from funasr.models.e2e_uni_asr import UniASR
 from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.encoder.conformer_encoder import ConformerEncoder
@@ -117,6 +118,7 @@
         paraformer=Paraformer,
         paraformer_bert=ParaformerBert,
         bicif_paraformer=BiCifParaformer,
+        contextual_paraformer=ContextualParaformer,
     ),
     type_check=AbsESPnetModel,
     default="asr",
@@ -177,6 +179,7 @@
         fsmn_scama_opt=FsmnDecoderSCAMAOpt,
         paraformer_decoder_sanm=ParaformerSANMDecoder,
         paraformer_decoder_san=ParaformerDecoderSAN,
+        contextual_paraformer_decoder=ContextualParaformerDecoder,
     ),
     type_check=AbsDecoder,
     default="rnn",
@@ -1098,5 +1101,8 @@
         # decoder
         var_dict_torch_update_local = model.decoder.convert_tf2torch(var_dict_tf, var_dict_torch)
         var_dict_torch_update.update(var_dict_torch_update_local)
+        # bias_encoder
+        var_dict_torch_update_local = model.clas_convert_tf2torch(var_dict_tf, var_dict_torch)
+        var_dict_torch_update.update(var_dict_torch_update_local)
 
         return var_dict_torch_update
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index 46b9fe0..608c1d3 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -58,7 +58,7 @@
         # NOTE(kamo): add_arguments(..., required=True) can't be used
         # to provide --print_config mode. Instead of it, do as
         required = parser.get_default("required")
-        required += ["token_list"]
+        # required += ["token_list"]
 
         group.add_argument(
             "--token_list",
diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py
index 4da0d59..575fb90 100644
--- a/funasr/utils/postprocess_utils.py
+++ b/funasr/utils/postprocess_utils.py
@@ -232,5 +232,9 @@
         return sentence, ts_lists, real_word_lists
     else:
         word_lists = abbr_dispose(word_lists)
+        real_word_lists = []
+        for ch in word_lists:
+            if ch != ' ':
+                real_word_lists.append(ch)
         sentence = ''.join(word_lists).strip()
-        return sentence
+        return sentence, real_word_lists
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 3afaa40..33d1255 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -86,14 +86,51 @@
     else:
         return time_stamp_list
 
-
-def time_stamp_lfr6_advance(tst: List, text: str):
-    # advanced timestamp prediction for BiCIF_Paraformer using upsampled alphas
-    ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = tst
-    if text.endswith('</s>'):
-        text = text[:-4]
+def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
+    START_END_THRESHOLD = 5
+    TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
+    if len(us_alphas.shape) == 3:
+        alphas, cif_peak = us_alphas[0], us_cif_peak[0]  # support inference batch_size=1 only
     else:
-        text = text[:-1]
-        logging.warning("found text does not end with </s>")
-    assert int(ds_alphas.sum() + 1e-4) - 1 == len(text)
-    
+        alphas, cif_peak = us_alphas, us_cif_peak
+    num_frames = cif_peak.shape[0]
+    if char_list[-1] == '</s>':
+        char_list = char_list[:-1]
+    # char_list = [i for i in text]
+    timestamp_list = []
+    # for bicif model trained with large data, cif2 actually fires when a character starts
+    # so treat the frames between two peaks as the duration of the former token
+    fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5
+    num_peak = len(fire_place)
+    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
+    # begin silence
+    if fire_place[0] > START_END_THRESHOLD:
+        char_list.insert(0, '<sil>')
+        timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
+    # tokens timestamp
+    for i in range(len(fire_place)-1):
+        # the peak is always a little ahead of the start time
+        # timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE])
+        timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE])
+        # cut the duration to token and sil of the 0-weight frames last long
+    # tail token and end silence
+    if num_frames - fire_place[-1] > START_END_THRESHOLD:
+        _end = (num_frames + fire_place[-1]) / 2
+        timestamp_list[-1][1] = _end*TIME_RATE
+        timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
+        char_list.append("<sil>")
+    else:
+        timestamp_list[-1][1] = num_frames*TIME_RATE
+    if begin_time:  # add offset time in model with vad
+        for i in range(len(timestamp_list)):
+            timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0
+            timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0
+    res_txt = ""
+    for char, timestamp in zip(char_list, timestamp_list):
+        res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1])
+    res = []
+    for char, timestamp in zip(char_list, timestamp_list):
+        if char != '<sil>':
+            res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
+    return res
+

--
Gitblit v1.9.1