From ee06cb9c6870d9e1579015aabfe1a84a61a5c087 Mon Sep 17 00:00:00 2001
From: 九耳 <mengzhe.cmz@alibaba-inc.com>
Date: 星期二, 28 二月 2023 18:11:12 +0800
Subject: [PATCH] punctuation:add training code, support largedataset

---
 funasr/modules/mask.py                             |   17 
 funasr/punctuation/abs_model.py                    |    4 
 funasr/datasets/large_datasets/build_dataloader.py |    8 
 funasr/datasets/large_datasets/utils/padding.py    |    5 
 funasr/tasks/abs_task.py                           |    2 
 funasr/punctuation/espnet_model.py                 |   55 +-
 funasr/bin/punc_train.py                           |   43 ++
 funasr/datasets/large_datasets/dataset.py          |   16 
 funasr/tasks/punctuation.py                        |   25 
 funasr/modules/attention.py                        |   12 
 funasr/datasets/preprocessor.py                    |  100 +++++
 funasr/punctuation/sanm_encoder.py                 |  590 +++++++++++++++++++++++++++++++
 funasr/bin/punc_train_vadrealtime.py               |   44 ++
 funasr/punctuation/target_delay_transformer.py     |    5 
 funasr/datasets/large_datasets/utils/tokenize.py   |   29 +
 funasr/punctuation/vad_realtime_transformer.py     |  132 ++++++
 16 files changed, 1,042 insertions(+), 45 deletions(-)

diff --git a/funasr/bin/punc_train.py b/funasr/bin/punc_train.py
new file mode 100644
index 0000000..61b63ec
--- /dev/null
+++ b/funasr/bin/punc_train.py
@@ -0,0 +1,43 @@
+#!/usr/bin/env python3
+import os
+from funasr.tasks.punctuation import PunctuationTask
+
+
+def parse_args():
+    parser = PunctuationTask.get_parser()
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
+    )
+    parser.add_argument(
+        "--punc_list",
+        type=str,
+        default=None,
+        help="Punctuation list",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def main(args=None, cmd=None):
+    """
+    punc training.
+    """
+    PunctuationTask.main(args=args, cmd=cmd)
+
+
+if __name__ == "__main__":
+    args = parse_args()
+
+    # setup local gpu_id
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+    # DDP settings
+    if args.ngpu > 1:
+        args.distributed = True
+    else:
+        args.distributed = False
+
+    main(args=args)
diff --git a/funasr/bin/punc_train_vadrealtime.py b/funasr/bin/punc_train_vadrealtime.py
new file mode 100644
index 0000000..c5afaad
--- /dev/null
+++ b/funasr/bin/punc_train_vadrealtime.py
@@ -0,0 +1,44 @@
+#!/usr/bin/env python3
+import os
+from funasr.tasks.punctuation import PunctuationTask
+
+
+def parse_args():
+    parser = PunctuationTask.get_parser()
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
+    )
+    parser.add_argument(
+        "--punc_list",
+        type=str,
+        default=None,
+        help="Punctuation list",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def main(args=None, cmd=None):
+    """
+    punc training.
+    """
+    PunctuationTask.main(args=args, cmd=cmd)
+
+
+if __name__ == "__main__":
+    args = parse_args()
+
+    # setup local gpu_id
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+    # DDP settings
+    if args.ngpu > 1:
+        args.distributed = True
+    else:
+        args.distributed = False
+    assert args.num_worker_count == 1
+
+    main(args=args)
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 8f7fd0b..093ad60 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -34,16 +34,20 @@
     return seg_dict
 
 class ArkDataLoader(AbsIterFactory):
-    def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, mode="train"):
+    def __init__(self, data_list, dict_file, dataset_conf, seg_dict_file=None, punc_dict_file=None, mode="train"):
         symbol_table = read_symbol_table(dict_file) if dict_file is not None else None
         if seg_dict_file is not None:
             seg_dict = load_seg_dict(seg_dict_file)
         else:
             seg_dict = None
+        if punc_dict_file is not None:
+            punc_dict = read_symbol_table(punc_dict_file)
+        else:
+            punc_dict = None
         self.dataset_conf = dataset_conf
         logging.info("dataloader config: {}".format(self.dataset_conf))
         batch_mode = self.dataset_conf.get("batch_mode", "padding")
