From 0cc7af2c894a31f1b6d95d8af7e4efa414e2a11b Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 07 四月 2023 15:54:21 +0800
Subject: [PATCH] Merge pull request #328 from alibaba-damo-academy/dev_cmz2

---
 funasr/bin/punctuation_infer.py                                |    2 
 funasr/export/models/encoder/fsmn_encoder.py                   |  296 +++++++
 funasr/train/abs_model.py                                      |   54 +
 funasr/export/models/e2e_vad.py                                |   60 +
 funasr/export/test/test_onnx.py                                |    0 
 funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py      |  249 ++++++
 funasr/tasks/vad.py                                            |   25 
 funasr/export/export_model.py                                  |   52 
 funasr/export/models/__init__.py                               |   18 
 funasr/tasks/punctuation.py                                    |   14 
 funasr/export/test/test_onnx_vad.py                            |   26 
 funasr/models/target_delay_transformer.py                      |    5 
 funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py   |   36 
 funasr/datasets/preprocessor.py                                |   14 
 funasr/runtime/python/onnxruntime/demo.py                      |    2 
 funasr/runtime/python/onnxruntime/demo_punc_online.py          |   15 
 funasr/bin/punctuation_infer_vadrealtime.py                    |    2 
 funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py      |    3 
 funasr/export/test/test_torchscripts.py                        |    0 
 funasr/lm/abs_model.py                                         |  129 +++
 funasr/runtime/python/onnxruntime/setup.py                     |    2 
 funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py       |  143 +++
 funasr/runtime/python/onnxruntime/demo_punc_offline.py         |    8 
 funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py |  607 +++++++++++++++
 funasr/models/encoder/sanm_encoder.py                          |  232 +++++
 funasr/export/models/encoder/sanm_encoder.py                   |  104 ++
 funasr/export/test/test_onnx_punc.py                           |   18 
 funasr/tasks/lm.py                                             |    8 
 funasr/export/models/target_delay_transformer.py               |  154 +++
 funasr/export/test/test_onnx_punc_vadrealtime.py               |   22 
 funasr/models/vad_realtime_transformer.py                      |    4 
 funasr/export/test/__init__.py                                 |    0 
 /dev/null                                                      |   12 
 funasr/bin/asr_inference_uniasr.py                             |    3 
 funasr/runtime/python/onnxruntime/demo_vad.py                  |   30 
 funasr/models/e2e_vad.py                                       |    6 
 funasr/runtime/python/libtorch/funasr_torch/utils/utils.py     |    2 
 37 files changed, 2,283 insertions(+), 74 deletions(-)

diff --git a/funasr/bin/asr_inference_uniasr.py b/funasr/bin/asr_inference_uniasr.py
index 2e5b6f5..4aea720 100644
--- a/funasr/bin/asr_inference_uniasr.py
+++ b/funasr/bin/asr_inference_uniasr.py
@@ -37,9 +37,6 @@
 from funasr.models.frontend.wav_frontend import WavFrontend
 
 
-header_colors = '\033[95m'
-end_colors = '\033[0m'
-
 
 class Speech2Text:
     """Speech2Text class
diff --git a/funasr/bin/punctuation_infer.py b/funasr/bin/punctuation_infer.py
index a801ee8..dd28ef8 100644
--- a/funasr/bin/punctuation_infer.py
+++ b/funasr/bin/punctuation_infer.py
@@ -23,7 +23,7 @@
 from funasr.utils import config_argparse
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
-from funasr.punctuation.text_preprocessor import split_to_mini_sentence
+from funasr.datasets.preprocessor import split_to_mini_sentence
 
 
 class Text2Punc:
diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
index ce1cee8..81f9d7a 100644
--- a/funasr/bin/punctuation_infer_vadrealtime.py
+++ b/funasr/bin/punctuation_infer_vadrealtime.py
@@ -23,7 +23,7 @@
 from funasr.utils import config_argparse
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
-from funasr.punctuation.text_preprocessor import split_to_mini_sentence
+from funasr.datasets.preprocessor import split_to_mini_sentence
 
 
 class Text2Punc:
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 98cca1d..afeff4e 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -800,3 +800,17 @@
                     data[self.vad_name] = np.array([vad], dtype=np.int64)
                 text_ints = self.token_id_converter[i].tokens2ids(tokens)
                 data[text_name] = np.array(text_ints, dtype=np.int64)
+
+
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+    assert word_limit > 1
+    if len(words) <= word_limit:
+        return [words]
+    sentences = []
+    length = len(words)
+    sentence_len = length // word_limit
+    for i in range(sentence_len):
+        sentences.append(words[i * word_limit:(i + 1) * word_limit])
+    if length % word_limit > 0:
+        sentences.append(words[sentence_len * word_limit:])
+    return sentences
\ No newline at end of file
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index d3d119c..b69eeee 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -167,31 +167,57 @@
     
     def export(self,
                tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
-               mode: str = 'paraformer',
+               mode: str = None,
                ):
         
         model_dir = tag_name
-        if model_dir.startswith('damo/'):
+        if model_dir.startswith('damo'):
             from modelscope.hub.snapshot_download import snapshot_download
             model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir)
-        asr_train_config = os.path.join(model_dir, 'config.yaml')
-        asr_model_file = os.path.join(model_dir, 'model.pb')
-        cmvn_file = os.path.join(model_dir, 'am.mvn')
-        json_file = os.path.join(model_dir, 'configuration.json')
+
         if mode is None:
             import json
+            json_file = os.path.join(model_dir, 'configuration.json')
             with open(json_file, 'r') as f:
                 config_data = json.load(f)
-                mode = config_data['model']['model_config']['mode']
+                if config_data['task'] == "punctuation":
+                    mode = config_data['model']['punc_model_config']['mode']
+                else:
+                    mode = config_data['model']['model_config']['mode']
         if mode.startswith('paraformer'):
             from funasr.tasks.asr import ASRTaskParaformer as ASRTask
-        elif mode.startswith('uniasr'):
-            from funasr.tasks.asr import ASRTaskUniASR as ASRTask
+            config = os.path.join(model_dir, 'config.yaml')
+            model_file = os.path.join(model_dir, 'model.pb')
+            cmvn_file = os.path.join(model_dir, 'am.mvn')
+            model, asr_train_args = ASRTask.build_model_from_file(
+                config, model_file, cmvn_file, 'cpu'
+            )
+            self.frontend = model.frontend
+        elif mode.startswith('offline'):
+            from funasr.tasks.vad import VADTask
+            config = os.path.join(model_dir, 'vad.yaml')
+            model_file = os.path.join(model_dir, 'vad.pb')
+            cmvn_file = os.path.join(model_dir, 'vad.mvn')
             
-        model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, cmvn_file, 'cpu'
-        )
-        self.frontend = model.frontend
+            model, vad_infer_args = VADTask.build_model_from_file(
+                config, model_file, cmvn_file=cmvn_file, device='cpu'
+            )
+            self.export_config["feats_dim"] = 400
+            self.frontend = model.frontend
+        elif mode.startswith('punc'):
+            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
+            punc_train_config = os.path.join(model_dir, 'config.yaml')
+            punc_model_file = os.path.join(model_dir, 'punc.pb')
+            model, punc_train_args = PUNCTask.build_model_from_file(
+                punc_train_config, punc_model_file, 'cpu'
+            )
+        elif mode.startswith('punc_VadRealtime'):
+            from funasr.tasks.punctuation import PunctuationTask as PUNCTask
+            punc_train_config = os.path.join(model_dir, 'config.yaml')
+            punc_model_file = os.path.join(model_dir, 'punc.pb')
+            model, punc_train_args = PUNCTask.build_model_from_file(
+                punc_train_config, punc_model_file, 'cpu'
+            )
         self._export(model, tag_name)
             
 
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index 0012377..f81ff64 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -1,13 +1,25 @@
 from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
 from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
 from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
-from funasr.models.e2e_uni_asr import UniASR
-
+from funasr.models.e2e_vad import E2EVadModel
+from funasr.export.models.e2e_vad import E2EVadModel as E2EVadModel_export
+from funasr.models.target_delay_transformer import TargetDelayTransformer
+from funasr.export.models.target_delay_transformer import CT_Transformer as CT_Transformer_export
+from funasr.train.abs_model import PunctuationModel
+from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.export.models.target_delay_transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
 
 def get_model(model, export_config=None):
     if isinstance(model, BiCifParaformer):
         return BiCifParaformer_export(model, **export_config)
     elif isinstance(model, Paraformer):
         return Paraformer_export(model, **export_config)
+    elif isinstance(model, E2EVadModel):
+        return E2EVadModel_export(model, **export_config)
+    elif isinstance(model, PunctuationModel):
+        if isinstance(model.punc_model, TargetDelayTransformer):
+            return CT_Transformer_export(model.punc_model, **export_config)
+        elif isinstance(model.punc_model, VadRealtimeTransformer):
+            return CT_Transformer_VadRealtime_export(model.punc_model, **export_config)
     else:
-        raise "Funasr does not support the given model type currently."
\ No newline at end of file
+        raise "Funasr does not support the given model type currently."
diff --git a/funasr/export/models/e2e_vad.py b/funasr/export/models/e2e_vad.py
new file mode 100644
index 0000000..d3e8f30
--- /dev/null
+++ b/funasr/export/models/e2e_vad.py
@@ -0,0 +1,60 @@
+from enum import Enum
+from typing import List, Tuple, Dict, Any
+
+import torch
+from torch import nn
+import math
+
+from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.export.models.encoder.fsmn_encoder import FSMN as FSMN_export
+
+class E2EVadModel(nn.Module):
+    def __init__(self, model,
+                max_seq_len=512,
+                feats_dim=400,
+                model_name='model',
+                **kwargs,):
+        super(E2EVadModel, self).__init__()
+        self.feats_dim = feats_dim
+        self.max_seq_len = max_seq_len
+        self.model_name = model_name
+        if isinstance(model.encoder, FSMN):
+            self.encoder = FSMN_export(model.encoder)
+        else:
+            raise "unsupported encoder"
+        
+
+    def forward(self, feats: torch.Tensor, *args, ):
+
+        scores, out_caches = self.encoder(feats, *args)
+        return scores, out_caches
+
+    def get_dummy_inputs(self, frame=30):
+        speech = torch.randn(1, frame, self.feats_dim)
+        in_cache0 = torch.randn(1, 128, 19, 1)
+        in_cache1 = torch.randn(1, 128, 19, 1)
+        in_cache2 = torch.randn(1, 128, 19, 1)
+        in_cache3 = torch.randn(1, 128, 19, 1)
+        
+        return (speech, in_cache0, in_cache1, in_cache2, in_cache3)
+
+    # def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
+    #     import numpy as np
+    #     fbank = np.loadtxt(txt_file)
+    #     fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
+    #     speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
+    #     speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
+    #     return (speech, speech_lengths)
+
+    def get_input_names(self):
+        return ['speech', 'in_cache0', 'in_cache1', 'in_cache2', 'in_cache3']
+
+    def get_output_names(self):
+        return ['logits', 'out_cache0', 'out_cache1', 'out_cache2', 'out_cache3']
+
+    def get_dynamic_axes(self):
+        return {
+            'speech': {
+                1: 'feats_length'
+            },
+        }
diff --git a/funasr/export/models/encoder/fsmn_encoder.py b/funasr/export/models/encoder/fsmn_encoder.py
new file mode 100755
index 0000000..b8e6433
--- /dev/null
+++ b/funasr/export/models/encoder/fsmn_encoder.py
@@ -0,0 +1,296 @@
+from typing import Tuple, Dict
+import copy
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from funasr.models.encoder.fsmn_encoder import BasicBlock
+
+class LinearTransform(nn.Module):
+
+    def __init__(self, input_dim, output_dim):
+        super(LinearTransform, self).__init__()
+        self.input_dim = input_dim
+        self.output_dim = output_dim
+        self.linear = nn.Linear(input_dim, output_dim, bias=False)
+
+    def forward(self, input):
+        output = self.linear(input)
+
+        return output
+
+
+class AffineTransform(nn.Module):
+
+    def __init__(self, input_dim, output_dim):
+        super(AffineTransform, self).__init__()
+        self.input_dim = input_dim
+        self.output_dim = output_dim
+        self.linear = nn.Linear(input_dim, output_dim)
+
+    def forward(self, input):
+        output = self.linear(input)
+
+        return output
+
+
+class RectifiedLinear(nn.Module):
+
+    def __init__(self, input_dim, output_dim):
+        super(RectifiedLinear, self).__init__()
+        self.dim = input_dim
+        self.relu = nn.ReLU()
+        self.dropout = nn.Dropout(0.1)
+
+    def forward(self, input):
+        out = self.relu(input)
+        return out
+
+
+class FSMNBlock(nn.Module):
+
+    def __init__(
+            self,
+            input_dim: int,
+            output_dim: int,
+            lorder=None,
+            rorder=None,
+            lstride=1,
+            rstride=1,
+    ):
+        super(FSMNBlock, self).__init__()
+
+        self.dim = input_dim
+
+        if lorder is None:
+            return
+
+        self.lorder = lorder
+        self.rorder = rorder
+        self.lstride = lstride
+        self.rstride = rstride
+
+        self.conv_left = nn.Conv2d(
+            self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
+
+        if self.rorder > 0:
+            self.conv_right = nn.Conv2d(
+                self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
+        else:
+            self.conv_right = None
+
+    def forward(self, input: torch.Tensor, cache: torch.Tensor):
+        x = torch.unsqueeze(input, 1)
+        x_per = x.permute(0, 3, 2, 1)  # B D T C
+        
+        cache = cache.to(x_per.device)
+        y_left = torch.cat((cache, x_per), dim=2)
+        cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
+        y_left = self.conv_left(y_left)
+        out = x_per + y_left
+
+        if self.conv_right is not None:
+            # maybe need to check
+            y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
+            y_right = y_right[:, :, self.rstride:, :]
+            y_right = self.conv_right(y_right)
+            out += y_right
+
+        out_per = out.permute(0, 3, 2, 1)
+        output = out_per.squeeze(1)
+
+        return output, cache
+
+
+class BasicBlock_export(nn.Module):
+    def __init__(self,
+                 model,
+                 ):
+        super(BasicBlock_export, self).__init__()
+        self.linear = model.linear
+        self.fsmn_block = model.fsmn_block
+        self.affine = model.affine
+        self.relu = model.relu
+
+    def forward(self, input: torch.Tensor, in_cache: torch.Tensor):
+        x = self.linear(input)  # B T D
+        # cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
+        # if cache_layer_name not in in_cache:
+        #     in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
+        x, out_cache = self.fsmn_block(x, in_cache)
+        x = self.affine(x)
+        x = self.relu(x)
+        return x, out_cache
+
+
+# class FsmnStack(nn.Sequential):
+#     def __init__(self, *args):
+#         super(FsmnStack, self).__init__(*args)
+#
+#     def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
+#         x = input
+#         for module in self._modules.values():
+#             x = module(x, in_cache)
+#         return x
+
+
+'''
+FSMN net for keyword spotting
+input_dim:              input dimension
+linear_dim:             fsmn input dimensionll
+proj_dim:               fsmn projection dimension
+lorder:                 fsmn left order
+rorder:                 fsmn right order
+num_syn:                output dimension
+fsmn_layers:            no. of sequential fsmn layers
+'''
+
+
+class FSMN(nn.Module):
+    def __init__(
+            self, model,
+    ):
+        super(FSMN, self).__init__()
+        
+        # self.input_dim = input_dim
+        # self.input_affine_dim = input_affine_dim
+        # self.fsmn_layers = fsmn_layers
+        # self.linear_dim = linear_dim
+        # self.proj_dim = proj_dim
+        # self.output_affine_dim = output_affine_dim
+        # self.output_dim = output_dim
+        #
+        # self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
+        # self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
+        # self.relu = RectifiedLinear(linear_dim, linear_dim)
+        # self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
+        #                         range(fsmn_layers)])
+        # self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
+        # self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
+        # self.softmax = nn.Softmax(dim=-1)
+        self.in_linear1 = model.in_linear1
+        self.in_linear2 = model.in_linear2
+        self.relu = model.relu
+        # self.fsmn = model.fsmn
+        self.out_linear1 = model.out_linear1
+        self.out_linear2 = model.out_linear2
+        self.softmax = model.softmax
+        self.fsmn = model.fsmn
+        for i, d in enumerate(model.fsmn):
+            if isinstance(d, BasicBlock):
+                self.fsmn[i] = BasicBlock_export(d)
+
+    def fuse_modules(self):
+        pass
+
+    def forward(
+            self,
+            input: torch.Tensor,
+            *args,
+    ):
+        """
+        Args:
+            input (torch.Tensor): Input tensor (B, T, D)
+            in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs,
+            {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
+        """
+
+        x = self.in_linear1(input)
+        x = self.in_linear2(x)
+        x = self.relu(x)
+        # x4 = self.fsmn(x3, in_cache)  # self.in_cache will update automatically in self.fsmn
+        out_caches = list()
+        for i, d in enumerate(self.fsmn):
+            in_cache = args[i]
+            x, out_cache = d(x, in_cache)
+            out_caches.append(out_cache)
+        x = self.out_linear1(x)
+        x = self.out_linear2(x)
+        x = self.softmax(x)
+
+        return x, out_caches
+
+
+'''
+one deep fsmn layer
+dimproj:                projection dimension, input and output dimension of memory blocks
+dimlinear:              dimension of mapping layer
+lorder:                 left order
+rorder:                 right order
+lstride:                left stride
+rstride:                right stride
+'''
+
+
+class DFSMN(nn.Module):
+
+    def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
+        super(DFSMN, self).__init__()
+
+        self.lorder = lorder
+        self.rorder = rorder
+        self.lstride = lstride
+        self.rstride = rstride
+
+        self.expand = AffineTransform(dimproj, dimlinear)
+        self.shrink = LinearTransform(dimlinear, dimproj)
+
+        self.conv_left = nn.Conv2d(
+            dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
+
+        if rorder > 0:
+            self.conv_right = nn.Conv2d(
+                dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
+        else:
+            self.conv_right = None
+
+    def forward(self, input):
+        f1 = F.relu(self.expand(input))
+        p1 = self.shrink(f1)
+
+        x = torch.unsqueeze(p1, 1)
+        x_per = x.permute(0, 3, 2, 1)
+
+        y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
+
+        if self.conv_right is not None:
+            y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
+            y_right = y_right[:, :, self.rstride:, :]
+            out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
+        else:
+            out = x_per + self.conv_left(y_left)
+
+        out1 = out.permute(0, 3, 2, 1)
+        output = input + out1.squeeze(1)
+
+        return output
+
+
+'''
+build stacked dfsmn layers
+'''
+
+
+def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
+    repeats = [
+        nn.Sequential(
+            DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
+        for i in range(fsmn_layers)
+    ]
+
+    return nn.Sequential(*repeats)
+
+
+if __name__ == '__main__':
+    fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
+    print(fsmn)
+
+    num_params = sum(p.numel() for p in fsmn.parameters())
+    print('the number of model params: {}'.format(num_params))
+    x = torch.zeros(128, 200, 400)  # batch-size * time * dim
+    y, _ = fsmn(x)  # batch-size * time * dim
+    print('input shape: {}'.format(x.shape))
+    print('output shape: {}'.format(y.shape))
+
+    print(fsmn.to_kaldi_net())
diff --git a/funasr/export/models/encoder/sanm_encoder.py b/funasr/export/models/encoder/sanm_encoder.py
index 8a50538..f583f56 100644
--- a/funasr/export/models/encoder/sanm_encoder.py
+++ b/funasr/export/models/encoder/sanm_encoder.py
@@ -9,6 +9,7 @@
 from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
 from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
 
+
 class SANMEncoder(nn.Module):
     def __init__(
         self,
@@ -107,3 +108,106 @@
             }
 
         }
+
+
+class SANMVadEncoder(nn.Module):
+    def __init__(
+        self,
+        model,
+        max_seq_len=512,
+        feats_dim=560,
+        model_name='encoder',
+        onnx: bool = True,
+    ):
+        super().__init__()
+        self.embed = model.embed
+        self.model = model
+        self.feats_dim = feats_dim
+        self._output_size = model._output_size
+        
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+        
+        if hasattr(model, 'encoders0'):
+            for i, d in enumerate(self.model.encoders0):
+                if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+                    d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
+                if isinstance(d.feed_forward, PositionwiseFeedForward):
+                    d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
+                self.model.encoders0[i] = EncoderLayerSANM_export(d)
+        
+        for i, d in enumerate(self.model.encoders):
+            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+                d.self_attn = MultiHeadedAttentionSANM_export(d.self_attn)
+            if isinstance(d.feed_forward, PositionwiseFeedForward):
+                d.feed_forward = PositionwiseFeedForward_export(d.feed_forward)
+            self.model.encoders[i] = EncoderLayerSANM_export(d)
+        
+        self.model_name = model_name
+        self.num_heads = model.encoders[0].self_attn.h
+        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
+    
+    def prepare_mask(self, mask, sub_masks):
+        mask_3d_btd = mask[:, :, None]
+        mask_4d_bhlt = (1 - sub_masks) * -10000.0
+        
+        return mask_3d_btd, mask_4d_bhlt
+    
+    def forward(self,
+                speech: torch.Tensor,
+                speech_lengths: torch.Tensor,
+                vad_masks: torch.Tensor,
+                sub_masks: torch.Tensor,
+                ):
+        speech = speech * self._output_size ** 0.5
+        mask = self.make_pad_mask(speech_lengths)
+        vad_masks = self.prepare_mask(mask, vad_masks)
+        mask = self.prepare_mask(mask, sub_masks)
+        
+        if self.embed is None:
+            xs_pad = speech
+        else:
+            xs_pad = self.embed(speech)
+        
+        encoder_outs = self.model.encoders0(xs_pad, mask)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        
+        # encoder_outs = self.model.encoders(xs_pad, mask)
+        for layer_idx, encoder_layer in enumerate(self.model.encoders):
+            if layer_idx == len(self.model.encoders) - 1:
+                mask = vad_masks
+            encoder_outs = encoder_layer(xs_pad, mask)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        
+        xs_pad = self.model.after_norm(xs_pad)
+        
+        return xs_pad, speech_lengths
+    
+    def get_output_size(self):
+        return self.model.encoders[0].size
+    
+    # def get_dummy_inputs(self):
+    #     feats = torch.randn(1, 100, self.feats_dim)
+    #     return (feats)
+    #
+    # def get_input_names(self):
+    #     return ['feats']
+    #
+    # def get_output_names(self):
+    #     return ['encoder_out', 'encoder_out_lens', 'predictor_weight']
+    #
+    # def get_dynamic_axes(self):
+    #     return {
+    #         'feats': {
+    #             1: 'feats_length'
+    #         },
+    #         'encoder_out': {
+    #             1: 'enc_out_length'
+    #         },
+    #         'predictor_weight': {
+    #             1: 'pre_out_length'
+    #         }
+    #
+    #     }
diff --git a/funasr/export/models/target_delay_transformer.py b/funasr/export/models/target_delay_transformer.py
new file mode 100644
index 0000000..2780d82
--- /dev/null
+++ b/funasr/export/models/target_delay_transformer.py
@@ -0,0 +1,154 @@
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from funasr.models.encoder.sanm_encoder import SANMEncoder
+from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
+from funasr.models.encoder.sanm_encoder import SANMVadEncoder
+from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadEncoder_export
+
+class CT_Transformer(nn.Module):
+
+    def __init__(
+            self,
+            model,
+            max_seq_len=512,
+            model_name='punc_model',
+            **kwargs,
+    ):
+        super().__init__()
+        onnx = False
+        if "onnx" in kwargs:
+            onnx = kwargs["onnx"]
+        self.embed = model.embed
+        self.decoder = model.decoder
+        # self.model = model
+        self.feats_dim = self.embed.embedding_dim
+        self.num_embeddings = self.embed.num_embeddings
+        self.model_name = model_name
+
+        if isinstance(model.encoder, SANMEncoder):
+            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
+        else:
+            assert False, "Only support samn encode."
+
+    def forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+        """Compute loss value from buffer sequences.
+
+        Args:
+            input (torch.Tensor): Input ids. (batch, len)
+            hidden (torch.Tensor): Target ids. (batch, len)
+
+        """
+        x = self.embed(inputs)
+        # mask = self._target_mask(input)
+        h, _ = self.encoder(x, text_lengths)
+        y = self.decoder(h)
+        return y
+
+    def get_dummy_inputs(self):
+        length = 120
+        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
+        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
+        return (text_indexes, text_lengths)
+
+    def get_input_names(self):
+        return ['inputs', 'text_lengths']
+
+    def get_output_names(self):
+        return ['logits']
+
+    def get_dynamic_axes(self):
+        return {
+            'inputs': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'text_lengths': {
+                0: 'batch_size',
+            },
+            'logits': {
+                0: 'batch_size',
+                1: 'logits_length'
+            },
+        }
+
+
+class CT_Transformer_VadRealtime(nn.Module):
+
+    def __init__(
+        self,
+        model,
+        max_seq_len=512,
+        model_name='punc_model',
+        **kwargs,
+    ):
+        super().__init__()
+        onnx = False
+        if "onnx" in kwargs:
+            onnx = kwargs["onnx"]
+
+        self.embed = model.embed
+        if isinstance(model.encoder, SANMVadEncoder):
+            self.encoder = SANMVadEncoder_export(model.encoder, onnx=onnx)
+        else:
+            assert False, "Only support samn encode."
+        self.decoder = model.decoder
+        self.model_name = model_name
+
+
+
+    def forward(self, inputs: torch.Tensor,
+                text_lengths: torch.Tensor,
+                vad_indexes: torch.Tensor,
+                sub_masks: torch.Tensor,
+                ) -> Tuple[torch.Tensor, None]:
+        """Compute loss value from buffer sequences.
+
+        Args:
+            input (torch.Tensor): Input ids. (batch, len)
+            hidden (torch.Tensor): Target ids. (batch, len)
+
+        """
+        x = self.embed(inputs)
+        # mask = self._target_mask(input)
+        h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
+        y = self.decoder(h)
+        return y
+
+    def with_vad(self):
+        return True
+
+    def get_dummy_inputs(self):
+        length = 120
+        text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
+        text_lengths = torch.tensor([length], dtype=torch.int32)
+        vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
+        sub_masks = torch.ones(length, length, dtype=torch.float32)
+        sub_masks = torch.tril(sub_masks).type(torch.float32)
+        return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
+
+    def get_input_names(self):
+        return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
+
+    def get_output_names(self):
+        return ['logits']
+
+    def get_dynamic_axes(self):
+        return {
+            'inputs': {
+                1: 'feats_length'
+            },
+            'vad_masks': {
+                2: 'feats_length1',
+                3: 'feats_length2'
+            },
+            'sub_masks': {
+                2: 'feats_length1',
+                3: 'feats_length2'
+            },
+            'logits': {
+                1: 'logits_length'
+            },
+        }
diff --git a/funasr/punctuation/__init__.py b/funasr/export/test/__init__.py
similarity index 100%
rename from funasr/punctuation/__init__.py
rename to funasr/export/test/__init__.py
diff --git a/funasr/export/test_onnx.py b/funasr/export/test/test_onnx.py
similarity index 100%
rename from funasr/export/test_onnx.py
rename to funasr/export/test/test_onnx.py
diff --git a/funasr/export/test/test_onnx_punc.py b/funasr/export/test/test_onnx_punc.py
new file mode 100644
index 0000000..39f85f4
--- /dev/null
+++ b/funasr/export/test/test_onnx_punc.py
@@ -0,0 +1,18 @@
+import onnxruntime
+import numpy as np
+
+
+if __name__ == '__main__':
+    onnx_path = "../damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.onnx"
+    sess = onnxruntime.InferenceSession(onnx_path)
+    input_name = [nd.name for nd in sess.get_inputs()]
+    output_name = [nd.name for nd in sess.get_outputs()]
+
+    def _get_feed_dict(text_length):
+        return {'inputs': np.ones((1, text_length), dtype=np.int64), 'text_lengths': np.array([text_length,], dtype=np.int32)}
+
+    def _run(feed_dict):
+        output = sess.run(output_name, input_feed=feed_dict)
+        for name, value in zip(output_name, output):
+            print('{}: {}'.format(name, value))
+    _run(_get_feed_dict(10))
diff --git a/funasr/export/test/test_onnx_punc_vadrealtime.py b/funasr/export/test/test_onnx_punc_vadrealtime.py
new file mode 100644
index 0000000..86be026
--- /dev/null
+++ b/funasr/export/test/test_onnx_punc_vadrealtime.py
@@ -0,0 +1,22 @@
+import onnxruntime
+import numpy as np
+
+
+if __name__ == '__main__':
+    onnx_path = "./export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727/model.onnx"
+    sess = onnxruntime.InferenceSession(onnx_path)
+    input_name = [nd.name for nd in sess.get_inputs()]
+    output_name = [nd.name for nd in sess.get_outputs()]
+
+    def _get_feed_dict(text_length):
+        return {'inputs': np.ones((1, text_length), dtype=np.int64),
+                'text_lengths': np.array([text_length,], dtype=np.int32),
+                'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
+                'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
+                }
+
+    def _run(feed_dict):
+        output = sess.run(output_name, input_feed=feed_dict)
+        for name, value in zip(output_name, output):
+            print('{}: {}'.format(name, value))
+    _run(_get_feed_dict(10))
diff --git a/funasr/export/test/test_onnx_vad.py b/funasr/export/test/test_onnx_vad.py
new file mode 100644
index 0000000..12f058f
--- /dev/null
+++ b/funasr/export/test/test_onnx_vad.py
@@ -0,0 +1,26 @@
+import onnxruntime
+import numpy as np
+
+
+if __name__ == '__main__':
+    onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"
+    sess = onnxruntime.InferenceSession(onnx_path)
+    input_name = [nd.name for nd in sess.get_inputs()]
+    output_name = [nd.name for nd in sess.get_outputs()]
+
+    def _get_feed_dict(feats_length):
+        
+        return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32),
+                'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32),
+                'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32),
+                'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32),
+                'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32),
+                }
+
+    def _run(feed_dict):
+        output = sess.run(output_name, input_feed=feed_dict)
+        for name, value in zip(output_name, output):
+            print('{}: {}'.format(name, value.shape))
+
+    _run(_get_feed_dict(100))
+    _run(_get_feed_dict(200))
\ No newline at end of file
diff --git a/funasr/export/test_torchscripts.py b/funasr/export/test/test_torchscripts.py
similarity index 100%
rename from funasr/export/test_torchscripts.py
rename to funasr/export/test/test_torchscripts.py
diff --git a/funasr/lm/abs_model.py b/funasr/lm/abs_model.py
index 0ad1e71..1f3c8a7 100644
--- a/funasr/lm/abs_model.py
+++ b/funasr/lm/abs_model.py
@@ -5,7 +5,17 @@
 import torch
 
 from funasr.modules.scorers.scorer_interface import BatchScorerInterface
+from typing import Dict
+from typing import Optional
+from typing import Tuple
 
+import torch
+import torch.nn.functional as F
+from typeguard import check_argument_types
+
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
 
 class AbsLM(torch.nn.Module, BatchScorerInterface, ABC):
     """The abstract LM class
@@ -27,3 +37,122 @@
         self, input: torch.Tensor, hidden: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         raise NotImplementedError
+
+
+class LanguageModel(AbsESPnetModel):
+    def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
+        assert check_argument_types()
+        super().__init__()
+        self.lm = lm
+        self.sos = 1
+        self.eos = 2
+
+        # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
+        self.ignore_id = ignore_id
+
+    def nll(
+        self,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        max_length: Optional[int] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute negative log likelihood(nll)
+
+        Normally, this function is called in batchify_nll.
+        Args:
+            text: (Batch, Length)
+            text_lengths: (Batch,)
+            max_lengths: int
+        """
+        batch_size = text.size(0)
+        # For data parallel
+        if max_length is None:
+            text = text[:, : text_lengths.max()]
+        else:
+            text = text[:, :max_length]
+
+        # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
+        # text: (Batch, Length) -> x, y: (Batch, Length + 1)
+        x = F.pad(text, [1, 0], "constant", self.sos)
+        t = F.pad(text, [0, 1], "constant", self.ignore_id)
+        for i, l in enumerate(text_lengths):
+            t[i, l] = self.eos
+        x_lengths = text_lengths + 1
+
+        # 2. Forward Language model
+        # x: (Batch, Length) -> y: (Batch, Length, NVocab)
+        y, _ = self.lm(x, None)
+
+        # 3. Calc negative log likelihood
+        # nll: (BxL,)
+        nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
+        # nll: (BxL,) -> (BxL,)
+        if max_length is None:
+            nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
+        else:
+            nll.masked_fill_(
+                make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
+                0.0,
+            )
+        # nll: (BxL,) -> (B, L)
+        nll = nll.view(batch_size, -1)
+        return nll, x_lengths
+
+    def batchify_nll(
+        self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute negative log likelihood(nll) from transformer language model
+
+        To avoid OOM, this fuction seperate the input into batches.
+        Then call nll for each batch and combine and return results.
+        Args:
+            text: (Batch, Length)
+            text_lengths: (Batch,)
+            batch_size: int, samples each batch contain when computing nll,
+                        you may change this to avoid OOM or increase
+
+        """
+        total_num = text.size(0)
+        if total_num <= batch_size:
+            nll, x_lengths = self.nll(text, text_lengths)
+        else:
+            nlls = []
+            x_lengths = []
+            max_length = text_lengths.max()
+
+            start_idx = 0
+            while True:
+                end_idx = min(start_idx + batch_size, total_num)
+                batch_text = text[start_idx:end_idx, :]
+                batch_text_lengths = text_lengths[start_idx:end_idx]
+                # batch_nll: [B * T]
+                batch_nll, batch_x_lengths = self.nll(
+                    batch_text, batch_text_lengths, max_length=max_length
+                )
+                nlls.append(batch_nll)
+                x_lengths.append(batch_x_lengths)
+                start_idx = end_idx
+                if start_idx == total_num:
+                    break
+            nll = torch.cat(nlls)
+            x_lengths = torch.cat(x_lengths)
+        assert nll.size(0) == total_num
+        assert x_lengths.size(0) == total_num
+        return nll, x_lengths
+
+    def forward(
+        self, text: torch.Tensor, text_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        nll, y_lengths = self.nll(text, text_lengths)
+        ntokens = y_lengths.sum()
+        loss = nll.sum() / ntokens
+        stats = dict(loss=loss.detach())
+
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
+        return loss, stats, weight
+
+    def collect_feats(
+        self, text: torch.Tensor, text_lengths: torch.Tensor
+    ) -> Dict[str, torch.Tensor]:
+        return {}
diff --git a/funasr/lm/espnet_model.py b/funasr/lm/espnet_model.py
deleted file mode 100644
index db11b67..0000000
--- a/funasr/lm/espnet_model.py
+++ /dev/null
@@ -1,131 +0,0 @@
-from typing import Dict
-from typing import Optional
-from typing import Tuple
-
-import torch
-import torch.nn.functional as F
-from typeguard import check_argument_types
-
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.lm.abs_model import AbsLM
-from funasr.torch_utils.device_funcs import force_gatherable
-from funasr.train.abs_espnet_model import AbsESPnetModel
-
-
-class ESPnetLanguageModel(AbsESPnetModel):
-    def __init__(self, lm: AbsLM, vocab_size: int, ignore_id: int = 0):
-        assert check_argument_types()
-        super().__init__()
-        self.lm = lm
-        self.sos = 1
-        self.eos = 2
-
-        # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
-        self.ignore_id = ignore_id
-
-    def nll(
-        self,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
-        max_length: Optional[int] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Compute negative log likelihood(nll)
-
-        Normally, this function is called in batchify_nll.
-        Args:
-            text: (Batch, Length)
-            text_lengths: (Batch,)
-            max_lengths: int
-        """
-        batch_size = text.size(0)
-        # For data parallel
-        if max_length is None:
-            text = text[:, : text_lengths.max()]
-        else:
-            text = text[:, :max_length]
-
-        # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
-        # text: (Batch, Length) -> x, y: (Batch, Length + 1)
-        x = F.pad(text, [1, 0], "constant", self.sos)
-        t = F.pad(text, [0, 1], "constant", self.ignore_id)
-        for i, l in enumerate(text_lengths):
-            t[i, l] = self.eos
-        x_lengths = text_lengths + 1
-
-        # 2. Forward Language model
-        # x: (Batch, Length) -> y: (Batch, Length, NVocab)
-        y, _ = self.lm(x, None)
-
-        # 3. Calc negative log likelihood
-        # nll: (BxL,)
-        nll = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none")
-        # nll: (BxL,) -> (BxL,)
-        if max_length is None:
-            nll.masked_fill_(make_pad_mask(x_lengths).to(nll.device).view(-1), 0.0)
-        else:
-            nll.masked_fill_(
-                make_pad_mask(x_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
-                0.0,
-            )
-        # nll: (BxL,) -> (B, L)
-        nll = nll.view(batch_size, -1)
-        return nll, x_lengths
-
-    def batchify_nll(
-        self, text: torch.Tensor, text_lengths: torch.Tensor, batch_size: int = 100
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Compute negative log likelihood(nll) from transformer language model
-
-        To avoid OOM, this fuction seperate the input into batches.
-        Then call nll for each batch and combine and return results.
-        Args:
-            text: (Batch, Length)
-            text_lengths: (Batch,)
-            batch_size: int, samples each batch contain when computing nll,
-                        you may change this to avoid OOM or increase
-
-        """
-        total_num = text.size(0)
-        if total_num <= batch_size:
-            nll, x_lengths = self.nll(text, text_lengths)
-        else:
-            nlls = []
-            x_lengths = []
-            max_length = text_lengths.max()
-
-            start_idx = 0
-            while True:
-                end_idx = min(start_idx + batch_size, total_num)
-                batch_text = text[start_idx:end_idx, :]
-                batch_text_lengths = text_lengths[start_idx:end_idx]
-                # batch_nll: [B * T]
-                batch_nll, batch_x_lengths = self.nll(
-                    batch_text, batch_text_lengths, max_length=max_length
-                )
-                nlls.append(batch_nll)
-                x_lengths.append(batch_x_lengths)
-                start_idx = end_idx
-                if start_idx == total_num:
-                    break
-            nll = torch.cat(nlls)
-            x_lengths = torch.cat(x_lengths)
-        assert nll.size(0) == total_num
-        assert x_lengths.size(0) == total_num
-        return nll, x_lengths
-
-    def forward(
-        self, text: torch.Tensor, text_lengths: torch.Tensor
-    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        nll, y_lengths = self.nll(text, text_lengths)
-        ntokens = y_lengths.sum()
-        loss = nll.sum() / ntokens
-        stats = dict(loss=loss.detach())
-
-        # force_gatherable: to-device and to-tensor if scalar for DataParallel
-        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
-        return loss, stats, weight
-
-    def collect_feats(
-        self, text: torch.Tensor, text_lengths: torch.Tensor
-    ) -> Dict[str, torch.Tensor]:
-        return {}
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index e6cd7c0..ff37429 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -192,7 +192,7 @@
 
 
 class E2EVadModel(nn.Module):
-    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
+    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
         super(E2EVadModel, self).__init__()
         self.vad_opts = VADXOptions(**vad_post_args)
         self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@@ -229,6 +229,7 @@
         self.data_buf_all = None
         self.waveform = None
         self.ResetDetection()
+        self.frontend = frontend
 
     def AllResetDetection(self):
         self.is_final = False
@@ -477,8 +478,9 @@
                 ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
         self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
         self.waveform = waveform  # compute decibel for each frame
-        self.ComputeDecibel()
+        
         self.ComputeScores(feats, in_cache)
+        self.ComputeDecibel()
         if not is_final:
             self.DetectCommonFrames()
         else:
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 57890ef..2a3a353 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -10,7 +10,7 @@
 from typeguard import check_argument_types
 import numpy as np
 from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
+from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
 from funasr.modules.embedding import SinusoidalPositionEncoder
 from funasr.modules.layer_norm import LayerNorm
 from funasr.modules.multi_layer_conv import Conv1dLinear
@@ -27,7 +27,7 @@
 from funasr.modules.subsampling import check_short_utt
 from funasr.models.ctc import CTC
 from funasr.models.encoder.abs_encoder import AbsEncoder
-
+from funasr.modules.mask import subsequent_mask, vad_mask
 
 class EncoderLayerSANM(nn.Module):
     def __init__(
@@ -958,3 +958,231 @@
                                                                                       var_dict_tf[name_tf].shape))
     
         return var_dict_torch_update
+
+
+class SANMVadEncoder(AbsEncoder):
+    """
+    author: Speech Lab, Alibaba Group, China
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        pos_enc_class=SinusoidalPositionEncoder,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
+        kernel_size : int = 11,
+        sanm_shfit : int = 0,
+        selfattention_layer_type: str = "sanm",
+    ):
+        assert check_argument_types()
+        super().__init__()
+        self._output_size = output_size
+
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                SinusoidalPositionEncoder(),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        elif input_layer == "pe":
+            self.embed = SinusoidalPositionEncoder()
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+
+        elif selfattention_layer_type == "sanm":
+            self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
+            encoder_selfattn_layer_args0 = (
+                attention_heads,
+                input_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+        self.encoders = repeat(
+            num_blocks-1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+        self.dropout = nn.Dropout(dropout_rate)
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        vad_indexes: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
+        no_future_masks = masks & sub_masks
+        xs_pad *= self.output_size()**0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
+              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
+                    f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        # xs_pad = self.dropout(xs_pad)
+        mask_tup0 = [masks, no_future_masks]
+        encoder_outs = self.encoders0(xs_pad, mask_tup0)
+        xs_pad, _ = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+
+
+        for layer_idx, encoder_layer in enumerate(self.encoders):
+                if layer_idx + 1 == len(self.encoders):
+                    # This is last layer.
+                    coner_mask = torch.ones(masks.size(0),
+                                            masks.size(-1),
+                                            masks.size(-1),
+                                            device=xs_pad.device,
+                                            dtype=torch.bool)
+                    for word_index, length in enumerate(ilens):
+                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
+                                                                vad_indexes[word_index],
+                                                                device=xs_pad.device)
+                    layer_mask = masks & coner_mask
+                else:
+                    layer_mask = no_future_masks
+                mask_tup1 = [masks, layer_mask]
+                encoder_outs = encoder_layer(xs_pad, mask_tup1)
+                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens, None
diff --git a/funasr/punctuation/target_delay_transformer.py b/funasr/models/target_delay_transformer.py
similarity index 95%
rename from funasr/punctuation/target_delay_transformer.py
rename to funasr/models/target_delay_transformer.py
index 219af26..84a2e6c 100644
--- a/funasr/punctuation/target_delay_transformer.py
+++ b/funasr/models/target_delay_transformer.py
@@ -5,12 +5,11 @@
 import torch
 import torch.nn as nn
 
-from funasr.modules.embedding import PositionalEncoding
 from funasr.modules.embedding import SinusoidalPositionEncoder
 #from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
-from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
 #from funasr.modules.mask import subsequent_n_mask
-from funasr.punctuation.abs_model import AbsPunctuation
+from funasr.train.abs_model import AbsPunctuation
 
 
 class TargetDelayTransformer(AbsPunctuation):
diff --git a/funasr/punctuation/vad_realtime_transformer.py b/funasr/models/vad_realtime_transformer.py
similarity index 96%
rename from funasr/punctuation/vad_realtime_transformer.py
rename to funasr/models/vad_realtime_transformer.py
index 35224f9..66f7fad 100644
--- a/funasr/punctuation/vad_realtime_transformer.py
+++ b/funasr/models/vad_realtime_transformer.py
@@ -6,8 +6,8 @@
 import torch.nn as nn
 
 from funasr.modules.embedding import SinusoidalPositionEncoder
-from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
-from funasr.punctuation.abs_model import AbsPunctuation
+from funasr.models.encoder.sanm_encoder import SANMVadEncoder as Encoder
+from funasr.train.abs_model import AbsPunctuation
 
 
 class VadRealtimeTransformer(AbsPunctuation):
diff --git a/funasr/punctuation/abs_model.py b/funasr/punctuation/abs_model.py
deleted file mode 100644
index 404d5e8..0000000
--- a/funasr/punctuation/abs_model.py
+++ /dev/null
@@ -1,31 +0,0 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Tuple
-
-import torch
-
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-
-
-class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
-    """The abstract class
-
-    To share the loss calculation way among different models,
-    We uses delegate pattern here:
-    The instance of this class should be passed to "LanguageModel"
-
-    >>> from funasr.punctuation.abs_model import AbsPunctuation
-    >>> punc = AbsPunctuation()
-    >>> model = ESPnetPunctuationModel(punc=punc)
-
-    This "model" is one of mediator objects for "Task" class.
-
-    """
-
-    @abstractmethod
-    def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def with_vad(self) -> bool:
-        raise NotImplementedError
diff --git a/funasr/punctuation/sanm_encoder.py b/funasr/punctuation/sanm_encoder.py
deleted file mode 100644
index 8962093..0000000
--- a/funasr/punctuation/sanm_encoder.py
+++ /dev/null
@@ -1,590 +0,0 @@
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-import logging
-import torch
-import torch.nn as nn
-from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
-import numpy as np
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
-from funasr.modules.embedding import SinusoidalPositionEncoder
-from funasr.modules.layer_norm import LayerNorm
-from funasr.modules.multi_layer_conv import Conv1dLinear
-from funasr.modules.multi_layer_conv import MultiLayeredConv1d
-from funasr.modules.positionwise_feed_forward import (
-    PositionwiseFeedForward,  # noqa: H301
-)
-from funasr.modules.repeat import repeat
-from funasr.modules.subsampling import Conv2dSubsampling
-from funasr.modules.subsampling import Conv2dSubsampling2
-from funasr.modules.subsampling import Conv2dSubsampling6
-from funasr.modules.subsampling import Conv2dSubsampling8
-from funasr.modules.subsampling import TooShortUttError
-from funasr.modules.subsampling import check_short_utt
-from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
-
-from funasr.modules.nets_utils import make_pad_mask
-from funasr.modules.mask import subsequent_mask, vad_mask
-
-class EncoderLayerSANM(nn.Module):
-    def __init__(
-        self,
-        in_size,
-        size,
-        self_attn,
-        feed_forward,
-        dropout_rate,
-        normalize_before=True,
-        concat_after=False,
-        stochastic_depth_rate=0.0,
-    ):
-        """Construct an EncoderLayer object."""
-        super(EncoderLayerSANM, self).__init__()
-        self.self_attn = self_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(in_size)
-        self.norm2 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.in_size = in_size
-        self.size = size
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        if self.concat_after:
-            self.concat_linear = nn.Linear(size + size, size)
-        self.stochastic_depth_rate = stochastic_depth_rate
-        self.dropout_rate = dropout_rate
-
-    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
-        """Compute encoded features.
-
-        Args:
-            x_input (torch.Tensor): Input tensor (#batch, time, size).
-            mask (torch.Tensor): Mask tensor for the input (#batch, time).
-            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time, size).
-            torch.Tensor: Mask tensor (#batch, time).
-
-        """
-        skip_layer = False
-        # with stochastic depth, residual connection `x + f(x)` becomes
-        # `x <- x + 1 / (1 - p) * f(x)` at training time.
-        stoch_layer_coeff = 1.0
-        if self.training and self.stochastic_depth_rate > 0:
-            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
-            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
-        if skip_layer:
-            if cache is not None:
-                x = torch.cat([cache, x], dim=1)
-            return x, mask
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm1(x)
-
-        if self.concat_after:
-            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
-            if self.in_size == self.size:
-                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
-            else:
-                x = stoch_layer_coeff * self.concat_linear(x_concat)
-        else:
-            if self.in_size == self.size:
-                x = residual + stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
-                )
-            else:
-                x = stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
-                )
-        if not self.normalize_before:
-            x = self.norm1(x)
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm2(x)
-        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm2(x)
-
-
-        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
-
-class SANMEncoder(AbsEncoder):
-    """
-    author: Speech Lab, Alibaba Group, China
-
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        output_size: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        attention_dropout_rate: float = 0.0,
-        input_layer: Optional[str] = "conv2d",
-        pos_enc_class=SinusoidalPositionEncoder,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        positionwise_layer_type: str = "linear",
-        positionwise_conv_kernel_size: int = 1,
-        padding_idx: int = -1,
-        interctc_layer_idx: List[int] = [],
-        interctc_use_conditioning: bool = False,
-        kernel_size : int = 11,
-        sanm_shfit : int = 0,
-        selfattention_layer_type: str = "sanm",
-    ):
-        assert check_argument_types()
-        super().__init__()
-        self._output_size = output_size
-
-        if input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(input_size, output_size),
-                torch.nn.LayerNorm(output_size),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2d":
-            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d2":
-            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d6":
-            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d8":
-            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
-        elif input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
-                SinusoidalPositionEncoder(),
-            )
-        elif input_layer is None:
-            if input_size == output_size:
-                self.embed = None
-            else:
-                self.embed = torch.nn.Linear(input_size, output_size)
-        elif input_layer == "pe":
-            self.embed = SinusoidalPositionEncoder()
-        else:
-            raise ValueError("unknown input_layer: " + input_layer)
-        self.normalize_before = normalize_before
-        if positionwise_layer_type == "linear":
-            positionwise_layer = PositionwiseFeedForward
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d":
-            positionwise_layer = MultiLayeredConv1d
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d-linear":
-            positionwise_layer = Conv1dLinear
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        else:
-            raise NotImplementedError("Support only linear or conv1d.")
-
-        if selfattention_layer_type == "selfattn":
-            encoder_selfattn_layer = MultiHeadedAttention
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                attention_dropout_rate,
-            )
-
-        elif selfattention_layer_type == "sanm":
-            self.encoder_selfattn_layer = MultiHeadedAttentionSANM
-            encoder_selfattn_layer_args0 = (
-                attention_heads,
-                input_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-
-        self.encoders0 = repeat(
-            1,
-            lambda lnum: EncoderLayerSANM(
-                input_size,
-                output_size,
-                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-
-        self.encoders = repeat(
-            num_blocks-1,
-            lambda lnum: EncoderLayerSANM(
-                output_size,
-                output_size,
-                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-
-        self.interctc_layer_idx = interctc_layer_idx
-        if len(interctc_layer_idx) > 0:
-            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
-        self.interctc_use_conditioning = interctc_use_conditioning
-        self.conditioning_layer = None
-        self.dropout = nn.Dropout(dropout_rate)
-
-    def output_size(self) -> int:
-        return self._output_size
-
-    def forward(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        prev_states: torch.Tensor = None,
-        ctc: CTC = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Embed positions in tensor.
-
-        Args:
-            xs_pad: input tensor (B, L, D)
-            ilens: input length (B)
-            prev_states: Not to be used now.
-        Returns:
-            position embedded tensor and mask
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad *= self.output_size()**0.5
-        if self.embed is None:
-            xs_pad = xs_pad
-        elif (
-            isinstance(self.embed, Conv2dSubsampling)
-            or isinstance(self.embed, Conv2dSubsampling2)
-            or isinstance(self.embed, Conv2dSubsampling6)
-            or isinstance(self.embed, Conv2dSubsampling8)
-        ):
-            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
-            if short_status:
-                raise TooShortUttError(
-                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
-                    + f"(it needs more than {limit_size} frames), return empty results",
-                    xs_pad.size(1),
-                    limit_size,
-                )
-            xs_pad, masks = self.embed(xs_pad, masks)
-        else:
-            xs_pad = self.embed(xs_pad)
-
-        # xs_pad = self.dropout(xs_pad)
-        encoder_outs = self.encoders0(xs_pad, masks)
-        xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        intermediate_outs = []
-        if len(self.interctc_layer_idx) == 0:
-            encoder_outs = self.encoders(xs_pad, masks)
-            xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        else:
-            for layer_idx, encoder_layer in enumerate(self.encoders):
-                encoder_outs = encoder_layer(xs_pad, masks)
-                xs_pad, masks = encoder_outs[0], encoder_outs[1]
-
-                if layer_idx + 1 in self.interctc_layer_idx:
-                    encoder_out = xs_pad
-
-                    # intermediate outputs are also normalized
-                    if self.normalize_before:
-                        encoder_out = self.after_norm(encoder_out)
-
-                    intermediate_outs.append((layer_idx + 1, encoder_out))
-
-                    if self.interctc_use_conditioning:
-                        ctc_out = ctc.softmax(encoder_out)
-                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-
-        olens = masks.squeeze(1).sum(1)
-        if len(intermediate_outs) > 0:
-            return (xs_pad, intermediate_outs), olens, None
-        return xs_pad, olens, None
-
-class SANMVadEncoder(AbsEncoder):
-    """
-    author: Speech Lab, Alibaba Group, China
-
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        output_size: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        attention_dropout_rate: float = 0.0,
-        input_layer: Optional[str] = "conv2d",
-        pos_enc_class=SinusoidalPositionEncoder,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        positionwise_layer_type: str = "linear",
-        positionwise_conv_kernel_size: int = 1,
-        padding_idx: int = -1,
-        interctc_layer_idx: List[int] = [],
-        interctc_use_conditioning: bool = False,
-        kernel_size : int = 11,
-        sanm_shfit : int = 0,
-        selfattention_layer_type: str = "sanm",
-    ):
-        assert check_argument_types()
-        super().__init__()
-        self._output_size = output_size
-
-        if input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(input_size, output_size),
-                torch.nn.LayerNorm(output_size),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2d":
-            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d2":
-            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d6":
-            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d8":
-            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
-        elif input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
-                SinusoidalPositionEncoder(),
-            )
-        elif input_layer is None:
-            if input_size == output_size:
-                self.embed = None
-            else:
-                self.embed = torch.nn.Linear(input_size, output_size)
-        elif input_layer == "pe":
-            self.embed = SinusoidalPositionEncoder()
-        else:
-            raise ValueError("unknown input_layer: " + input_layer)
-        self.normalize_before = normalize_before
-        if positionwise_layer_type == "linear":
-            positionwise_layer = PositionwiseFeedForward
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d":
-            positionwise_layer = MultiLayeredConv1d
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d-linear":
-            positionwise_layer = Conv1dLinear
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        else:
-            raise NotImplementedError("Support only linear or conv1d.")
-
-        if selfattention_layer_type == "selfattn":
-            encoder_selfattn_layer = MultiHeadedAttention
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                attention_dropout_rate,
-            )
-
-        elif selfattention_layer_type == "sanm":
-            self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
-            encoder_selfattn_layer_args0 = (
-                attention_heads,
-                input_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-
-        self.encoders0 = repeat(
-            1,
-            lambda lnum: EncoderLayerSANM(
-                input_size,
-                output_size,
-                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-
-        self.encoders = repeat(
-            num_blocks-1,
-            lambda lnum: EncoderLayerSANM(
-                output_size,
-                output_size,
-                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-
-        self.interctc_layer_idx = interctc_layer_idx
-        if len(interctc_layer_idx) > 0:
-            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
-        self.interctc_use_conditioning = interctc_use_conditioning
-        self.conditioning_layer = None
-        self.dropout = nn.Dropout(dropout_rate)
-
-    def output_size(self) -> int:
-        return self._output_size
-
-    def forward(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        vad_indexes: torch.Tensor,
-        prev_states: torch.Tensor = None,
-        ctc: CTC = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Embed positions in tensor.
-
-        Args:
-            xs_pad: input tensor (B, L, D)
-            ilens: input length (B)
-            prev_states: Not to be used now.
-        Returns:
-            position embedded tensor and mask
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
-        no_future_masks = masks & sub_masks
-        xs_pad *= self.output_size()**0.5
-        if self.embed is None:
-            xs_pad = xs_pad
-        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
-              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
-            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
-            if short_status:
-                raise TooShortUttError(
-                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
-                    f"(it needs more than {limit_size} frames), return empty results",
-                    xs_pad.size(1),
-                    limit_size,
-                )
-            xs_pad, masks = self.embed(xs_pad, masks)
-        else:
-            xs_pad = self.embed(xs_pad)
-
-        # xs_pad = self.dropout(xs_pad)
-        mask_tup0 = [masks, no_future_masks]
-        encoder_outs = self.encoders0(xs_pad, mask_tup0)
-        xs_pad, _ = encoder_outs[0], encoder_outs[1]
-        intermediate_outs = []
-        #if len(self.interctc_layer_idx) == 0:
-        if False:
-            # Here, we should not use the repeat operation to do it for all layers.
-            encoder_outs = self.encoders(xs_pad, masks)
-            xs_pad, masks = encoder_outs[0], encoder_outs[1]
-        else:
-            for layer_idx, encoder_layer in enumerate(self.encoders):
-                if layer_idx + 1 == len(self.encoders):
-                    # This is last layer.
-                    coner_mask = torch.ones(masks.size(0),
-                                            masks.size(-1),
-                                            masks.size(-1),
-                                            device=xs_pad.device,
-                                            dtype=torch.bool)
-                    for word_index, length in enumerate(ilens):
-                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
-                                                                vad_indexes[word_index],
-                                                                device=xs_pad.device)
-                    layer_mask = masks & coner_mask
-                else:
-                    layer_mask = no_future_masks
-                mask_tup1 = [masks, layer_mask]
-                encoder_outs = encoder_layer(xs_pad, mask_tup1)
-                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
-
-                if layer_idx + 1 in self.interctc_layer_idx:
-                    encoder_out = xs_pad
-
-                    # intermediate outputs are also normalized
-                    if self.normalize_before:
-                        encoder_out = self.after_norm(encoder_out)
-
-                    intermediate_outs.append((layer_idx + 1, encoder_out))
-
-                    if self.interctc_use_conditioning:
-                        ctc_out = ctc.softmax(encoder_out)
-                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-
-        olens = masks.squeeze(1).sum(1)
-        if len(intermediate_outs) > 0:
-            return (xs_pad, intermediate_outs), olens, None
-        return xs_pad, olens, None
-
diff --git a/funasr/punctuation/text_preprocessor.py b/funasr/punctuation/text_preprocessor.py
deleted file mode 100644
index c9c4bac..0000000
--- a/funasr/punctuation/text_preprocessor.py
+++ /dev/null
@@ -1,12 +0,0 @@
-def split_to_mini_sentence(words: list, word_limit: int = 20):
-    assert word_limit > 1
-    if len(words) <= word_limit:
-        return [words]
-    sentences = []
-    length = len(words)
-    sentence_len = length // word_limit
-    for i in range(sentence_len):
-        sentences.append(words[i * word_limit:(i + 1) * word_limit])
-    if length % word_limit > 0:
-        sentences.append(words[sentence_len * word_limit:])
-    return sentences
diff --git a/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py b/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
index 2f09de8..cafc43b 100644
--- a/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
+++ b/funasr/runtime/python/libtorch/funasr_torch/utils/utils.py
@@ -134,7 +134,7 @@
 
 
 @functools.lru_cache()
-def get_logger(name='torch_paraformer'):
+def get_logger(name='funasr_torch'):
     """Initialize and get a logger by name.
     If the logger has not been initialized, this method will initialize the
     logger by adding one or two handlers, otherwise the initialized logger will
diff --git a/funasr/runtime/python/onnxruntime/demo.py b/funasr/runtime/python/onnxruntime/demo.py
index f0f39d7..8fc82f1 100644
--- a/funasr/runtime/python/onnxruntime/demo.py
+++ b/funasr/runtime/python/onnxruntime/demo.py
@@ -1,5 +1,6 @@
 from funasr_onnx import Paraformer
 
+
 model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
 
 model = Paraformer(model_dir, batch_size=2, plot_timestamp_to="./", pred_bias=0)  # cpu
@@ -7,7 +8,6 @@
 
 # when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps
 # model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png")
-
 
 wav_path = "YourPath/xx.wav"
 
diff --git a/funasr/runtime/python/onnxruntime/demo_punc_offline.py b/funasr/runtime/python/onnxruntime/demo_punc_offline.py
new file mode 100644
index 0000000..469adda
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_punc_offline.py
@@ -0,0 +1,8 @@
+from funasr_onnx import CT_Transformer
+
+model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
+model = CT_Transformer(model_dir)
+
+text_in="璺ㄥ娌虫祦鏄吇鑲叉部宀镐汉姘戠殑鐢熷懡涔嬫簮闀挎湡浠ユ潵涓哄府鍔╀笅娓稿湴鍖洪槻鐏惧噺鐏句腑鏂规妧鏈汉鍛樺湪涓婃父鍦板尯鏋佷负鎭跺姡鐨勮嚜鐒舵潯浠朵笅鍏嬫湇宸ㄥぇ鍥伴毦鐢氳嚦鍐掔潃鐢熷懡鍗遍櫓鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒跺嚒鏄腑鏂硅兘鍋氱殑鎴戜滑閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛﹁鍒掑拰璁鸿瘉鍏奸【涓婁笅娓哥殑鍒╃泭"
+result = model(text_in)
+print(result[0])
diff --git a/funasr/runtime/python/onnxruntime/demo_punc_online.py b/funasr/runtime/python/onnxruntime/demo_punc_online.py
new file mode 100644
index 0000000..63f2f5e
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_punc_online.py
@@ -0,0 +1,15 @@
+from funasr_onnx import CT_Transformer_VadRealtime
+
+model_dir = "../../../export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
+model = CT_Transformer_VadRealtime(model_dir)
+
+text_in  = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦>闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
+
+vads = text_in.split("|")
+rec_result_all=""
+param_dict = {"cache": []}
+for vad in vads:
+    result = model(vad, param_dict=param_dict)
+    rec_result_all += result[0]
+
+print(rec_result_all)
diff --git a/funasr/runtime/python/onnxruntime/demo_vad.py b/funasr/runtime/python/onnxruntime/demo_vad.py
new file mode 100644
index 0000000..2e17197
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_vad.py
@@ -0,0 +1,30 @@
+import soundfile
+from funasr_onnx import Fsmn_vad
+
+
+model_dir = "/Users/zhifu/Downloads/speech_fsmn_vad_zh-cn-16k-common-pytorch"
+wav_path = "/Users/zhifu/Downloads/speech_fsmn_vad_zh-cn-16k-common-pytorch/example/vad_example.wav"
+model = Fsmn_vad(model_dir)
+
+#offline vad
+# result = model(wav_path)
+# print(result)
+
+#online vad
+speech, sample_rate = soundfile.read(wav_path)
+speech_length = speech.shape[0]
+
+sample_offset = 0
+step = 160 * 10
+param_dict = {'in_cache': []}
+for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
+    if sample_offset + step >= speech_length - 1:
+        step = speech_length - sample_offset
+        is_final = True
+    else:
+        is_final = False
+    param_dict['is_final'] = is_final
+    segments_result = model(audio_in=speech[sample_offset: sample_offset + step],
+                            param_dict=param_dict)
+    print(segments_result)
+
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
index 647f9fa..86f0e8e 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -1,2 +1,5 @@
 # -*- encoding: utf-8 -*-
 from .paraformer_bin import Paraformer
+from .vad_bin import Fsmn_vad
+from .punc_bin import CT_Transformer
+from .punc_bin import CT_Transformer_VadRealtime
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
new file mode 100644
index 0000000..0dc728a
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -0,0 +1,249 @@
+# -*- encoding: utf-8 -*-
+
+import os.path
+from pathlib import Path
+from typing import List, Union, Tuple
+import numpy as np
+
+from .utils.utils import (ONNXRuntimeError,
+                          OrtInferSession, get_logger,
+                          read_yaml)
+from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
+logging = get_logger()
+
+
+class CT_Transformer():
+    def __init__(self, model_dir: Union[str, Path] = None,
+                 batch_size: int = 1,
+                 device_id: Union[str, int] = "-1",
+                 quantize: bool = False,
+                 intra_op_num_threads: int = 4
+                 ):
+
+        if not Path(model_dir).exists():
+            raise FileNotFoundError(f'{model_dir} does not exist.')
+
+        model_file = os.path.join(model_dir, 'model.onnx')
+        if quantize:
+            model_file = os.path.join(model_dir, 'model_quant.onnx')
+        config_file = os.path.join(model_dir, 'punc.yaml')
+        config = read_yaml(config_file)
+
+        self.converter = TokenIDConverter(config['token_list'])
+        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
+        self.batch_size = 1
+        self.punc_list = config['punc_list']
+        self.period = 0
+        for i in range(len(self.punc_list)):
+            if self.punc_list[i] == ",":
+                self.punc_list[i] = "锛�"
+            elif self.punc_list[i] == "?":
+                self.punc_list[i] = "锛�"
+            elif self.punc_list[i] == "銆�":
+                self.period = i
+
+    def __call__(self, text: Union[list, str], split_size=20):
+        split_text = code_mix_split_words(text)
+        split_text_id = self.converter.tokens2ids(split_text)
+        mini_sentences = split_to_mini_sentence(split_text, split_size)
+        mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
+        assert len(mini_sentences) == len(mini_sentences_id)
+        cache_sent = []
+        cache_sent_id = []
+        new_mini_sentence = ""
+        new_mini_sentence_punc = []
+        cache_pop_trigger_limit = 200
+        for mini_sentence_i in range(len(mini_sentences)):
+            mini_sentence = mini_sentences[mini_sentence_i]
+            mini_sentence_id = mini_sentences_id[mini_sentence_i]
+            mini_sentence = cache_sent + mini_sentence
+            mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
+            data = {
+                "text": mini_sentence_id[None,:],
+                "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
+            }
+            try:
+                outputs = self.infer(data['text'], data['text_lengths'])
+                y = outputs[0]
+                punctuations = np.argmax(y,axis=-1)[0]
+                assert punctuations.size == len(mini_sentence)
+            except ONNXRuntimeError:
+                logging.warning("error")
+
+            # Search for the last Period/QuestionMark as cache
+            if mini_sentence_i < len(mini_sentences) - 1:
+                sentenceEnd = -1
+                last_comma_index = -1
+                for i in range(len(punctuations) - 2, 1, -1):
+                    if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+                        sentenceEnd = i
+                        break
+                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+                        last_comma_index = i
+
+                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+                    # The sentence it too long, cut off at a comma.
+                    sentenceEnd = last_comma_index
+                    punctuations[sentenceEnd] = self.period
+                cache_sent = mini_sentence[sentenceEnd + 1:]
+                cache_sent_id = mini_sentence_id[sentenceEnd + 1:].tolist()
+                mini_sentence = mini_sentence[0:sentenceEnd + 1]
+                punctuations = punctuations[0:sentenceEnd + 1]
+
+            new_mini_sentence_punc += [int(x) for x in punctuations]
+            words_with_punc = []
+            for i in range(len(mini_sentence)):
+                if i > 0:
+                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
+                        mini_sentence[i] = " " + mini_sentence[i]
+                words_with_punc.append(mini_sentence[i])
+                if self.punc_list[punctuations[i]] != "_":
+                    words_with_punc.append(self.punc_list[punctuations[i]])
+            new_mini_sentence += "".join(words_with_punc)
+            # Add Period for the end of the sentence
+            new_mini_sentence_out = new_mini_sentence
+            new_mini_sentence_punc_out = new_mini_sentence_punc
+            if mini_sentence_i == len(mini_sentences) - 1:
+                if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
+                    new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�":
+                    new_mini_sentence_out = new_mini_sentence + "銆�"
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+        return new_mini_sentence_out, new_mini_sentence_punc_out
+
+    def infer(self, feats: np.ndarray,
+              feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+        outputs = self.ort_infer([feats, feats_len])
+        return outputs
+
+
+class CT_Transformer_VadRealtime(CT_Transformer):
+    def __init__(self, model_dir: Union[str, Path] = None,
+                 batch_size: int = 1,
+                 device_id: Union[str, int] = "-1",
+                 quantize: bool = False,
+                 intra_op_num_threads: int = 4
+                 ):
+        super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
+
+    def __call__(self, text: str, param_dict: map, split_size=20):
+        cache_key = "cache"
+        assert cache_key in param_dict
+        cache = param_dict[cache_key]
+        if cache is not None and len(cache) > 0:
+            precache = "".join(cache)
+        else:
+            precache = ""
+            cache = []
+        full_text = precache + text
+        split_text = code_mix_split_words(full_text)
+        split_text_id = self.converter.tokens2ids(split_text)
+        mini_sentences = split_to_mini_sentence(split_text, split_size)
+        mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
+        new_mini_sentence_punc = []
+        assert len(mini_sentences) == len(mini_sentences_id)
+
+        cache_sent = []
+        cache_sent_id = np.array([], dtype='int32')
+        sentence_punc_list = []
+        sentence_words_list = []
+        cache_pop_trigger_limit = 200
+        skip_num = 0
+        for mini_sentence_i in range(len(mini_sentences)):
+            mini_sentence = mini_sentences[mini_sentence_i]
+            mini_sentence_id = mini_sentences_id[mini_sentence_i]
+            mini_sentence = cache_sent + mini_sentence
+            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+            text_length = len(mini_sentence_id)
+            data = {
+                "input": mini_sentence_id[None,:],
+                "text_lengths": np.array([text_length], dtype='int32'),
+                "vad_mask": self.vad_mask(text_length, len(cache) - 1)[None, None, :, :].astype(np.float32),
+                "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
+            }
+            try:
+                outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
+                y = outputs[0]
+                punctuations = np.argmax(y,axis=-1)[0]
+                assert punctuations.size == len(mini_sentence)
+            except ONNXRuntimeError:
+                logging.warning("error")
+
+            # Search for the last Period/QuestionMark as cache
+            if mini_sentence_i < len(mini_sentences) - 1:
+                sentenceEnd = -1
+                last_comma_index = -1
+                for i in range(len(punctuations) - 2, 1, -1):
+                    if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+                        sentenceEnd = i
+                        break
+                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+                        last_comma_index = i
+
+                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+                    # The sentence it too long, cut off at a comma.
+                    sentenceEnd = last_comma_index
+                    punctuations[sentenceEnd] = self.period
+                cache_sent = mini_sentence[sentenceEnd + 1:]
+                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+                mini_sentence = mini_sentence[0:sentenceEnd + 1]
+                punctuations = punctuations[0:sentenceEnd + 1]
+
+            punctuations_np = [int(x) for x in punctuations]
+            new_mini_sentence_punc += punctuations_np
+            sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
+            sentence_words_list += mini_sentence
+
+        assert len(sentence_punc_list) == len(sentence_words_list)
+        words_with_punc = []
+        sentence_punc_list_out = []
+        for i in range(0, len(sentence_words_list)):
+            if i > 0:
+                if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
+                    sentence_words_list[i] = " " + sentence_words_list[i]
+            if skip_num < len(cache):
+                skip_num += 1
+            else:
+                words_with_punc.append(sentence_words_list[i])
+            if skip_num >= len(cache):
+                sentence_punc_list_out.append(sentence_punc_list[i])
+                if sentence_punc_list[i] != "_":
+                    words_with_punc.append(sentence_punc_list[i])
+        sentence_out = "".join(words_with_punc)
+
+        sentenceEnd = -1
+        for i in range(len(sentence_punc_list) - 2, 1, -1):
+            if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
+                sentenceEnd = i
+                break
+        cache_out = sentence_words_list[sentenceEnd + 1:]
+        if sentence_out[-1] in self.punc_list:
+            sentence_out = sentence_out[:-1]
+            sentence_punc_list_out[-1] = "_"
+        param_dict[cache_key] = cache_out
+        return sentence_out, sentence_punc_list_out, cache_out
+
+    def vad_mask(self, size, vad_pos, dtype=np.bool):
+        """Create mask for decoder self-attention.
+
+        :param int size: size of mask
+        :param int vad_pos: index of vad index
+        :param torch.dtype dtype: result dtype
+        :rtype: torch.Tensor (B, Lmax, Lmax)
+        """
+        ret = np.ones((size, size), dtype=dtype)
+        if vad_pos <= 0 or vad_pos >= size:
+            return ret
+        sub_corner = np.zeros(
+            (vad_pos - 1, size - vad_pos), dtype=dtype)
+        ret[0:vad_pos - 1, vad_pos:] = sub_corner
+        return ret
+
+    def infer(self, feats: np.ndarray,
+              feats_len: np.ndarray,
+              vad_mask: np.ndarray,
+              sub_masks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+        outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
+        return outputs
+
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
new file mode 100644
index 0000000..3f6c3d1
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
@@ -0,0 +1,607 @@
+from enum import Enum
+from typing import List, Tuple, Dict, Any
+
+import math
+import numpy as np
+
+class VadStateMachine(Enum):
+    kVadInStateStartPointNotDetected = 1
+    kVadInStateInSpeechSegment = 2
+    kVadInStateEndPointDetected = 3
+
+
+class FrameState(Enum):
+    kFrameStateInvalid = -1
+    kFrameStateSpeech = 1
+    kFrameStateSil = 0
+
+
+# final voice/unvoice state per frame
+class AudioChangeState(Enum):
+    kChangeStateSpeech2Speech = 0
+    kChangeStateSpeech2Sil = 1
+    kChangeStateSil2Sil = 2
+    kChangeStateSil2Speech = 3
+    kChangeStateNoBegin = 4
+    kChangeStateInvalid = 5
+
+
+class VadDetectMode(Enum):
+    kVadSingleUtteranceDetectMode = 0
+    kVadMutipleUtteranceDetectMode = 1
+
+
+class VADXOptions:
+    def __init__(
+            self,
+            sample_rate: int = 16000,
+            detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
+            snr_mode: int = 0,
+            max_end_silence_time: int = 800,
+            max_start_silence_time: int = 3000,
+            do_start_point_detection: bool = True,
+            do_end_point_detection: bool = True,
+            window_size_ms: int = 200,
+            sil_to_speech_time_thres: int = 150,
+            speech_to_sil_time_thres: int = 150,
+            speech_2_noise_ratio: float = 1.0,
+            do_extend: int = 1,
+            lookback_time_start_point: int = 200,
+            lookahead_time_end_point: int = 100,
+            max_single_segment_time: int = 60000,
+            nn_eval_block_size: int = 8,
+            dcd_block_size: int = 4,
+            snr_thres: int = -100.0,
+            noise_frame_num_used_for_snr: int = 100,
+            decibel_thres: int = -100.0,
+            speech_noise_thres: float = 0.6,
+            fe_prior_thres: float = 1e-4,
+            silence_pdf_num: int = 1,
+            sil_pdf_ids: List[int] = [0],
+            speech_noise_thresh_low: float = -0.1,
+            speech_noise_thresh_high: float = 0.3,
+            output_frame_probs: bool = False,
+            frame_in_ms: int = 10,
+            frame_length_ms: int = 25,
+    ):
+        self.sample_rate = sample_rate
+        self.detect_mode = detect_mode
+        self.snr_mode = snr_mode
+        self.max_end_silence_time = max_end_silence_time
+        self.max_start_silence_time = max_start_silence_time
+        self.do_start_point_detection = do_start_point_detection
+        self.do_end_point_detection = do_end_point_detection
+        self.window_size_ms = window_size_ms
+        self.sil_to_speech_time_thres = sil_to_speech_time_thres
+        self.speech_to_sil_time_thres = speech_to_sil_time_thres
+        self.speech_2_noise_ratio = speech_2_noise_ratio
+        self.do_extend = do_extend
+        self.lookback_time_start_point = lookback_time_start_point
+        self.lookahead_time_end_point = lookahead_time_end_point
+        self.max_single_segment_time = max_single_segment_time
+        self.nn_eval_block_size = nn_eval_block_size
+        self.dcd_block_size = dcd_block_size
+        self.snr_thres = snr_thres
+        self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
+        self.decibel_thres = decibel_thres
+        self.speech_noise_thres = speech_noise_thres
+        self.fe_prior_thres = fe_prior_thres
+        self.silence_pdf_num = silence_pdf_num
+        self.sil_pdf_ids = sil_pdf_ids
+        self.speech_noise_thresh_low = speech_noise_thresh_low
+        self.speech_noise_thresh_high = speech_noise_thresh_high
+        self.output_frame_probs = output_frame_probs
+        self.frame_in_ms = frame_in_ms
+        self.frame_length_ms = frame_length_ms
+
+
+class E2EVadSpeechBufWithDoa(object):
+    def __init__(self):
+        self.start_ms = 0
+        self.end_ms = 0
+        self.buffer = []
+        self.contain_seg_start_point = False
+        self.contain_seg_end_point = False
+        self.doa = 0
+
+    def Reset(self):
+        self.start_ms = 0
+        self.end_ms = 0
+        self.buffer = []
+        self.contain_seg_start_point = False
+        self.contain_seg_end_point = False
+        self.doa = 0
+
+
+class E2EVadFrameProb(object):
+    def __init__(self):
+        self.noise_prob = 0.0
+        self.speech_prob = 0.0
+        self.score = 0.0
+        self.frame_id = 0
+        self.frm_state = 0
+
+
+class WindowDetector(object):
+    def __init__(self, window_size_ms: int, sil_to_speech_time: int,
+                 speech_to_sil_time: int, frame_size_ms: int):
+        self.window_size_ms = window_size_ms
+        self.sil_to_speech_time = sil_to_speech_time
+        self.speech_to_sil_time = speech_to_sil_time
+        self.frame_size_ms = frame_size_ms
+
+        self.win_size_frame = int(window_size_ms / frame_size_ms)
+        self.win_sum = 0
+        self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
+
+        self.cur_win_pos = 0
+        self.pre_frame_state = FrameState.kFrameStateSil
+        self.cur_frame_state = FrameState.kFrameStateSil
+        self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
+        self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
+
+        self.voice_last_frame_count = 0
+        self.noise_last_frame_count = 0
+        self.hydre_frame_count = 0
+
+    def Reset(self) -> None:
+        self.cur_win_pos = 0
+        self.win_sum = 0
+        self.win_state = [0] * self.win_size_frame
+        self.pre_frame_state = FrameState.kFrameStateSil
+        self.cur_frame_state = FrameState.kFrameStateSil
+        self.voice_last_frame_count = 0
+        self.noise_last_frame_count = 0
+        self.hydre_frame_count = 0
+
+    def GetWinSize(self) -> int:
+        return int(self.win_size_frame)
+
+    def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
+        cur_frame_state = FrameState.kFrameStateSil
+        if frameState == FrameState.kFrameStateSpeech:
+            cur_frame_state = 1
+        elif frameState == FrameState.kFrameStateSil:
+            cur_frame_state = 0
+        else:
+            return AudioChangeState.kChangeStateInvalid
+        self.win_sum -= self.win_state[self.cur_win_pos]
+        self.win_sum += cur_frame_state
+        self.win_state[self.cur_win_pos] = cur_frame_state
+        self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
+
+        if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
+            self.pre_frame_state = FrameState.kFrameStateSpeech
+            return AudioChangeState.kChangeStateSil2Speech
+
+        if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
+            self.pre_frame_state = FrameState.kFrameStateSil
+            return AudioChangeState.kChangeStateSpeech2Sil
+
+        if self.pre_frame_state == FrameState.kFrameStateSil:
+            return AudioChangeState.kChangeStateSil2Sil
+        if self.pre_frame_state == FrameState.kFrameStateSpeech:
+            return AudioChangeState.kChangeStateSpeech2Speech
+        return AudioChangeState.kChangeStateInvalid
+
+    def FrameSizeMs(self) -> int:
+        return int(self.frame_size_ms)
+
+
+class E2EVadModel():
+    def __init__(self, vad_post_args: Dict[str, Any]):
+        super(E2EVadModel, self).__init__()
+        self.vad_opts = VADXOptions(**vad_post_args)
+        self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
+                                               self.vad_opts.sil_to_speech_time_thres,
+                                               self.vad_opts.speech_to_sil_time_thres,
+                                               self.vad_opts.frame_in_ms)
+        # self.encoder = encoder
+        # init variables
+        self.is_final = False
+        self.data_buf_start_frame = 0
+        self.frm_cnt = 0
+        self.latest_confirmed_speech_frame = 0
+        self.lastest_confirmed_silence_frame = -1
+        self.continous_silence_frame_count = 0
+        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+        self.confirmed_start_frame = -1
+        self.confirmed_end_frame = -1
+        self.number_end_time_detected = 0
+        self.sil_frame = 0
+        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
+        self.noise_average_decibel = -100.0
+        self.pre_end_silence_detected = False
+        self.next_seg = True
+
+        self.output_data_buf = []
+        self.output_data_buf_offset = 0
+        self.frame_probs = []
+        self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
+        self.speech_noise_thres = self.vad_opts.speech_noise_thres
+        self.scores = None
+        self.max_time_out = False
+        self.decibel = []
+        self.data_buf = None
+        self.data_buf_all = None
+        self.waveform = None
+        self.ResetDetection()
+
+    def AllResetDetection(self):
+        self.is_final = False
+        self.data_buf_start_frame = 0
+        self.frm_cnt = 0
+        self.latest_confirmed_speech_frame = 0
+        self.lastest_confirmed_silence_frame = -1
+        self.continous_silence_frame_count = 0
+        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+        self.confirmed_start_frame = -1
+        self.confirmed_end_frame = -1
+        self.number_end_time_detected = 0
+        self.sil_frame = 0
+        self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
+        self.noise_average_decibel = -100.0
+        self.pre_end_silence_detected = False
+        self.next_seg = True
+
+        self.output_data_buf = []
+        self.output_data_buf_offset = 0
+        self.frame_probs = []
+        self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
+        self.speech_noise_thres = self.vad_opts.speech_noise_thres
+        self.scores = None
+        self.max_time_out = False
+        self.decibel = []
+        self.data_buf = None
+        self.data_buf_all = None
+        self.waveform = None
+        self.ResetDetection()
+
+    def ResetDetection(self):
+        self.continous_silence_frame_count = 0
+        self.latest_confirmed_speech_frame = 0
+        self.lastest_confirmed_silence_frame = -1
+        self.confirmed_start_frame = -1
+        self.confirmed_end_frame = -1
+        self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
+        self.windows_detector.Reset()
+        self.sil_frame = 0
+        self.frame_probs = []
+
+    def ComputeDecibel(self) -> None:
+        frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
+        frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+        if self.data_buf_all is None:
+            self.data_buf_all = self.waveform[0]  # self.data_buf is pointed to self.waveform[0]
+            self.data_buf = self.data_buf_all
+        else:
+            self.data_buf_all = np.concatenate((self.data_buf_all, self.waveform[0]))
+        for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
+            self.decibel.append(
+                10 * math.log10(np.square((self.waveform[0][offset: offset + frame_sample_length])).sum() + \
+                                0.000001))
+
+    def ComputeScores(self, scores: np.ndarray) -> None:
+        # scores = self.encoder(feats, in_cache)  # return B * T * D
+        self.vad_opts.nn_eval_block_size = scores.shape[1]
+        self.frm_cnt += scores.shape[1]  # count total frames
+        if self.scores is None:
+            self.scores = scores  # the first calculation
+        else:
+            self.scores = np.concatenate((self.scores, scores), axis=1)
+
+    def PopDataBufTillFrame(self, frame_idx: int) -> None:  # need check again
+        while self.data_buf_start_frame < frame_idx:
+            if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
+                self.data_buf_start_frame += 1
+                self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
+                    self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+
+    def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
+                           last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
+        self.PopDataBufTillFrame(start_frm)
+        expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
+        if last_frm_is_end_point:
+            extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
+                                      self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
+            expected_sample_number += int(extra_sample)
+        if end_point_is_sent_end:
+            expected_sample_number = max(expected_sample_number, len(self.data_buf))
+        if len(self.data_buf) < expected_sample_number:
+            print('error in calling pop data_buf\n')
+
+        if len(self.output_data_buf) == 0 or first_frm_is_start_point:
+            self.output_data_buf.append(E2EVadSpeechBufWithDoa())
+            self.output_data_buf[-1].Reset()
+            self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
+            self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
+            self.output_data_buf[-1].doa = 0
+        cur_seg = self.output_data_buf[-1]
+        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+            print('warning\n')
+        out_pos = len(cur_seg.buffer)  # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
+        data_to_pop = 0
+        if end_point_is_sent_end:
+            data_to_pop = expected_sample_number
+        else:
+            data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
+        if data_to_pop > len(self.data_buf):
+            print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
+            data_to_pop = len(self.data_buf)
+            expected_sample_number = len(self.data_buf)
+
+        cur_seg.doa = 0
+        for sample_cpy_out in range(0, data_to_pop):
+            # cur_seg.buffer[out_pos ++] = data_buf_.back();
+            out_pos += 1
+        for sample_cpy_out in range(data_to_pop, expected_sample_number):
+            # cur_seg.buffer[out_pos++] = data_buf_.back()
+            out_pos += 1
+        if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
+            print('Something wrong with the VAD algorithm\n')
+        self.data_buf_start_frame += frm_cnt
+        cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
+        if first_frm_is_start_point:
+            cur_seg.contain_seg_start_point = True
+        if last_frm_is_end_point:
+            cur_seg.contain_seg_end_point = True
+
+    def OnSilenceDetected(self, valid_frame: int):
+        self.lastest_confirmed_silence_frame = valid_frame
+        if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+            self.PopDataBufTillFrame(valid_frame)
+        # silence_detected_callback_
+        # pass
+
+    def OnVoiceDetected(self, valid_frame: int) -> None:
+        self.latest_confirmed_speech_frame = valid_frame
+        self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
+
+    def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
+        if self.vad_opts.do_start_point_detection:
+            pass
+        if self.confirmed_start_frame != -1:
+            print('not reset vad properly\n')
+        else:
+            self.confirmed_start_frame = start_frame
+
+        if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+            self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
+
+    def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
+        for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
+            self.OnVoiceDetected(t)
+        if self.vad_opts.do_end_point_detection:
+            pass
+        if self.confirmed_end_frame != -1:
+            print('not reset vad properly\n')
+        else:
+            self.confirmed_end_frame = end_frame
+        if not fake_result:
+            self.sil_frame = 0
+            self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
+        self.number_end_time_detected += 1
+
+    def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
+        if is_final_frame:
+            self.OnVoiceEnd(cur_frm_idx, False, True)
+            self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+
+    def GetLatency(self) -> int:
+        return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
+
+    def LatencyFrmNumAtStartPoint(self) -> int:
+        vad_latency = self.windows_detector.GetWinSize()
+        if self.vad_opts.do_extend:
+            vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
+        return vad_latency
+
+    def GetFrameState(self, t: int) -> FrameState:
+        frame_state = FrameState.kFrameStateInvalid
+        cur_decibel = self.decibel[t]
+        cur_snr = cur_decibel - self.noise_average_decibel
+        # for each frame, calc log posterior probability of each state
+        if cur_decibel < self.vad_opts.decibel_thres:
+            frame_state = FrameState.kFrameStateSil
+            self.DetectOneFrame(frame_state, t, False)
+            return frame_state
+
+        sum_score = 0.0
+        noise_prob = 0.0
+        assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
+        if len(self.sil_pdf_ids) > 0:
+            assert len(self.scores) == 1  # 鍙敮鎸乥atch_size = 1鐨勬祴璇�
+            sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
+            sum_score = sum(sil_pdf_scores)
+            noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
+            total_score = 1.0
+            sum_score = total_score - sum_score
+        speech_prob = math.log(sum_score)
+        if self.vad_opts.output_frame_probs:
+            frame_prob = E2EVadFrameProb()
+            frame_prob.noise_prob = noise_prob
+            frame_prob.speech_prob = speech_prob
+            frame_prob.score = sum_score
+            frame_prob.frame_id = t
+            self.frame_probs.append(frame_prob)
+        if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
+            if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
+                frame_state = FrameState.kFrameStateSpeech
+            else:
+                frame_state = FrameState.kFrameStateSil
+        else:
+            frame_state = FrameState.kFrameStateSil
+            if self.noise_average_decibel < -99.9:
+                self.noise_average_decibel = cur_decibel
+            else:
+                self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * (
+                        self.vad_opts.noise_frame_num_used_for_snr
+                        - 1)) / self.vad_opts.noise_frame_num_used_for_snr
+
+        return frame_state
+     
+
+    def __call__(self, score: np.ndarray, waveform: np.ndarray,
+                is_final: bool = False, max_end_sil: int = 800
+                ):
+        self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
+        self.waveform = waveform  # compute decibel for each frame
+        self.ComputeDecibel()
+        self.ComputeScores(score)
+        if not is_final:
+            self.DetectCommonFrames()
+        else:
+            self.DetectLastFrames()
+        segments = []
+        for batch_num in range(0, score.shape[0]):  # only support batch_size = 1 now
+            segment_batch = []
+            if len(self.output_data_buf) > 0:
+                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
+                    if not self.output_data_buf[i].contain_seg_start_point:
+                        continue
+                    if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
+                        continue
+                    start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
+                    if self.output_data_buf[i].contain_seg_end_point:
+                        end_ms = self.output_data_buf[i].end_ms
+                        self.next_seg = True
+                        self.output_data_buf_offset += 1
+                    else:
+                        end_ms = -1
+                        self.next_seg = False
+                    segment = [start_ms, end_ms]
+                    segment_batch.append(segment)
+            if segment_batch:
+                segments.append(segment_batch)
+        if is_final:
+            # reset class variables and clear the dict for the next query
+            self.AllResetDetection()
+        return segments
+
+    def DetectCommonFrames(self) -> int:
+        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+            return 0
+        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+            frame_state = FrameState.kFrameStateInvalid
+            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+            self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+
+        return 0
+
+    def DetectLastFrames(self) -> int:
+        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+            return 0
+        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+            frame_state = FrameState.kFrameStateInvalid
+            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+            if i != 0:
+                self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+            else:
+                self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
+
+        return 0
+
+    def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
+        tmp_cur_frm_state = FrameState.kFrameStateInvalid
+        if cur_frm_state == FrameState.kFrameStateSpeech:
+            if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
+                tmp_cur_frm_state = FrameState.kFrameStateSpeech
+            else:
+                tmp_cur_frm_state = FrameState.kFrameStateSil
+        elif cur_frm_state == FrameState.kFrameStateSil:
+            tmp_cur_frm_state = FrameState.kFrameStateSil
+        state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
+        frm_shift_in_ms = self.vad_opts.frame_in_ms
+        if AudioChangeState.kChangeStateSil2Speech == state_change:
+            silence_frame_count = self.continous_silence_frame_count
+            self.continous_silence_frame_count = 0
+            self.pre_end_silence_detected = False
+            start_frame = 0
+            if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+                start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint())
+                self.OnVoiceStart(start_frame)
+                self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
+                for t in range(start_frame + 1, cur_frm_idx + 1):
+                    self.OnVoiceDetected(t)
+            elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
+                    self.OnVoiceDetected(t)
+                if cur_frm_idx - self.confirmed_start_frame + 1 > \
+                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+                    self.OnVoiceEnd(cur_frm_idx, False, False)
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif not is_final_frame:
+                    self.OnVoiceDetected(cur_frm_idx)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+            else:
+                pass
+        elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
+            self.continous_silence_frame_count = 0
+            if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+                pass
+            elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                if cur_frm_idx - self.confirmed_start_frame + 1 > \
+                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+                    self.OnVoiceEnd(cur_frm_idx, False, False)
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif not is_final_frame:
+                    self.OnVoiceDetected(cur_frm_idx)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+            else:
+                pass
+        elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
+            self.continous_silence_frame_count = 0
+            if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                if cur_frm_idx - self.confirmed_start_frame + 1 > \
+                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+                    self.max_time_out = True
+                    self.OnVoiceEnd(cur_frm_idx, False, False)
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif not is_final_frame:
+                    self.OnVoiceDetected(cur_frm_idx)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+            else:
+                pass
+        elif AudioChangeState.kChangeStateSil2Sil == state_change:
+            self.continous_silence_frame_count += 1
+            if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
+                # silence timeout, return zero length decision
+                if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
+                        self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
+                        or (is_final_frame and self.number_end_time_detected == 0):
+                    for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
+                        self.OnSilenceDetected(t)
+                    self.OnVoiceStart(0, True)
+                    self.OnVoiceEnd(0, True, False);
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                else:
+                    if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
+                        self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
+            elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
+                if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh:
+                    lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
+                    if self.vad_opts.do_extend:
+                        lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
+                        lookback_frame -= 1
+                        lookback_frame = max(0, lookback_frame)
+                    self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif cur_frm_idx - self.confirmed_start_frame + 1 > \
+                        self.vad_opts.max_single_segment_time / frm_shift_in_ms:
+                    self.OnVoiceEnd(cur_frm_idx, False, False)
+                    self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
+                elif self.vad_opts.do_extend and not is_final_frame:
+                    if self.continous_silence_frame_count <= int(
+                            self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
+                        self.OnVoiceDetected(cur_frm_idx)
+                else:
+                    self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
+            else:
+                pass
+
+        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
+                self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
+            self.ResetDetection()
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
index 2edde11..0df954e 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -188,7 +188,7 @@
                  input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
         input_dict = dict(zip(self.get_input_names(), input_content))
         try:
-            return self.session.run(None, input_dict)
+            return self.session.run(self.get_output_names(), input_dict)
         except Exception as e:
             raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
 
@@ -215,6 +215,38 @@
         if not model_path.is_file():
             raise FileExistsError(f'{model_path} is not a file.')
 
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+    assert word_limit > 1
+    if len(words) <= word_limit:
+        return [words]
+    sentences = []
+    length = len(words)
+    sentence_len = length // word_limit
+    for i in range(sentence_len):
+        sentences.append(words[i * word_limit:(i + 1) * word_limit])
+    if length % word_limit > 0:
+        sentences.append(words[sentence_len * word_limit:])
+    return sentences
+
+def code_mix_split_words(text: str):
+    words = []
+    segs = text.split()
+    for seg in segs:
+        # There is no space in seg.
+        current_word = ""
+        for c in seg:
+            if len(c.encode()) == 1:
+                # This is an ASCII char.
+                current_word += c
+            else:
+                # This is a Chinese char.
+                if len(current_word) > 0:
+                    words.append(current_word)
+                    current_word = ""
+                words.append(c)
+        if len(current_word) > 0:
+            words.append(current_word)
+    return words
 
 def read_yaml(yaml_path: Union[str, Path]) -> Dict:
     if not Path(yaml_path).exists():
@@ -226,7 +258,7 @@
 
 
 @functools.lru_cache()
-def get_logger(name='rapdi_paraformer'):
+def get_logger(name='funasr_onnx'):
     """Initialize and get a logger by name.
     If the logger has not been initialized, this method will initialize the
     logger by adding one or two handlers, otherwise the initialized logger will
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
new file mode 100644
index 0000000..221867d
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -0,0 +1,143 @@
+# -*- encoding: utf-8 -*-
+
+import os.path
+from pathlib import Path
+from typing import List, Union, Tuple
+
+import copy
+import librosa
+import numpy as np
+
+from .utils.utils import (ONNXRuntimeError,
+                          OrtInferSession, get_logger,
+                          read_yaml)
+from .utils.frontend import WavFrontend
+from .utils.e2e_vad import E2EVadModel
+
+logging = get_logger()
+
+
+class Fsmn_vad():
+	def __init__(self, model_dir: Union[str, Path] = None,
+	             batch_size: int = 1,
+	             device_id: Union[str, int] = "-1",
+	             quantize: bool = False,
+	             intra_op_num_threads: int = 4,
+	             max_end_sil: int = None,
+	             ):
+		
+		if not Path(model_dir).exists():
+			raise FileNotFoundError(f'{model_dir} does not exist.')
+		
+		model_file = os.path.join(model_dir, 'model.onnx')
+		if quantize:
+			model_file = os.path.join(model_dir, 'model_quant.onnx')
+		config_file = os.path.join(model_dir, 'vad.yaml')
+		cmvn_file = os.path.join(model_dir, 'vad.mvn')
+		config = read_yaml(config_file)
+		
+		self.frontend = WavFrontend(
+			cmvn_file=cmvn_file,
+			**config['frontend_conf']
+		)
+		self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
+		self.batch_size = batch_size
+		self.vad_scorer = E2EVadModel(config["vad_post_conf"])
+		self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"]
+		self.encoder_conf = config["encoder_conf"]
+	
+	def prepare_cache(self, in_cache: list = []):
+		if len(in_cache) > 0:
+			return in_cache
+		fsmn_layers = self.encoder_conf["fsmn_layers"]
+		proj_dim = self.encoder_conf["proj_dim"]
+		lorder = self.encoder_conf["lorder"]
+		for i in range(fsmn_layers):
+			cache = np.zeros((1, proj_dim, lorder-1, 1)).astype(np.float32)
+			in_cache.append(cache)
+		return in_cache
+		
+	
+	def __call__(self, audio_in: Union[str, np.ndarray, List[str]], **kwargs) -> List:
+		# waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
+		
+		param_dict = kwargs.get('param_dict', dict())
+		is_final = param_dict.get('is_final', False)
+		audio_in_cache = param_dict.get('audio_in_cache', None)
+		audio_in_cum = audio_in
+		if audio_in_cache is not None:
+			audio_in_cum = np.concatenate((audio_in_cache, audio_in_cum))
+		param_dict['audio_in_cache'] = audio_in_cum
+		feats, feats_len = self.extract_feat([audio_in_cum])
+		
+		in_cache = param_dict.get('in_cache', list())
+		in_cache = self.prepare_cache(in_cache)
+		beg_idx = param_dict.get('beg_idx',0)
+		feats = feats[:, beg_idx:beg_idx+8, :]
+		param_dict['beg_idx'] = beg_idx + feats.shape[1]
+		try:
+			inputs = [feats]
+			inputs.extend(in_cache)
+			scores, out_caches = self.infer(inputs)
+			param_dict['in_cache'] = out_caches
+			segments = self.vad_scorer(scores, audio_in[None, :], is_final=is_final, max_end_sil=self.max_end_sil)
+			# print(segments)
+			if len(segments) == 1 and segments[0][0][1] != -1:
+				self.frontend.reset_status()
+			
+			
+		except ONNXRuntimeError:
+			logging.warning(traceback.format_exc())
+			logging.warning("input wav is silence or noise")
+			segments = []
+	
+		return segments
+
+	def load_data(self,
+	              wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+		def load_wav(path: str) -> np.ndarray:
+			waveform, _ = librosa.load(path, sr=fs)
+			return waveform
+		
+		if isinstance(wav_content, np.ndarray):
+			return [wav_content]
+		
+		if isinstance(wav_content, str):
+			return [load_wav(wav_content)]
+		
+		if isinstance(wav_content, list):
+			return [load_wav(path) for path in wav_content]
+		
+		raise TypeError(
+			f'The type of {wav_content} is not in [str, np.ndarray, list]')
+	
+	def extract_feat(self,
+	                 waveform_list: List[np.ndarray]
+	                 ) -> Tuple[np.ndarray, np.ndarray]:
+		feats, feats_len = [], []
+		for waveform in waveform_list:
+			speech, _ = self.frontend.fbank(waveform)
+			feat, feat_len = self.frontend.lfr_cmvn(speech)
+			feats.append(feat)
+			feats_len.append(feat_len)
+		
+		feats = self.pad_feats(feats, np.max(feats_len))
+		feats_len = np.array(feats_len).astype(np.int32)
+		return feats, feats_len
+	
+	@staticmethod
+	def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
+		def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
+			pad_width = ((0, max_feat_len - cur_len), (0, 0))
+			return np.pad(feat, pad_width, 'constant', constant_values=0)
+		
+		feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
+		feats = np.array(feat_res).astype(np.float32)
+		return feats
+	
+	def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
+		
+		outputs = self.ort_infer(feats)
+		scores, out_caches = outputs[0], outputs[1:]
+		return scores, out_caches
+	
\ No newline at end of file
diff --git a/funasr/runtime/python/onnxruntime/setup.py b/funasr/runtime/python/onnxruntime/setup.py
index 3b9ed3b..1a8ed7b 100644
--- a/funasr/runtime/python/onnxruntime/setup.py
+++ b/funasr/runtime/python/onnxruntime/setup.py
@@ -13,7 +13,7 @@
 
 
 MODULE_NAME = 'funasr_onnx'
-VERSION_NUM = '0.0.2'
+VERSION_NUM = '0.0.3'
 
 setuptools.setup(
     name=MODULE_NAME,
diff --git a/funasr/tasks/lm.py b/funasr/tasks/lm.py
index 608c1d3..80d66d5 100644
--- a/funasr/tasks/lm.py
+++ b/funasr/tasks/lm.py
@@ -15,7 +15,7 @@
 from funasr.datasets.collate_fn import CommonCollateFn
 from funasr.datasets.preprocessor import CommonPreprocessor
 from funasr.lm.abs_model import AbsLM
-from funasr.lm.espnet_model import ESPnetLanguageModel
+from funasr.lm.abs_model import LanguageModel
 from funasr.lm.seq_rnn_lm import SequentialRNNLM
 from funasr.lm.transformer_lm import TransformerLM
 from funasr.tasks.abs_task import AbsTask
@@ -83,7 +83,7 @@
         group.add_argument(
             "--model_conf",
             action=NestedDictAction,
-            default=get_default_kwargs(ESPnetLanguageModel),
+            default=get_default_kwargs(LanguageModel),
             help="The keyword arguments for model class.",
         )
 
@@ -178,7 +178,7 @@
         return retval
 
     @classmethod
-    def build_model(cls, args: argparse.Namespace) -> ESPnetLanguageModel:
+    def build_model(cls, args: argparse.Namespace) -> LanguageModel:
         assert check_argument_types()
         if isinstance(args.token_list, str):
             with open(args.token_list, encoding="utf-8") as f:
@@ -201,7 +201,7 @@
 
         # 2. Build ESPnetModel
         # Assume the last-id is sos_and_eos
-        model = ESPnetLanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
+        model = LanguageModel(lm=lm, vocab_size=vocab_size, **args.model_conf)
 
         # 3. Initialize
         if args.init is not None:
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index ea1e102..0170f28 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -14,10 +14,10 @@
 
 from funasr.datasets.collate_fn import CommonCollateFn
 from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
-from funasr.punctuation.abs_model import AbsPunctuation
-from funasr.punctuation.espnet_model import ESPnetPunctuationModel
-from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
-from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
+from funasr.train.abs_model import AbsPunctuation
+from funasr.train.abs_model import PunctuationModel
+from funasr.models.target_delay_transformer import TargetDelayTransformer
+from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
 from funasr.tasks.abs_task import AbsTask
 from funasr.text.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.initialize import initialize
@@ -79,7 +79,7 @@
         group.add_argument(
             "--model_conf",
             action=NestedDictAction,
-            default=get_default_kwargs(ESPnetPunctuationModel),
+            default=get_default_kwargs(PunctuationModel),
             help="The keyword arguments for model class.",
         )
 
@@ -183,7 +183,7 @@
         return retval
 
     @classmethod
-    def build_model(cls, args: argparse.Namespace) -> ESPnetPunctuationModel:
+    def build_model(cls, args: argparse.Namespace) -> PunctuationModel:
         assert check_argument_types()
         if isinstance(args.token_list, str):
             with open(args.token_list, encoding="utf-8") as f:
@@ -218,7 +218,7 @@
         # Assume the last-id is sos_and_eos
         if "punc_weight" in args.model_conf:
             args.model_conf.pop("punc_weight")
-        model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
+        model = PunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
 
         # FIXME(kamo): Should be done in model?
         # 3. Initialize
diff --git a/funasr/tasks/vad.py b/funasr/tasks/vad.py
index 22a5cb3..1ac77d3 100644
--- a/funasr/tasks/vad.py
+++ b/funasr/tasks/vad.py
@@ -40,7 +40,7 @@
 from funasr.models.frontend.abs_frontend import AbsFrontend
 from funasr.models.frontend.default import DefaultFrontend
 from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
 from funasr.models.frontend.s3prl import S3prlFrontend
 from funasr.models.frontend.windowing import SlidingWindow
 from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
@@ -81,6 +81,7 @@
         s3prl=S3prlFrontend,
         fused=FusedFrontends,
         wav_frontend=WavFrontend,
+        wav_frontend_online=WavFrontendOnline,
     ),
     type_check=AbsFrontend,
     default="default",
@@ -291,7 +292,24 @@
             model_class = model_choices.get_class(args.model)
         except AttributeError:
             model_class = model_choices.get_class("e2evad")
-        model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf)
+        
+        # 1. frontend
+        if args.input_size is None:
+            # Extract features in the model
+            frontend_class = frontend_choices.get_class(args.frontend)
+            if args.frontend == 'wav_frontend':
+                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+            else:
+                frontend = frontend_class(**args.frontend_conf)
+            input_size = frontend.output_size()
+        else:
+            # Give features from data-loader
+            args.frontend = None
+            args.frontend_conf = {}
+            frontend = None
+            input_size = args.input_size
+        
+        model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
 
         return model
 
@@ -301,6 +319,7 @@
             cls,
             config_file: Union[Path, str] = None,
             model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
             device: str = "cpu",
     ):
         """Build model from the files.
@@ -325,6 +344,8 @@
 
         with config_file.open("r", encoding="utf-8") as f:
             args = yaml.safe_load(f)
+        if cmvn_file is not None:
+            args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
         model.to(device)
diff --git a/funasr/punctuation/espnet_model.py b/funasr/train/abs_model.py
similarity index 86%
rename from funasr/punctuation/espnet_model.py
rename to funasr/train/abs_model.py
index 7266b38..1c7ff3d 100644
--- a/funasr/punctuation/espnet_model.py
+++ b/funasr/train/abs_model.py
@@ -1,3 +1,7 @@
+from abc import ABC
+from abc import abstractmethod
+
+
 from typing import Dict
 from typing import Optional
 from typing import Tuple
@@ -7,13 +11,34 @@
 from typeguard import check_argument_types
 
 from funasr.modules.nets_utils import make_pad_mask
-from funasr.punctuation.abs_model import AbsPunctuation
 from funasr.torch_utils.device_funcs import force_gatherable
 from funasr.train.abs_espnet_model import AbsESPnetModel
 
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
 
-class ESPnetPunctuationModel(AbsESPnetModel):
 
+class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
+    """The abstract class
+
+    To share the loss calculation way among different models,
+    We uses delegate pattern here:
+    The instance of this class should be passed to "LanguageModel"
+
+    This "model" is one of mediator objects for "Task" class.
+
+    """
+
+    @abstractmethod
+    def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def with_vad(self) -> bool:
+        raise NotImplementedError
+
+
+class PunctuationModel(AbsESPnetModel):
+    
     def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
         assert check_argument_types()
         super().__init__()
@@ -21,12 +46,12 @@
         self.punc_weight = torch.Tensor(punc_weight)
         self.sos = 1
         self.eos = 2
-
+        
         # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
         self.ignore_id = ignore_id
-        #if self.punc_model.with_vad():
+        # if self.punc_model.with_vad():
         #    print("This is a vad puncuation model.")
-
+    
     def nll(
         self,
         text: torch.Tensor,
@@ -54,7 +79,7 @@
         else:
             text = text[:, :max_length]
             punc = punc[:, :max_length]
-       
+        
         if self.punc_model.with_vad():
             # Should be VadRealtimeTransformer
             assert vad_indexes is not None
@@ -62,7 +87,7 @@
         else:
             # Should be TargetDelayTransformer,
             y, _ = self.punc_model(text, text_lengths)
-
+        
         # Calc negative log likelihood
         # nll: (BxL,)
         if self.training == False:
@@ -75,7 +100,8 @@
             return nll, text_lengths
         else:
             self.punc_weight = self.punc_weight.to(punc.device)
-            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
+            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
+                                  ignore_index=self.ignore_id)
         # nll: (BxL,) -> (BxL,)
         if max_length is None:
             nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -87,7 +113,7 @@
         # nll: (BxL,) -> (B, L)
         nll = nll.view(batch_size, -1)
         return nll, text_lengths
-
+    
     def batchify_nll(self,
                      text: torch.Tensor,
                      punc: torch.Tensor,
@@ -113,7 +139,7 @@
             nlls = []
             x_lengths = []
             max_length = text_lengths.max()
-
+            
             start_idx = 0
             while True:
                 end_idx = min(start_idx + batch_size, total_num)
@@ -132,7 +158,7 @@
         assert nll.size(0) == total_num
         assert x_lengths.size(0) == total_num
         return nll, x_lengths
-
+    
     def forward(
         self,
         text: torch.Tensor,
@@ -146,15 +172,15 @@
         ntokens = y_lengths.sum()
         loss = nll.sum() / ntokens
         stats = dict(loss=loss.detach())
-
+        
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
         return loss, stats, weight
-
+    
     def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
                       text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
         return {}
-
+    
     def inference(self,
                   text: torch.Tensor,
                   text_lengths: torch.Tensor,

--
Gitblit v1.9.1