-        self.dataset = Dataset(data_list, symbol_table, seg_dict,
+        self.dataset = Dataset(data_list, symbol_table, seg_dict, punc_dict,
                                self.dataset_conf, mode=mode, batch_mode=batch_mode)
 
     def build_iter(self, epoch, shuffle=True):
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 2123737..61231d2 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -127,14 +127,17 @@
                             sample_dict["key"] = key
                     else:
                         text = item
-                        sample_dict[data_name] = text.strip().split()[1:]
+                        segs = text.strip().split()
+                        sample_dict[data_name] = segs[1:]
+                        if "key" not in sample_dict:
+                            sample_dict["key"] = segs[0]
                 yield sample_dict
 
             self.close_reader(reader_list)
 
 
 def len_fn_example(data):
-    return len(data)
+    return 1
 
 
 def len_fn_token(data):
@@ -148,6 +151,7 @@
 def Dataset(data_list_file,
             dict,
             seg_dict,
+            punc_dict,
             conf,
             mode="train",
             batch_mode="padding"):
@@ -162,7 +166,7 @@
     dataset = FilterIterDataPipe(dataset, fn=filter_fn)
 
     if "text" in data_names:
-        vocab = {'vocab': dict, 'seg_dict': seg_dict}
+        vocab = {'vocab': dict, 'seg_dict': seg_dict, 'punc_dict': punc_dict}
         tokenize_fn = partial(tokenize, **vocab)
         dataset = MapperIterDataPipe(dataset, fn=tokenize_fn)
 
@@ -191,6 +195,10 @@
                                              sort_size=sort_size,
                                              batch_mode=batch_mode)
 
-    dataset = MapperIterDataPipe(dataset, fn=padding if batch_mode == "padding" else clipping)
+    int_pad_value = conf.get("int_pad_value", -1)
+    float_pad_value = conf.get("float_pad_value", 0.0)
+    padding_conf = {"int_pad_value": int_pad_value, "float_pad_value": float_pad_value}
+    padding_fn = partial(padding, **padding_conf)
+    dataset = MapperIterDataPipe(dataset, fn=padding_fn if batch_mode == "padding" else clipping)
 
     return dataset
diff --git a/funasr/datasets/large_datasets/utils/padding.py b/funasr/datasets/large_datasets/utils/padding.py
index e814b1c..e0feac6 100644
--- a/funasr/datasets/large_datasets/utils/padding.py
+++ b/funasr/datasets/large_datasets/utils/padding.py
@@ -6,9 +6,8 @@
 def padding(data, float_pad_value=0.0, int_pad_value=-1):
     assert isinstance(data, list)
     assert "key" in data[0]
-    assert "speech" in data[0]
-    assert "text" in data[0]
-
+    assert "speech" in data[0] or "text" in data[0]
+    
     keys = [x["key"] for x in data]
 
     batch = {}
diff --git a/funasr/datasets/large_datasets/utils/tokenize.py b/funasr/datasets/large_datasets/utils/tokenize.py
index 0c01885..caeb426 100644
--- a/funasr/datasets/large_datasets/utils/tokenize.py
+++ b/funasr/datasets/large_datasets/utils/tokenize.py
@@ -31,22 +31,43 @@
 
 def tokenize(data,
              vocab=None,
-             seg_dict=None):
+             seg_dict=None,
+             punc_dict=None):
     assert "text" in data
     assert isinstance(vocab, dict)
     text = data["text"]
     token = []
+    vad = -2
 
     if seg_dict is not None:
         assert isinstance(seg_dict, dict)
         txt = forward_segment("".join(text).lower(), seg_dict)
         text = seg_tokenize(txt, seg_dict)
-    
-    for x in text:
-        if x in vocab:
+
+    length = len(text)
+    for i in range(length):
+        x = text[i]
+        if i == length-1 and "punc" in data and text[i].startswith("vad:"):
+            vad = x[-1][4:]
+            if len(vad) == 0:
+                vad = -1
+            else:
+                vad = int(vad)
+        elif x in vocab:
             token.append(vocab[x])
         else:
             token.append(vocab['<unk>'])
 
+    if "punc" in data and punc_dict is not None:
+        punc_token = []
+        for punc in data["punc"]:
+            if punc in punc_dict:
+                punc_token.append(punc_dict[punc])
+            else:
+                punc_token.append(punc_dict["_"])
+        data["punc"] =  np.array(punc_token)
+
     data["text"] = np.array(token)
+    if vad is not -2:
+        data["vad_indexes"]=np.array([vad], dtype=np.int64)
     return data
diff --git a/funasr/datasets/preprocessor.py b/funasr/datasets/preprocessor.py
index 8e86794..20a3791 100644
--- a/funasr/datasets/preprocessor.py
+++ b/funasr/datasets/preprocessor.py
@@ -704,3 +704,103 @@
         del data[self.split_text_name]
         return result
 
+class PuncTrainTokenizerCommonPreprocessor(CommonPreprocessor):
+    def __init__(
+            self,
+            train: bool,
+            token_type: List[str] = [None],
+            token_list: List[Union[Path, str, Iterable[str]]] = [None],
+            bpemodel: List[Union[Path, str, Iterable[str]]] = [None],
+            text_cleaner: Collection[str] = None,
+            g2p_type: str = None,
+            unk_symbol: str = "<unk>",
+            space_symbol: str = "<space>",
+            non_linguistic_symbols: Union[Path, str, Iterable[str]] = None,
+            delimiter: str = None,
+            rir_scp: str = None,
+            rir_apply_prob: float = 1.0,
+            noise_scp: str = None,
+            noise_apply_prob: float = 1.0,
+            noise_db_range: str = "3_10",
+            speech_volume_normalize: float = None,
+            speech_name: str = "speech",
+            text_name: List[str] = ["text"],
+            vad_name: str = "vad_indexes",
+    ):
+        # TODO(jiatong): sync with Kamo and Jing on interface for preprocessor
+        super().__init__(
+            train=train,
+            token_type=token_type[0],
+            token_list=token_list[0],
+            bpemodel=bpemodel[0],
+            text_cleaner=text_cleaner,
+            g2p_type=g2p_type,
+            unk_symbol=unk_symbol,
+            space_symbol=space_symbol,
+            non_linguistic_symbols=non_linguistic_symbols,
+            delimiter=delimiter,
+            speech_name=speech_name,
+            text_name=text_name[0],
+            rir_scp=rir_scp,
+            rir_apply_prob=rir_apply_prob,
+            noise_scp=noise_scp,
+            noise_apply_prob=noise_apply_prob,
+            noise_db_range=noise_db_range,
+            speech_volume_normalize=speech_volume_normalize,
+        )
+
+        assert (
+                len(token_type) == len(token_list) == len(bpemodel) == len(text_name)
+        ), "token_type, token_list, bpemodel, or processing text_name mismatched"
+        self.num_tokenizer = len(token_type)
+        self.tokenizer = []
+        self.token_id_converter = []
+
+        for i in range(self.num_tokenizer):
+            if token_type[i] is not None:
+                if token_list[i] is None:
+                    raise ValueError("token_list is required if token_type is not None")
+
+                self.tokenizer.append(
+                    build_tokenizer(
+                        token_type=token_type[i],
+                        bpemodel=bpemodel[i],
+                        delimiter=delimiter,
+                        space_symbol=space_symbol,
+                        non_linguistic_symbols=non_linguistic_symbols,
+                        g2p_type=g2p_type,
+                    )
+                )
+                self.token_id_converter.append(
+                    TokenIDConverter(
+                        token_list=token_list[i],
+                        unk_symbol=unk_symbol,
+                    )
+                )
+            else:
+                self.tokenizer.append(None)
+                self.token_id_converter.append(None)
+
+        self.text_cleaner = TextCleaner(text_cleaner)
+        self.text_name = text_name  # override the text_name from CommonPreprocessor
+        self.vad_name = vad_name
+
+    def _text_process(
+            self, data: Dict[str, Union[str, np.ndarray]]
+    ) -> Dict[str, np.ndarray]:
+        for i in range(self.num_tokenizer):
+            text_name = self.text_name[i]
+            if text_name in data and self.tokenizer[i] is not None:
+                text = data[text_name]
+                text = self.text_cleaner(text)
+                tokens = self.tokenizer[i].text2tokens(text)
+                if "vad:" in tokens[-1]:
+                    vad = tokens[-1][4:]
+                    tokens = tokens[:-1]
+                    if len(vad) == 0:
+                        vad = -1
+                    else:
+                        vad = int(vad)
+                    data[self.vad_name] = np.array([vad], dtype=np.int64)
+                text_ints = self.token_id_converter[i].tokens2ids(tokens)
+                data[text_name] = np.array(text_ints, dtype=np.int64)
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index c47d96d..6277005 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -439,6 +439,18 @@
         att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
         return att_outs + fsmn_memory
 
+class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
+        return att_outs + fsmn_memory
+
 class MultiHeadedAttentionSANMDecoder(nn.Module):
     """Multi-Head Attention layer.
 
diff --git a/funasr/modules/mask.py b/funasr/modules/mask.py
index 8f068e1..a8c168b 100644
--- a/funasr/modules/mask.py
+++ b/funasr/modules/mask.py
@@ -33,3 +33,20 @@
     ys_mask = ys_in_pad != ignore_id
     m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0)
     return ys_mask.unsqueeze(-2) & m
+
+def vad_mask(size, vad_pos, device="cpu", dtype=torch.bool):
+    """Create mask for decoder self-attention.
+
+    :param int size: size of mask
+    :param int vad_pos: index of vad index
+    :param str device: "cpu" or "cuda" or torch.Tensor.device
+    :param torch.dtype dtype: result dtype
+    :rtype: torch.Tensor (B, Lmax, Lmax)
+    """
+    ret = torch.ones(size, size, device=device, dtype=dtype)
+    if vad_pos <= 0 or vad_pos >= size:
+        return ret
+    sub_corner = torch.zeros(
+        vad_pos - 1, size - vad_pos, device=device, dtype=dtype)
+    ret[0:vad_pos - 1, vad_pos:] = sub_corner
+    return ret
diff --git a/funasr/punctuation/abs_model.py b/funasr/punctuation/abs_model.py
index 5f6afb7..404d5e8 100644
--- a/funasr/punctuation/abs_model.py
+++ b/funasr/punctuation/abs_model.py
@@ -25,3 +25,7 @@
     @abstractmethod
     def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
         raise NotImplementedError
+
+    @abstractmethod
+    def with_vad(self) -> bool:
+        raise NotImplementedError
diff --git a/funasr/punctuation/espnet_model.py b/funasr/punctuation/espnet_model.py
index 65edaad..c513779 100644
--- a/funasr/punctuation/espnet_model.py
+++ b/funasr/punctuation/espnet_model.py
@@ -14,15 +14,18 @@
 
 class ESPnetPunctuationModel(AbsESPnetModel):
 
-    def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0):
+    def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
         assert check_argument_types()
         super().__init__()
         self.punc_model = punc_model
+        self.punc_weight = torch.Tensor(punc_weight)
         self.sos = 1
         self.eos = 2
 
         # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
         self.ignore_id = ignore_id
+        if self.punc_model.with_vad():
+            print("This is a vad puncuation model.")
 
     def nll(
         self,
@@ -31,6 +34,8 @@
         text_lengths: torch.Tensor,
         punc_lengths: torch.Tensor,
         max_length: Optional[int] = None,
+        vad_indexes: Optional[torch.Tensor] = None,
+        vad_indexes_lengths: Optional[torch.Tensor] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Compute negative log likelihood(nll)
 
@@ -49,19 +54,16 @@
         else:
             text = text[:, :max_length]
             punc = punc[:, :max_length]
-        # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
-        # text: (Batch, Length) -> x, y: (Batch, Length + 1)
-        #x = F.pad(text, [1, 0], "constant", self.eos)
-        #t = F.pad(text, [0, 1], "constant", self.ignore_id)
-        #for i, l in enumerate(text_lengths):
-        #    t[i, l] = self.sos
-        #x_lengths = text_lengths + 1
+       
+        if self.punc_model.with_vad():
+            # Should be VadRealtimeTransformer
+            assert vad_indexes is not None
+            y, _ = self.punc_model(text, text_lengths, vad_indexes)
+        else:
+            # Should be TargetDelayTransformer,
+            y, _ = self.punc_model(text, text_lengths)
 
-        # 2. Forward Language model
-        # x: (Batch, Length) -> y: (Batch, Length, NVocab)
-        y, _ = self.punc_model(text, text_lengths)
-
-        # 3. Calc negative log likelihood
+        # Calc negative log likelihood
         # nll: (BxL,)
         if self.training == False:
             _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
@@ -72,7 +74,8 @@
             nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
             return nll, text_lengths
         else:
-            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), reduction="none", ignore_index=self.ignore_id)
+            self.punc_weight = self.punc_weight.to(punc.device)
+            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
         # nll: (BxL,) -> (BxL,)
         if max_length is None:
             nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -130,9 +133,16 @@
         assert x_lengths.size(0) == total_num
         return nll, x_lengths
 
-    def forward(self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor,
-                punc_lengths: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths)
+    def forward(
+        self,
+        text: torch.Tensor,
+        punc: torch.Tensor,
+        text_lengths: torch.Tensor,
+        punc_lengths: torch.Tensor,
+        vad_indexes: Optional[torch.Tensor] = None,
+        vad_indexes_lengths: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
         ntokens = y_lengths.sum()
         loss = nll.sum() / ntokens
         stats = dict(loss=loss.detach())
@@ -145,5 +155,12 @@
                       text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
         return {}
 
-    def inference(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
-        return self.punc_model(text, text_lengths)
+    def inference(self,
+                  text: torch.Tensor,
+                  text_lengths: torch.Tensor,
+                  vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
+        if self.punc_model.with_vad():
+            assert vad_indexes is not None
+            return self.punc_model(text, text_lengths, vad_indexes)
+        else:
+            return self.punc_model(text, text_lengths)
diff --git a/funasr/punctuation/sanm_encoder.py b/funasr/punctuation/sanm_encoder.py
new file mode 100644
index 0000000..8962093
--- /dev/null
+++ b/funasr/punctuation/sanm_encoder.py
@@ -0,0 +1,590 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
+from typeguard import check_argument_types
+import numpy as np
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
+from funasr.modules.embedding import SinusoidalPositionEncoder
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.multi_layer_conv import Conv1dLinear
+from funasr.modules.multi_layer_conv import MultiLayeredConv1d
+from funasr.modules.positionwise_feed_forward import (
+    PositionwiseFeedForward,  # noqa: H301
+)
+from funasr.modules.repeat import repeat
+from funasr.modules.subsampling import Conv2dSubsampling
+from funasr.modules.subsampling import Conv2dSubsampling2
+from funasr.modules.subsampling import Conv2dSubsampling6
+from funasr.modules.subsampling import Conv2dSubsampling8
+from funasr.modules.subsampling import TooShortUttError
+from funasr.modules.subsampling import check_short_utt
+from funasr.models.ctc import CTC
+from funasr.models.encoder.abs_encoder import AbsEncoder
+
+from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.mask import subsequent_mask, vad_mask
+
+class EncoderLayerSANM(nn.Module):
+    def __init__(
+        self,
+        in_size,
+        size,
+        self_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+        stochastic_depth_rate=0.0,
+    ):
+        """Construct an EncoderLayer object."""
+        super(EncoderLayerSANM, self).__init__()
+        self.self_attn = self_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(in_size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.in_size = in_size
+        self.size = size
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear = nn.Linear(size + size, size)
+        self.stochastic_depth_rate = stochastic_depth_rate
+        self.dropout_rate = dropout_rate
+
+    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+        skip_layer = False
+        # with stochastic depth, residual connection `x + f(x)` becomes
+        # `x <- x + 1 / (1 - p) * f(x)` at training time.
+        stoch_layer_coeff = 1.0
+        if self.training and self.stochastic_depth_rate > 0:
+            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+        if skip_layer:
+            if cache is not None:
+                x = torch.cat([cache, x], dim=1)
+            return x, mask
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.concat_after:
+            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+            else:
+                x = stoch_layer_coeff * self.concat_linear(x_concat)
+        else:
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+            else:
+                x = stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+
+        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+class SANMEncoder(AbsEncoder):
+    """
+    author: Speech Lab, Alibaba Group, China
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        pos_enc_class=SinusoidalPositionEncoder,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
+        kernel_size : int = 11,
+        sanm_shfit : int = 0,
+        selfattention_layer_type: str = "sanm",
+    ):
+        assert check_argument_types()
+        super().__init__()
+        self._output_size = output_size
+
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                SinusoidalPositionEncoder(),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        elif input_layer == "pe":
+            self.embed = SinusoidalPositionEncoder()
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+
+        elif selfattention_layer_type == "sanm":
+            self.encoder_selfattn_layer = MultiHeadedAttentionSANM
+            encoder_selfattn_layer_args0 = (
+                attention_heads,
+                input_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+        self.encoders = repeat(
+            num_blocks-1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+        self.dropout = nn.Dropout(dropout_rate)
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        xs_pad *= self.output_size()**0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling2)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
+        ):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
+                    + f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        # xs_pad = self.dropout(xs_pad)
+        encoder_outs = self.encoders0(xs_pad, masks)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+        if len(self.interctc_layer_idx) == 0:
+            encoder_outs = self.encoders(xs_pad, masks)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                encoder_outs = encoder_layer(xs_pad, masks)
+                xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens, None
+
+class SANMVadEncoder(AbsEncoder):
+    """
+    author: Speech Lab, Alibaba Group, China
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        pos_enc_class=SinusoidalPositionEncoder,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
+        kernel_size : int = 11,
+        sanm_shfit : int = 0,
+        selfattention_layer_type: str = "sanm",
+    ):
+        assert check_argument_types()
+        super().__init__()
+        self._output_size = output_size
+
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                SinusoidalPositionEncoder(),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        elif input_layer == "pe":
+            self.embed = SinusoidalPositionEncoder()
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+
+        elif selfattention_layer_type == "sanm":
+            self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
+            encoder_selfattn_layer_args0 = (
+                attention_heads,
+                input_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+        self.encoders = repeat(
+            num_blocks-1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+        self.dropout = nn.Dropout(dropout_rate)
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        vad_indexes: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
+        no_future_masks = masks & sub_masks
+        xs_pad *= self.output_size()**0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
+              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
+                    f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        # xs_pad = self.dropout(xs_pad)
+        mask_tup0 = [masks, no_future_masks]
+        encoder_outs = self.encoders0(xs_pad, mask_tup0)
+        xs_pad, _ = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+        #if len(self.interctc_layer_idx) == 0:
+        if False:
+            # Here, we should not use the repeat operation to do it for all layers.
+            encoder_outs = self.encoders(xs_pad, masks)
+            xs_pad, masks = encoder_outs[0], encoder_outs[1]
+        else:
+            for layer_idx, encoder_layer in enumerate(self.encoders):
+                if layer_idx + 1 == len(self.encoders):
+                    # This is last layer.
+                    coner_mask = torch.ones(masks.size(0),
+                                            masks.size(-1),
+                                            masks.size(-1),
+                                            device=xs_pad.device,
+                                            dtype=torch.bool)
+                    for word_index, length in enumerate(ilens):
+                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
+                                                                vad_indexes[word_index],
+                                                                device=xs_pad.device)
+                    layer_mask = masks & coner_mask
+                else:
+                    layer_mask = no_future_masks
+                mask_tup1 = [masks, layer_mask]
+                encoder_outs = encoder_layer(xs_pad, mask_tup1)
+                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
+
+                if layer_idx + 1 in self.interctc_layer_idx:
+                    encoder_out = xs_pad
+
+                    # intermediate outputs are also normalized
+                    if self.normalize_before:
+                        encoder_out = self.after_norm(encoder_out)
+
+                    intermediate_outs.append((layer_idx + 1, encoder_out))
+
+                    if self.interctc_use_conditioning:
+                        ctc_out = ctc.softmax(encoder_out)
+                        xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens, None
+
diff --git a/funasr/punctuation/target_delay_transformer.py b/funasr/punctuation/target_delay_transformer.py
index 10cc5a8..219af26 100644
--- a/funasr/punctuation/target_delay_transformer.py
+++ b/funasr/punctuation/target_delay_transformer.py
@@ -8,7 +8,7 @@
 from funasr.modules.embedding import PositionalEncoding
 from funasr.modules.embedding import SinusoidalPositionEncoder
 #from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
-from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
+from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
 #from funasr.modules.mask import subsequent_n_mask
 from funasr.punctuation.abs_model import AbsPunctuation
 
@@ -73,6 +73,9 @@
         y = self.decoder(h)
         return y, None
 
+    def with_vad(self):
+        return False
+
     def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
         """Score new token.
 
diff --git a/funasr/punctuation/vad_realtime_transformer.py b/funasr/punctuation/vad_realtime_transformer.py
new file mode 100644
index 0000000..35224f9
--- /dev/null
+++ b/funasr/punctuation/vad_realtime_transformer.py
@@ -0,0 +1,132 @@
+from typing import Any
+from typing import List
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from funasr.modules.embedding import SinusoidalPositionEncoder
+from funasr.punctuation.sanm_encoder import SANMVadEncoder as Encoder
+from funasr.punctuation.abs_model import AbsPunctuation
+
+
+class VadRealtimeTransformer(AbsPunctuation):
+
+    def __init__(
+        self,
+        vocab_size: int,
+        punc_size: int,
+        pos_enc: str = None,
+        embed_unit: int = 128,
+        att_unit: int = 256,
+        head: int = 2,
+        unit: int = 1024,
+        layer: int = 4,
+        dropout_rate: float = 0.5,
+        kernel_size: int = 11,
+        sanm_shfit: int = 0,
+    ):
+        super().__init__()
+        if pos_enc == "sinusoidal":
+            #            pos_enc_class = PositionalEncoding
+            pos_enc_class = SinusoidalPositionEncoder
+        elif pos_enc is None:
+
+            def pos_enc_class(*args, **kwargs):
+                return nn.Sequential()  # indentity
+
+        else:
+            raise ValueError(f"unknown pos-enc option: {pos_enc}")
+
+        self.embed = nn.Embedding(vocab_size, embed_unit)
+        self.encoder = Encoder(
+            input_size=embed_unit,
+            output_size=att_unit,
+            attention_heads=head,
+            linear_units=unit,
+            num_blocks=layer,
+            dropout_rate=dropout_rate,
+            input_layer="pe",
+            # pos_enc_class=pos_enc_class,
+            padding_idx=0,
+            kernel_size=kernel_size,
+            sanm_shfit=sanm_shfit,
+        )
+        self.decoder = nn.Linear(att_unit, punc_size)
+
+
+#    def _target_mask(self, ys_in_pad):
+#        ys_mask = ys_in_pad != 0
+#        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
+#        return ys_mask.unsqueeze(-2) & m
+
+    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
+                vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
+        """Compute loss value from buffer sequences.
+
+        Args:
+            input (torch.Tensor): Input ids. (batch, len)
+            hidden (torch.Tensor): Target ids. (batch, len)
+
+        """
+        x = self.embed(input)
+        # mask = self._target_mask(input)
+        h, _, _ = self.encoder(x, text_lengths, vad_indexes)
+        y = self.decoder(h)
+        return y, None
+
+    def with_vad(self):
+        return True
+
+    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
+        """Score new token.
+
+        Args:
+            y (torch.Tensor): 1D torch.int64 prefix tokens.
+            state: Scorer state for prefix tokens
+            x (torch.Tensor): encoder feature that generates ys.
+
+        Returns:
+            tuple[torch.Tensor, Any]: Tuple of
+                torch.float32 scores for next token (vocab_size)
+                and next state for ys
+
+        """
+        y = y.unsqueeze(0)
+        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
+        h = self.decoder(h[:, -1])
+        logp = h.log_softmax(dim=-1).squeeze(0)
+        return logp, cache
+
+    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
+        """Score new token batch.
+
+        Args:
+            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+            states (List[Any]): Scorer states for prefix tokens.
+            xs (torch.Tensor):
+                The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+        Returns:
+            tuple[torch.Tensor, List[Any]]: Tuple of
+                batchfied scores for next token with shape of `(n_batch, vocab_size)`
+                and next state list for ys.
+
+        """
+        # merge states
+        n_batch = len(ys)
+        n_layers = len(self.encoder.encoders)
+        if states[0] is None:
+            batch_state = None
+        else:
+            # transpose state of [batch, layer] into [layer, batch]
+            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
+
+        # batch decoding
+        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
+        h = self.decoder(h[:, -1])
+        logp = h.log_softmax(dim=-1)
+
+        # transpose state of [layer, batch] into [batch, layer]
+        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
+        return logp, state_list
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 5be9089..d2a00b2 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -1350,10 +1350,12 @@
                 train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
                                                    seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                                "seg_dict_file") else None,
+                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
                                                    mode="train")
                 valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
                                                    seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                                "seg_dict_file") else None,
+                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
                                                    mode="eval")
             elif args.dataset_type == "small":
                 train_iter_factory = cls.build_iter_factory(
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index 1837b2a..ea1e102 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -13,10 +13,11 @@
 from typeguard import check_return_type
 
 from funasr.datasets.collate_fn import CommonCollateFn
-from funasr.datasets.preprocessor import MutliTokenizerCommonPreprocessor
+from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
 from funasr.punctuation.abs_model import AbsPunctuation
 from funasr.punctuation.espnet_model import ESPnetPunctuationModel
 from funasr.punctuation.target_delay_transformer import TargetDelayTransformer
+from funasr.punctuation.vad_realtime_transformer import VadRealtimeTransformer
 from funasr.tasks.abs_task import AbsTask
 from funasr.text.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.initialize import initialize
@@ -29,11 +30,9 @@
 
 punc_choices = ClassChoices(
     "punctuation",
-    classes=dict(
-        target_delay=TargetDelayTransformer,
-    ),
+    classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
     type_check=AbsPunctuation,
-    default="TargetDelayTransformer",
+    default="target_delay",
 )
 
 
@@ -56,8 +55,6 @@
         # NOTE(kamo): add_arguments(..., required=True) can't be used
         # to provide --print_config mode. Instead of it, do as
         required = parser.get_default("required")
-        #import pdb;pdb.set_trace()
-        #required += ["token_list"]
 
         group.add_argument(
             "--token_list",
@@ -154,7 +151,7 @@
         bpemodels = [args.bpemodel, args.bpemodel]
         text_names = ["text", "punc"]
         if args.use_preprocessor:
-            retval = MutliTokenizerCommonPreprocessor(
+            retval = PuncTrainTokenizerCommonPreprocessor(
                 train=train,
                 token_type=token_types,
                 token_list=token_lists,
@@ -182,7 +179,7 @@
     def optional_data_names(
             cls, train: bool = True, inference: bool = False
     ) -> Tuple[str, ...]:
-        retval = ()
+        retval = ("vad",)
         return retval
 
     @classmethod
@@ -197,11 +194,13 @@
             args.token_list = token_list.copy()
         if isinstance(args.punc_list, str):
             with open(args.punc_list, encoding="utf-8") as f2:
-                punc_list = [line.rstrip() for line in f2]
+                pairs = [line.rstrip().split(":") for line in f2]
+            punc_list = [pair[0] for pair in pairs]
+            punc_weight_list = [float(pair[1]) for pair in pairs]
             args.punc_list = punc_list.copy()
         elif isinstance(args.punc_list, list):
-            # This is in the inference code path.
             punc_list = args.punc_list.copy()
+            punc_weight_list = [1] * len(punc_list)
         if isinstance(args.token_list, (tuple, list)):
             token_list = args.token_list.copy()
         else:
@@ -217,7 +216,9 @@
 
         # 2. Build ESPnetModel
         # Assume the last-id is sos_and_eos
-        model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, **args.model_conf)
+        if "punc_weight" in args.model_conf:
+            args.model_conf.pop("punc_weight")
+        model = ESPnetPunctuationModel(punc_model=punc, vocab_size=vocab_size, punc_weight=punc_weight_list, **args.model_conf)
 
         # FIXME(kamo): Should be done in model?
         # 3. Initialize

--
Gitblit v1.9.1