From bdfd27b9e96bd55c449953bb577e1d4deeaf11c9 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 13 一月 2024 23:43:17 +0800
Subject: [PATCH] funasr1.0

---
 funasr/models/ct_transformer_streaming/__init__.py                 |    0 
 funasr/models/ct_transformer_streaming/attention.py                | 1091 +++++++++++++++++++++++++++++
 examples/industrial_data_pretraining/paraformer-zh-spk/demo.py     |    2 
 examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh    |    2 
 funasr/models/ct_transformer_streaming/vad_realtime_transformer.py |  135 +++
 funasr/download/download_from_hub.py                               |   34 
 examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py    |    2 
 funasr/models/ct_transformer_streaming/template.yaml               |   53 +
 examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh   |    2 
 funasr/models/ct_transformer_streaming/utils.py                    |  111 +++
 funasr/models/ct_transformer_streaming/model.py                    |  340 +++++++++
 examples/industrial_data_pretraining/emotion2vec/demo.py           |    2 
 examples/industrial_data_pretraining/bicif_paraformer/demo.py      |    4 
 examples/industrial_data_pretraining/seaco_paraformer/infer.sh     |    2 
 examples/industrial_data_pretraining/seaco_paraformer/demo.py      |    2 
 funasr/models/ct_transformer_streaming/encoder.py                  |  383 ++++++++++
 16 files changed, 2,147 insertions(+), 18 deletions(-)

diff --git a/examples/industrial_data_pretraining/bicif_paraformer/demo.py b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
index 4d921ea..57edb68 100644
--- a/examples/industrial_data_pretraining/bicif_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/bicif_paraformer/demo.py
@@ -8,7 +8,7 @@
 model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                     model_revision="v2.0.0",
                     vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-                    vad_model_revision="v2.0.1",
+                    vad_model_revision="v2.0.2",
                     punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                     punc_model_revision="v2.0.1",
                     spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
@@ -21,7 +21,7 @@
 model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                     model_revision="v2.0.0",
                     vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-                    vad_model_revision="v2.0.1",
+                    vad_model_revision="v2.0.2",
                     punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                     punc_model_revision="v2.0.1",
                     spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
diff --git a/examples/industrial_data_pretraining/emotion2vec/demo.py b/examples/industrial_data_pretraining/emotion2vec/demo.py
index 3653313..abaa9f4 100644
--- a/examples/industrial_data_pretraining/emotion2vec/demo.py
+++ b/examples/industrial_data_pretraining/emotion2vec/demo.py
@@ -5,7 +5,7 @@
 
 from funasr import AutoModel
 
-model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.0")
+model = AutoModel(model="damo/emotion2vec_base", model_revision="v2.0.1")
 
 res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", output_dir="./outputs")
 print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
index 4e3cb70..459dfff 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -7,7 +7,7 @@
 wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
 
 chunk_size = 60000 # ms
-model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.1")
+model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.2")
 
 res = model(input=wav_file, chunk_size=chunk_size, )
 print(res)
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh b/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh
index 08ef8bd..815c52a 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh
@@ -1,7 +1,7 @@
 
 
 model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
-model_revision="v2.0.1"
+model_revision="v2.0.2"
 
 python funasr/bin/inference.py \
 +model=${model} \
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
index fcf5f60..fc3a635 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py
@@ -8,7 +8,7 @@
 model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                   model_revision="v2.0.0",
                   vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-                  vad_model_revision="v2.0.1",
+                  vad_model_revision="v2.0.2",
                   punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                   punc_model_revision="v2.0.1",
                   spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
index 63347b6..f3fa90d 100644
--- a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
+++ b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh
@@ -2,7 +2,7 @@
 model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
 model_revision="v2.0.0"
 vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
-vad_model_revision="v2.0.1"
+vad_model_revision="v2.0.2"
 punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
 punc_model_revision="v2.0.1"
 spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index 3b5963a..7f1fdb5 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -8,7 +8,7 @@
 model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                   model_revision="v2.0.0",
                   vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-                  vad_model_revision="v2.0.1",
+                  vad_model_revision="v2.0.2",
                   punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                   punc_model_revision="v2.0.1",
                   )
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
index c46449f..ac5c190 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
+++ b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh
@@ -2,7 +2,7 @@
 model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
 model_revision="v2.0.0"
 vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
-vad_model_revision="v2.0.1"
+vad_model_revision="v2.0.2"
 punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
 punc_model_revision="v2.0.1"
 
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 27bd79d..9779050 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -22,8 +22,8 @@
 	
 	config = os.path.join(model_or_path, "config.yaml")
 	if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
-		cfg = OmegaConf.load(config)
-		kwargs = OmegaConf.merge(cfg, kwargs)
+		config = OmegaConf.load(config)
+		kwargs = OmegaConf.merge(config, kwargs)
 		init_param = os.path.join(model_or_path, "model.pb")
 		kwargs["init_param"] = init_param
 		if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
@@ -34,7 +34,7 @@
 			kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
 		if os.path.exists(os.path.join(model_or_path, "bpe.model")):
 			kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
-		kwargs["model"] = cfg["model"]
+		kwargs["model"] = config["model"]
 		if os.path.exists(os.path.join(model_or_path, "am.mvn")):
 			kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
 		if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
@@ -43,14 +43,30 @@
 		assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
 		with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
 			conf_json = json.load(f)
-			config = os.path.join(model_or_path, conf_json["model_config"])
-			cfg = OmegaConf.load(config)
-			kwargs = OmegaConf.merge(cfg, kwargs)
-			init_param = os.path.join(model_or_path, conf_json["model_file"])
-			kwargs["init_param"] = init_param
-		kwargs["model"] = cfg["model"]
+			cfg = {}
+			add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+			cfg.update(kwargs)
+			config = OmegaConf.load(cfg["config"])
+			kwargs = OmegaConf.merge(config, cfg)
+		kwargs["model"] = config["model"]
 	return OmegaConf.to_container(kwargs, resolve=True)
 
+def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
+	
+	if isinstance(file_path_metas, dict):
+		for k, v in file_path_metas.items():
+			if isinstance(v, str):
+				p = os.path.join(model_or_path, v)
+				if os.path.exists(p):
+					cfg[k] = p
+			elif isinstance(v, dict):
+				if k not in cfg:
+					cfg[k] = {}
+				return add_file_root_path(model_or_path, v, cfg[k])
+	
+	return cfg
+
+
 def get_or_download_model_dir(
 		model,
 		model_revision=None,
diff --git a/funasr/models/ct_transformer_streaming/__init__.py b/funasr/models/ct_transformer_streaming/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/ct_transformer_streaming/__init__.py
diff --git a/funasr/models/ct_transformer_streaming/attention.py b/funasr/models/ct_transformer_streaming/attention.py
new file mode 100644
index 0000000..a35ddee
--- /dev/null
+++ b/funasr/models/ct_transformer_streaming/attention.py
@@ -0,0 +1,1091 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention layer definition."""
+
+import math
+
+import numpy
+import torch
+from torch import nn
+from typing import Optional, Tuple
+
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+import funasr.models.lora.layers as lora
+
+class MultiHeadedAttention(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttention, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        self.linear_q = nn.Linear(n_feat, n_feat)
+        self.linear_k = nn.Linear(n_feat, n_feat)
+        self.linear_v = nn.Linear(n_feat, n_feat)
+        self.linear_out = nn.Linear(n_feat, n_feat)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, query, key, value):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        n_batch = query.size(0)
+        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
+        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
+        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q, k, v
+
+    def forward_attention(self, value, scores, mask):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, query, key, value, mask):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, mask)
+
+
+class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
+    """Multi-Head Attention layer with relative position encoding (old version).
+
+    Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+    Paper: https://arxiv.org/abs/1901.02860
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
+        """Construct an RelPositionMultiHeadedAttention object."""
+        super().__init__(n_head, n_feat, dropout_rate)
+        self.zero_triu = zero_triu
+        # linear transformation for positional encoding
+        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+        # these two learnable bias are used in matrix c and matrix d
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+        torch.nn.init.xavier_uniform_(self.pos_bias_u)
+        torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+    def rel_shift(self, x):
+        """Compute relative positional encoding.
+
+        Args:
+            x (torch.Tensor): Input tensor (batch, head, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor.
+
+        """
+        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+        x_padded = torch.cat([zero_pad, x], dim=-1)
+
+        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+        x = x_padded[:, :, 1:].view_as(x)
+
+        if self.zero_triu:
+            ones = torch.ones((x.size(2), x.size(3)))
+            x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+        return x
+
+    def forward(self, query, key, value, pos_emb, mask):
+        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
+
+        n_batch_pos = pos_emb.size(0)
+        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+        p = p.transpose(1, 2)  # (batch, head, time1, d_k)
+
+        # (batch, head, time1, d_k)
+        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+        # (batch, head, time1, d_k)
+        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+        # compute matrix b and matrix d
+        # (batch, head, time1, time1)
+        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        scores = (matrix_ac + matrix_bd) / math.sqrt(
+            self.d_k
+        )  # (batch, head, time1, time2)
+
+        return self.forward_attention(v, scores, mask)
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+    """Multi-Head Attention layer with relative position encoding (new implementation).
+
+    Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+    Paper: https://arxiv.org/abs/1901.02860
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
+        """Construct an RelPositionMultiHeadedAttention object."""
+        super().__init__(n_head, n_feat, dropout_rate)
+        self.zero_triu = zero_triu
+        # linear transformation for positional encoding
+        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
+        # these two learnable bias are used in matrix c and matrix d
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
+        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
+        torch.nn.init.xavier_uniform_(self.pos_bias_u)
+        torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+    def rel_shift(self, x):
+        """Compute relative positional encoding.
+
+        Args:
+            x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
+            time1 means the length of query vector.
+
+        Returns:
+            torch.Tensor: Output tensor.
+
+        """
+        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
+        x_padded = torch.cat([zero_pad, x], dim=-1)
+
+        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
+        x = x_padded[:, :, 1:].view_as(x)[
+            :, :, :, : x.size(-1) // 2 + 1
+            ]  # only keep the positions from 0 to time2
+
+        if self.zero_triu:
+            ones = torch.ones((x.size(2), x.size(3)), device=x.device)
+            x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
+
+        return x
+
+    def forward(self, query, key, value, pos_emb, mask):
+        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            pos_emb (torch.Tensor): Positional embedding tensor
+                (#batch, 2*time1-1, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
+
+        n_batch_pos = pos_emb.size(0)
+        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+        p = p.transpose(1, 2)  # (batch, head, 2*time1-1, d_k)
+
+        # (batch, head, time1, d_k)
+        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
+        # (batch, head, time1, d_k)
+        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
+
+        # compute attention score
+        # first compute matrix a and matrix c
+        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+        # (batch, head, time1, time2)
+        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+        # compute matrix b and matrix d
+        # (batch, head, time1, 2*time1-1)
+        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+        matrix_bd = self.rel_shift(matrix_bd)
+
+        scores = (matrix_ac + matrix_bd) / math.sqrt(
+            self.d_k
+        )  # (batch, head, time1, time2)
+
+        return self.forward_attention(v, scores, mask)
+
+
+class MultiHeadedAttentionSANM(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttentionSANM, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        # self.linear_q = nn.Linear(n_feat, n_feat)
+        # self.linear_k = nn.Linear(n_feat, n_feat)
+        # self.linear_v = nn.Linear(n_feat, n_feat)
+        if lora_list is not None:
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
+            lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
+            if lora_qkv_list == [False, False, False]:
+                self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+            else:
+                self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
+        else:
+            self.linear_out = nn.Linear(n_feat, n_feat)
+            self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+        self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+        # padding
+        left_padding = (kernel_size - 1) // 2
+        if sanm_shfit > 0:
+            left_padding = left_padding + sanm_shfit
+        right_padding = kernel_size - 1 - left_padding
+        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+
+    def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
+        b, t, d = inputs.size()
+        if mask is not None:
+            mask = torch.reshape(mask, (b, -1, 1))
+            if mask_shfit_chunk is not None:
+                mask = mask * mask_shfit_chunk
+            inputs = inputs * mask
+
+        x = inputs.transpose(1, 2)
+        x = self.pad_fn(x)
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        x += inputs
+        x = self.dropout(x)
+        if mask is not None:
+            x = x * mask
+        return x
+
+    def forward_qkv(self, x):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        b, t, d = x.size()
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q_h, k_h, v_h, v
+
+    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            if mask_att_chunk_encoder is not None:
+                mask = mask * mask_att_chunk_encoder
+
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+        return att_outs + fsmn_memory
+
+    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        if chunk_size is not None and look_back > 0 or look_back == -1:
+            if cache is not None:
+                k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
+                v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
+                k_h = torch.cat((cache["k"], k_h), dim=2)
+                v_h = torch.cat((cache["v"], v_h), dim=2)
+
+                cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
+                cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
+                if look_back != -1:
+                    cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
+                    cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
+            else:
+                cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
+                             "v": v_h[:, :, :-(chunk_size[2]), :]}
+                cache = cache_tmp
+        fsmn_memory = self.forward_fsmn(v, None)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, None)
+        return att_outs + fsmn_memory, cache
+
+
+class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
+        return att_outs + fsmn_memory
+
+class MultiHeadedAttentionSANMDecoder(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttentionSANMDecoder, self).__init__()
+
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+        self.fsmn_block = nn.Conv1d(n_feat, n_feat,
+                                    kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
+        # padding
+        # padding
+        left_padding = (kernel_size - 1) // 2
+        if sanm_shfit > 0:
+            left_padding = left_padding + sanm_shfit
+        right_padding = kernel_size - 1 - left_padding
+        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
+        self.kernel_size = kernel_size
+
+    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
+        '''
+        :param x: (#batch, time1, size).
+        :param mask: Mask tensor (#batch, 1, time)
+        :return:
+        '''
+        # print("in fsmn, inputs", inputs.size())
+        b, t, d = inputs.size()
+        # logging.info(
+        #     "mask: {}".format(mask.size()))
+        if mask is not None:
+            mask = torch.reshape(mask, (b ,-1, 1))
+            # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+            if mask_shfit_chunk is not None:
+                # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
+                mask = mask * mask_shfit_chunk
+            # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
+            # print("in fsmn, mask", mask.size())
+            # print("in fsmn, inputs", inputs.size())
+            inputs = inputs * mask
+
+        x = inputs.transpose(1, 2)
+        b, d, t = x.size()
+        if cache is None:
+            # print("in fsmn, cache is None, x", x.size())
+
+            x = self.pad_fn(x)
+            if not self.training:
+                cache = x
+        else:
+            # print("in fsmn, cache is not None, x", x.size())
+            # x = torch.cat((x, cache), dim=2)[:, :, :-1]
+            # if t < self.kernel_size:
+            #     x = self.pad_fn(x)
+            x = torch.cat((cache[:, :, 1:], x), dim=2)
+            x = x[:, :, -(self.kernel_size+t-1):]
+            # print("in fsmn, cache is not None, x_cat", x.size())
+            cache = x
+        x = self.fsmn_block(x)
+        x = x.transpose(1, 2)
+        # print("in fsmn, fsmn_out", x.size())
+        if x.size(1) != inputs.size(1):
+            inputs = inputs[:, -1, :]
+
+        x = x + inputs
+        x = self.dropout(x)
+        if mask is not None:
+            x = x * mask
+        return x, cache
+
+class MultiHeadedAttentionCrossAtt(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadedAttentionCrossAtt, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        if lora_list is not None:
+            if "q" in lora_list:
+                self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_q = nn.Linear(n_feat, n_feat)
+            lora_kv_list = ["k" in lora_list, "v" in lora_list]
+            if lora_kv_list == [False, False]:
+                self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            else:
+                self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
+                                      r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
+            if "o" in lora_list:
+                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
+            else:
+                self.linear_out = nn.Linear(n_feat, n_feat)
+        else:
+            self.linear_q = nn.Linear(n_feat, n_feat)
+            self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
+            self.linear_out = nn.Linear(n_feat, n_feat)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, x, memory):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+
+        # print("in forward_qkv, x", x.size())
+        b = x.size(0)
+        q = self.linear_q(x)
+        q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time1, d_k)
+
+        k_v = self.linear_k_v(memory)
+        k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
+        k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
+
+
+        return q_h, k_h, v_h
+
+    def forward_attention(self, value, scores, mask):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            # logging.info(
+            #     "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, x, memory, memory_mask):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h = self.forward_qkv(x, memory)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        return self.forward_attention(v_h, scores, memory_mask)
+
+    def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h = self.forward_qkv(x, memory)
+        if chunk_size is not None and look_back > 0:
+            if cache is not None:
+                k_h = torch.cat((cache["k"], k_h), dim=2)
+                v_h = torch.cat((cache["v"], v_h), dim=2)
+                cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
+                cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
+            else:
+                cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
+                             "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
+                cache = cache_tmp
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        return self.forward_attention(v_h, scores, None), cache
+
+
+class MultiHeadSelfAttention(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, in_feat, n_feat, dropout_rate):
+        """Construct an MultiHeadedAttention object."""
+        super(MultiHeadSelfAttention, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        self.linear_out = nn.Linear(n_feat, n_feat)
+        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, x):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        b, t, d = x.size()
+        q_k_v = self.linear_q_k_v(x)
+        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
+        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q_h, k_h, v_h, v
+
+    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            if mask_att_chunk_encoder is not None:
+                mask = mask * mask_att_chunk_encoder
+
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+            )
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x)  # (batch, time1, d_model)
+
+    def forward(self, x, mask, mask_att_chunk_encoder=None):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q_h, k_h, v_h, v = self.forward_qkv(x)
+        q_h = q_h * self.d_k ** (-0.5)
+        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+        return att_outs
+
+class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
+    """RelPositionMultiHeadedAttention definition.
+    Args:
+        num_heads: Number of attention heads.
+        embed_size: Embedding size.
+        dropout_rate: Dropout rate.
+    """
+
+    def __init__(
+        self,
+        num_heads: int,
+        embed_size: int,
+        dropout_rate: float = 0.0,
+        simplified_attention_score: bool = False,
+    ) -> None:
+        """Construct an MultiHeadedAttention object."""
+        super().__init__()
+
+        self.d_k = embed_size // num_heads
+        self.num_heads = num_heads
+
+        assert self.d_k * num_heads == embed_size, (
+            "embed_size (%d) must be divisible by num_heads (%d)",
+            (embed_size, num_heads),
+        )
+
+        self.linear_q = torch.nn.Linear(embed_size, embed_size)
+        self.linear_k = torch.nn.Linear(embed_size, embed_size)
+        self.linear_v = torch.nn.Linear(embed_size, embed_size)
+
+        self.linear_out = torch.nn.Linear(embed_size, embed_size)
+
+        if simplified_attention_score:
+            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
+
+            self.compute_att_score = self.compute_simplified_attention_score
+        else:
+            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
+
+            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
+            torch.nn.init.xavier_uniform_(self.pos_bias_u)
+            torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+            self.compute_att_score = self.compute_attention_score
+
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+        self.attn = None
+
+    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
+        """Compute relative positional encoding.
+        Args:
+            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
+            left_context: Number of frames in left context.
+        Returns:
+            x: Output sequence. (B, H, T_1, T_2)
+        """
+        batch_size, n_heads, time1, n = x.shape
+        time2 = time1 + left_context
+
+        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
+
+        return x.as_strided(
+            (batch_size, n_heads, time1, time2),
+            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
+            storage_offset=(n_stride * (time1 - 1)),
+        )
+
+    def compute_simplified_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Simplified attention score computation.
+        Reference: https://github.com/k2-fsa/icefall/pull/458
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+        """
+        pos_enc = self.linear_pos(pos_enc)
+
+        matrix_ac = torch.matmul(query, key.transpose(2, 3))
+
+        matrix_bd = self.rel_shift(
+            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
+            left_context=left_context,
+        )
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def compute_attention_score(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        pos_enc: torch.Tensor,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Attention score computation.
+        Args:
+            query: Transformed query tensor. (B, H, T_1, d_k)
+            key: Transformed key tensor. (B, H, T_2, d_k)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            left_context: Number of frames in left context.
+        Returns:
+            : Attention score. (B, H, T_1, T_2)
+        """
+        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
+
+        query = query.transpose(1, 2)
+        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
+        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
+
+        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
+
+        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
+        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
+
+        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
+
+    def forward_qkv(
+        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        """Transform query, key and value.
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            v: Value tensor. (B, T_2, size)
+        Returns:
+            q: Transformed query tensor. (B, H, T_1, d_k)
+            k: Transformed key tensor. (B, H, T_2, d_k)
+            v: Transformed value tensor. (B, H, T_2, d_k)
+        """
+        n_batch = query.size(0)
+
+        q = (
+            self.linear_q(query)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        k = (
+            self.linear_k(key)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+        v = (
+            self.linear_v(value)
+            .view(n_batch, -1, self.num_heads, self.d_k)
+            .transpose(1, 2)
+        )
+
+        return q, k, v
+
+    def forward_attention(
+        self,
+        value: torch.Tensor,
+        scores: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        """Compute attention context vector.
+        Args:
+            value: Transformed value. (B, H, T_2, d_k)
+            scores: Attention score. (B, H, T_1, T_2)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+        Returns:
+           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
+        """
+        batch_size = scores.size(0)
+        mask = mask.unsqueeze(1).unsqueeze(2)
+        if chunk_mask is not None:
+            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
+        scores = scores.masked_fill(mask, float("-inf"))
+        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
+
+        attn_output = self.dropout(self.attn)
+        attn_output = torch.matmul(attn_output, value)
+
+        attn_output = self.linear_out(
+            attn_output.transpose(1, 2)
+            .contiguous()
+            .view(batch_size, -1, self.num_heads * self.d_k)
+        )
+
+        return attn_output
+
+    def forward(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        pos_enc: torch.Tensor,
+        mask: torch.Tensor,
+        chunk_mask: Optional[torch.Tensor] = None,
+        left_context: int = 0,
+    ) -> torch.Tensor:
+        """Compute scaled dot product attention with rel. positional encoding.
+        Args:
+            query: Query tensor. (B, T_1, size)
+            key: Key tensor. (B, T_2, size)
+            value: Value tensor. (B, T_2, size)
+            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
+            mask: Source mask. (B, T_2)
+            chunk_mask: Chunk mask. (T_1, T_1)
+            left_context: Number of frames in left context.
+        Returns:
+            : Output tensor. (B, T_1, H * d_k)
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
+        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
+
+
+class CosineDistanceAttention(nn.Module):
+    """ Compute Cosine Distance between spk decoder output and speaker profile 
+    Args:
+        profile_path: speaker profile file path (.npy file)
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.softmax = nn.Softmax(dim=-1)
+
+    def forward(self, spk_decoder_out, profile, profile_lens=None):
+        """
+        Args:
+            spk_decoder_out(torch.Tensor):(B, L, D)
+            spk_profiles(torch.Tensor):(B, N, D)
+        """
+        x = spk_decoder_out.unsqueeze(2)  # (B, L, 1, D)
+        if profile_lens is not None:
+            
+            mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
+            min_value = float(
+                numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
+            )
+            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
+            weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0)  # (B, L, N)
+        else:
+            x = x[:, -1:, :, :]
+            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
+            weights = self.softmax(weights_not_softmax)  # (B, 1, N)
+        spk_embedding = torch.matmul(weights, profile.to(weights.device))  # (B, L, D)
+
+        return spk_embedding, weights
diff --git a/funasr/models/ct_transformer_streaming/encoder.py b/funasr/models/ct_transformer_streaming/encoder.py
new file mode 100644
index 0000000..784baf3
--- /dev/null
+++ b/funasr/models/ct_transformer_streaming/encoder.py
@@ -0,0 +1,383 @@
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+import logging
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from funasr.models.scama.chunk_utilis import overlap_chunk
+import numpy as np
+from funasr.train_utils.device_funcs import to_device
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.sanm.attention import MultiHeadedAttention
+from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask
+from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.layer_norm import LayerNorm
+from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
+from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
+from funasr.models.transformer.positionwise_feed_forward import (
+    PositionwiseFeedForward,  # noqa: H301
+)
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
+from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
+from funasr.models.transformer.utils.subsampling import TooShortUttError
+from funasr.models.transformer.utils.subsampling import check_short_utt
+from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
+
+from funasr.models.ctc.ctc import CTC
+
+from funasr.register import tables
+
+class EncoderLayerSANM(nn.Module):
+    def __init__(
+        self,
+        in_size,
+        size,
+        self_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+        stochastic_depth_rate=0.0,
+    ):
+        """Construct an EncoderLayer object."""
+        super(EncoderLayerSANM, self).__init__()
+        self.self_attn = self_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(in_size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.in_size = in_size
+        self.size = size
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear = nn.Linear(size + size, size)
+        self.stochastic_depth_rate = stochastic_depth_rate
+        self.dropout_rate = dropout_rate
+
+    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+        skip_layer = False
+        # with stochastic depth, residual connection `x + f(x)` becomes
+        # `x <- x + 1 / (1 - p) * f(x)` at training time.
+        stoch_layer_coeff = 1.0
+        if self.training and self.stochastic_depth_rate > 0:
+            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+        if skip_layer:
+            if cache is not None:
+                x = torch.cat([cache, x], dim=1)
+            return x, mask
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.concat_after:
+            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+            else:
+                x = stoch_layer_coeff * self.concat_linear(x_concat)
+        else:
+            if self.in_size == self.size:
+                x = residual + stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+            else:
+                x = stoch_layer_coeff * self.dropout(
+                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                )
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if self.in_size == self.size:
+            attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+            x = residual + attn
+        else:
+            x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + self.feed_forward(x)
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        return x, cache
+
+
+@tables.register("encoder_classes", "SANMVadEncoder")
+class SANMVadEncoder(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        pos_enc_class=SinusoidalPositionEncoder,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
+        kernel_size : int = 11,
+        sanm_shfit : int = 0,
+        selfattention_layer_type: str = "sanm",
+    ):
+        super().__init__()
+        self._output_size = output_size
+
+        if input_layer == "linear":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Linear(input_size, output_size),
+                torch.nn.LayerNorm(output_size),
+                torch.nn.Dropout(dropout_rate),
+                torch.nn.ReLU(),
+                pos_enc_class(output_size, positional_dropout_rate),
+            )
+        elif input_layer == "conv2d":
+            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d2":
+            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d6":
+            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
+        elif input_layer == "conv2d8":
+            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
+        elif input_layer == "embed":
+            self.embed = torch.nn.Sequential(
+                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
+                SinusoidalPositionEncoder(),
+            )
+        elif input_layer is None:
+            if input_size == output_size:
+                self.embed = None
+            else:
+                self.embed = torch.nn.Linear(input_size, output_size)
+        elif input_layer == "pe":
+            self.embed = SinusoidalPositionEncoder()
+        else:
+            raise ValueError("unknown input_layer: " + input_layer)
+        self.normalize_before = normalize_before
+        if positionwise_layer_type == "linear":
+            positionwise_layer = PositionwiseFeedForward
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d":
+            positionwise_layer = MultiLayeredConv1d
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        elif positionwise_layer_type == "conv1d-linear":
+            positionwise_layer = Conv1dLinear
+            positionwise_layer_args = (
+                output_size,
+                linear_units,
+                positionwise_conv_kernel_size,
+                dropout_rate,
+            )
+        else:
+            raise NotImplementedError("Support only linear or conv1d.")
+
+        if selfattention_layer_type == "selfattn":
+            encoder_selfattn_layer = MultiHeadedAttention
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                attention_dropout_rate,
+            )
+
+        elif selfattention_layer_type == "sanm":
+            self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
+            encoder_selfattn_layer_args0 = (
+                attention_heads,
+                input_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+            encoder_selfattn_layer_args = (
+                attention_heads,
+                output_size,
+                output_size,
+                attention_dropout_rate,
+                kernel_size,
+                sanm_shfit,
+            )
+
+        self.encoders0 = repeat(
+            1,
+            lambda lnum: EncoderLayerSANM(
+                input_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+
+        self.encoders = repeat(
+            num_blocks-1,
+            lambda lnum: EncoderLayerSANM(
+                output_size,
+                output_size,
+                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+        self.interctc_layer_idx = interctc_layer_idx
+        if len(interctc_layer_idx) > 0:
+            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
+        self.interctc_use_conditioning = interctc_use_conditioning
+        self.conditioning_layer = None
+        self.dropout = nn.Dropout(dropout_rate)
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        vad_indexes: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+            prev_states: Not to be used now.
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
+        no_future_masks = masks & sub_masks
+        xs_pad *= self.output_size()**0.5
+        if self.embed is None:
+            xs_pad = xs_pad
+        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
+              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
+            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
+            if short_status:
+                raise TooShortUttError(
+                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
+                    f"(it needs more than {limit_size} frames), return empty results",
+                    xs_pad.size(1),
+                    limit_size,
+                )
+            xs_pad, masks = self.embed(xs_pad, masks)
+        else:
+            xs_pad = self.embed(xs_pad)
+
+        # xs_pad = self.dropout(xs_pad)
+        mask_tup0 = [masks, no_future_masks]
+        encoder_outs = self.encoders0(xs_pad, mask_tup0)
+        xs_pad, _ = encoder_outs[0], encoder_outs[1]
+        intermediate_outs = []
+
+
+        for layer_idx, encoder_layer in enumerate(self.encoders):
+                if layer_idx + 1 == len(self.encoders):
+                    # This is last layer.
+                    coner_mask = torch.ones(masks.size(0),
+                                            masks.size(-1),
+                                            masks.size(-1),
+                                            device=xs_pad.device,
+                                            dtype=torch.bool)
+                    for word_index, length in enumerate(ilens):
+                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
+                                                                vad_indexes[word_index],
+                                                                device=xs_pad.device)
+                    layer_mask = masks & coner_mask
+                else:
+                    layer_mask = no_future_masks
+                mask_tup1 = [masks, layer_mask]
+                encoder_outs = encoder_layer(xs_pad, mask_tup1)
+                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        if len(intermediate_outs) > 0:
+            return (xs_pad, intermediate_outs), olens, None
+        return xs_pad, olens, None
diff --git a/funasr/models/ct_transformer_streaming/model.py b/funasr/models/ct_transformer_streaming/model.py
new file mode 100644
index 0000000..4c84261
--- /dev/null
+++ b/funasr/models/ct_transformer_streaming/model.py
@@ -0,0 +1,340 @@
+from typing import Any
+from typing import List
+from typing import Tuple
+from typing import Optional
+import numpy as np
+import torch.nn.functional as F
+
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.train_utils.device_funcs import to_device
+import torch
+import torch.nn as nn
+from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
+from funasr.utils.load_utils import load_audio_text_image_video
+
+from funasr.register import tables
+
+@tables.register("model_classes", "CTTransformerStreaming")
+class CTTransformerStreaming(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+    https://arxiv.org/pdf/2003.01309.pdf
+    """
+    def __init__(
+        self,
+        encoder: str = None,
+        encoder_conf: dict = None,
+        vocab_size: int = -1,
+        punc_list: list = None,
+        punc_weight: list = None,
+        embed_unit: int = 128,
+        att_unit: int = 256,
+        dropout_rate: float = 0.5,
+        ignore_id: int = -1,
+        sos: int = 1,
+        eos: int = 2,
+        sentence_end_id: int = 3,
+        **kwargs,
+    ):
+        super().__init__()
+
+        punc_size = len(punc_list)
+        if punc_weight is None:
+            punc_weight = [1] * punc_size
+        
+        
+        self.embed = nn.Embedding(vocab_size, embed_unit)
+        encoder_class = tables.encoder_classes.get(encoder.lower())
+        encoder = encoder_class(**encoder_conf)
+
+        self.decoder = nn.Linear(att_unit, punc_size)
+        self.encoder = encoder
+        self.punc_list = punc_list
+        self.punc_weight = punc_weight
+        self.ignore_id = ignore_id
+        self.sos = sos
+        self.eos = eos
+        self.sentence_end_id = sentence_end_id
+        
+        
+
+    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+        """Compute loss value from buffer sequences.
+
+        Args:
+            input (torch.Tensor): Input ids. (batch, len)
+            hidden (torch.Tensor): Target ids. (batch, len)
+
+        """
+        x = self.embed(text)
+        # mask = self._target_mask(input)
+        h, _, _ = self.encoder(x, text_lengths)
+        y = self.decoder(h)
+        return y, None
+
+    def with_vad(self):
+        return False
+
+    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
+        """Score new token.
+
+        Args:
+            y (torch.Tensor): 1D torch.int64 prefix tokens.
+            state: Scorer state for prefix tokens
+            x (torch.Tensor): encoder feature that generates ys.
+
+        Returns:
+            tuple[torch.Tensor, Any]: Tuple of
+                torch.float32 scores for next token (vocab_size)
+                and next state for ys
+
+        """
+        y = y.unsqueeze(0)
+        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
+        h = self.decoder(h[:, -1])
+        logp = h.log_softmax(dim=-1).squeeze(0)
+        return logp, cache
+
+    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
+        """Score new token batch.
+
+        Args:
+            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+            states (List[Any]): Scorer states for prefix tokens.
+            xs (torch.Tensor):
+                The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+        Returns:
+            tuple[torch.Tensor, List[Any]]: Tuple of
+                batchfied scores for next token with shape of `(n_batch, vocab_size)`
+                and next state list for ys.
+
+        """
+        # merge states
+        n_batch = len(ys)
+        n_layers = len(self.encoder.encoders)
+        if states[0] is None:
+            batch_state = None
+        else:
+            # transpose state of [batch, layer] into [layer, batch]
+            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
+
+        # batch decoding
+        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
+        h = self.decoder(h[:, -1])
+        logp = h.log_softmax(dim=-1)
+
+        # transpose state of [layer, batch] into [batch, layer]
+        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
+        return logp, state_list
+
+    def nll(
+        self,
+        text: torch.Tensor,
+        punc: torch.Tensor,
+        text_lengths: torch.Tensor,
+        punc_lengths: torch.Tensor,
+        max_length: Optional[int] = None,
+        vad_indexes: Optional[torch.Tensor] = None,
+        vad_indexes_lengths: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Compute negative log likelihood(nll)
+
+        Normally, this function is called in batchify_nll.
+        Args:
+            text: (Batch, Length)
+            punc: (Batch, Length)
+            text_lengths: (Batch,)
+            max_lengths: int
+        """
+        batch_size = text.size(0)
+        # For data parallel
+        if max_length is None:
+            text = text[:, :text_lengths.max()]
+            punc = punc[:, :text_lengths.max()]
+        else:
+            text = text[:, :max_length]
+            punc = punc[:, :max_length]
+    
+        if self.with_vad():
+            # Should be VadRealtimeTransformer
+            assert vad_indexes is not None
+            y, _ = self.punc_forward(text, text_lengths, vad_indexes)
+        else:
+            # Should be TargetDelayTransformer,
+            y, _ = self.punc_forward(text, text_lengths)
+    
+        # Calc negative log likelihood
+        # nll: (BxL,)
+        if self.training == False:
+            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
+            from sklearn.metrics import f1_score
+            f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
+                                indices.squeeze(-1).detach().cpu().numpy(),
+                                average='micro')
+            nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
+            return nll, text_lengths
+        else:
+            self.punc_weight = self.punc_weight.to(punc.device)
+            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
+                                  ignore_index=self.ignore_id)
+        # nll: (BxL,) -> (BxL,)
+        if max_length is None:
+            nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
+        else:
+            nll.masked_fill_(
+                make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
+                0.0,
+            )
+        # nll: (BxL,) -> (B, L)
+        nll = nll.view(batch_size, -1)
+        return nll, text_lengths
+
+
+    def forward(
+        self,
+        text: torch.Tensor,
+        punc: torch.Tensor,
+        text_lengths: torch.Tensor,
+        punc_lengths: torch.Tensor,
+        vad_indexes: Optional[torch.Tensor] = None,
+        vad_indexes_lengths: Optional[torch.Tensor] = None,
+    ):
+        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
+        ntokens = y_lengths.sum()
+        loss = nll.sum() / ntokens
+        stats = dict(loss=loss.detach())
+    
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
+        return loss, stats, weight
+    
+    def generate(self,
+                 data_in,
+                 data_lengths=None,
+                 key: list = None,
+                 tokenizer=None,
+                 frontend=None,
+                 **kwargs,
+                 ):
+        assert len(data_in) == 1
+        text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
+        vad_indexes = kwargs.get("vad_indexes", None)
+        # text = data_in[0]
+        # text_lengths = data_lengths[0] if data_lengths is not None else None
+        split_size = kwargs.get("split_size", 20)
+
+        jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
+        if jieba_usr_dict and isinstance(jieba_usr_dict, str):
+            import jieba
+            jieba.load_userdict(jieba_usr_dict)
+            jieba_usr_dict = jieba
+            kwargs["jieba_usr_dict"] = "jieba_usr_dict"
+        tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
+        tokens_int = tokenizer.encode(tokens)
+
+        mini_sentences = split_to_mini_sentence(tokens, split_size)
+        mini_sentences_id = split_to_mini_sentence(tokens_int, split_size)
+        assert len(mini_sentences) == len(mini_sentences_id)
+        cache_sent = []
+        cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
+        new_mini_sentence = ""
+        new_mini_sentence_punc = []
+        cache_pop_trigger_limit = 200
+        results = []
+        meta_data = {}
+        punc_array = None
+        for mini_sentence_i in range(len(mini_sentences)):
+            mini_sentence = mini_sentences[mini_sentence_i]
+            mini_sentence_id = mini_sentences_id[mini_sentence_i]
+            mini_sentence = cache_sent + mini_sentence
+            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+            data = {
+                "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
+                "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
+            }
+            data = to_device(data, kwargs["device"])
+            # y, _ = self.wrapped_model(**data)
+            y, _ = self.punc_forward(**data)
+            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
+            punctuations = indices
+            if indices.size()[0] != 1:
+                punctuations = torch.squeeze(indices)
+            assert punctuations.size()[0] == len(mini_sentence)
+
+            # Search for the last Period/QuestionMark as cache
+            if mini_sentence_i < len(mini_sentences) - 1:
+                sentenceEnd = -1
+                last_comma_index = -1
+                for i in range(len(punctuations) - 2, 1, -1):
+                    if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+                        sentenceEnd = i
+                        break
+                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+                        last_comma_index = i
+
+                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+                    # The sentence it too long, cut off at a comma.
+                    sentenceEnd = last_comma_index
+                    punctuations[sentenceEnd] = self.sentence_end_id
+                cache_sent = mini_sentence[sentenceEnd + 1:]
+                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+                mini_sentence = mini_sentence[0:sentenceEnd + 1]
+                punctuations = punctuations[0:sentenceEnd + 1]
+
+            # if len(punctuations) == 0:
+            #    continue
+
+            punctuations_np = punctuations.cpu().numpy()
+            new_mini_sentence_punc += [int(x) for x in punctuations_np]
+            words_with_punc = []
+            for i in range(len(mini_sentence)):
+                if (i==0 or self.punc_list[punctuations[i-1]] == "銆�" or self.punc_list[punctuations[i-1]] == "锛�") and len(mini_sentence[i][0].encode()) == 1:
+                    mini_sentence[i] = mini_sentence[i].capitalize()
+                if i == 0:
+                    if len(mini_sentence[i][0].encode()) == 1:
+                        mini_sentence[i] = " " + mini_sentence[i]
+                if i > 0:
+                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
+                        mini_sentence[i] = " " + mini_sentence[i]
+                words_with_punc.append(mini_sentence[i])
+                if self.punc_list[punctuations[i]] != "_":
+                    punc_res = self.punc_list[punctuations[i]]
+                    if len(mini_sentence[i][0].encode()) == 1:
+                        if punc_res == "锛�":
+                            punc_res = ","
+                        elif punc_res == "銆�":
+                            punc_res = "."
+                        elif punc_res == "锛�":
+                            punc_res = "?"
+                    words_with_punc.append(punc_res)
+            new_mini_sentence += "".join(words_with_punc)
+            # Add Period for the end of the sentence
+            new_mini_sentence_out = new_mini_sentence
+            new_mini_sentence_punc_out = new_mini_sentence_punc
+            if mini_sentence_i == len(mini_sentences) - 1:
+                if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
+                    new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+                elif new_mini_sentence[-1] == ",":
+                    new_mini_sentence_out = new_mini_sentence[:-1] + "."
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
+                    new_mini_sentence_out = new_mini_sentence + "銆�"
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
+                    new_mini_sentence_out = new_mini_sentence + "."
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
+            # keep a punctuations array for punc segment
+            if punc_array is None:
+                punc_array = punctuations
+            else:
+                punc_array = torch.cat([punc_array, punctuations], dim=0)
+        
+        result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
+        results.append(result_i)
+    
+        return results, meta_data
+
diff --git a/funasr/models/ct_transformer_streaming/template.yaml b/funasr/models/ct_transformer_streaming/template.yaml
new file mode 100644
index 0000000..c20a098
--- /dev/null
+++ b/funasr/models/ct_transformer_streaming/template.yaml
@@ -0,0 +1,53 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+model: CTTransformerStreaming
+model_conf:
+    ignore_id: 0
+    embed_unit: 256
+    att_unit: 256
+    dropout_rate: 0.1
+    punc_list:
+        - <unk>
+        - _
+        - 锛�
+        - 銆�
+        - 锛�
+        - 銆�
+    punc_weight:
+        - 1.0
+        - 1.0
+        - 1.0
+        - 1.0
+        - 1.0
+        - 1.0
+    sentence_end_id: 3
+
+encoder: SANMEncoder
+encoder_conf:
+    input_size: 256
+    output_size: 256
+    attention_heads: 8
+    linear_units: 1024
+    num_blocks: 4
+    dropout_rate: 0.1
+    positional_dropout_rate: 0.1
+    attention_dropout_rate: 0.0
+    input_layer: pe
+    pos_enc_class: SinusoidalPositionEncoder
+    normalize_before: true
+    kernel_size: 11
+    sanm_shfit: 0
+    selfattention_layer_type: sanm
+    padding_idx: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+  unk_symbol: <unk>
+
+
+
diff --git a/funasr/models/ct_transformer_streaming/utils.py b/funasr/models/ct_transformer_streaming/utils.py
new file mode 100644
index 0000000..917f2e0
--- /dev/null
+++ b/funasr/models/ct_transformer_streaming/utils.py
@@ -0,0 +1,111 @@
+import re
+
+def split_to_mini_sentence(words: list, word_limit: int = 20):
+    assert word_limit > 1
+    if len(words) <= word_limit:
+        return [words]
+    sentences = []
+    length = len(words)
+    sentence_len = length // word_limit
+    for i in range(sentence_len):
+        sentences.append(words[i * word_limit:(i + 1) * word_limit])
+    if length % word_limit > 0:
+        sentences.append(words[sentence_len * word_limit:])
+    return sentences
+
+
+# def split_words(text: str, **kwargs):
+#     words = []
+#     segs = text.split()
+#     for seg in segs:
+#         # There is no space in seg.
+#         current_word = ""
+#         for c in seg:
+#             if len(c.encode()) == 1:
+#                 # This is an ASCII char.
+#                 current_word += c
+#             else:
+#                 # This is a Chinese char.
+#                 if len(current_word) > 0:
+#                     words.append(current_word)
+#                     current_word = ""
+#                 words.append(c)
+#         if len(current_word) > 0:
+#             words.append(current_word)
+#
+#     return words
+
+def split_words(text: str, jieba_usr_dict=None, **kwargs):
+    if jieba_usr_dict:
+        input_list = text.split()
+        token_list_all = []
+        langauge_list = []
+        token_list_tmp = []
+        language_flag = None
+        for token in input_list:
+            if isEnglish(token) and language_flag == 'Chinese':
+                token_list_all.append(token_list_tmp)
+                langauge_list.append('Chinese')
+                token_list_tmp = []
+            elif not isEnglish(token) and language_flag == 'English':
+                token_list_all.append(token_list_tmp)
+                langauge_list.append('English')
+                token_list_tmp = []
+
+            token_list_tmp.append(token)
+
+            if isEnglish(token):
+                language_flag = 'English'
+            else:
+                language_flag = 'Chinese'
+
+        if token_list_tmp:
+            token_list_all.append(token_list_tmp)
+            langauge_list.append(language_flag)
+
+        result_list = []
+        for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
+            if language_flag == 'English':
+                result_list.extend(token_list_tmp)
+            else:
+                seg_list = jieba_usr_dict.cut(join_chinese_and_english(token_list_tmp), HMM=False)
+                result_list.extend(seg_list)
+
+        return result_list
+
+    else:
+        words = []
+        segs = text.split()
+        for seg in segs:
+            # There is no space in seg.
+            current_word = ""
+            for c in seg:
+                if len(c.encode()) == 1:
+                    # This is an ASCII char.
+                    current_word += c
+                else:
+                    # This is a Chinese char.
+                    if len(current_word) > 0:
+                        words.append(current_word)
+                        current_word = ""
+                    words.append(c)
+            if len(current_word) > 0:
+                words.append(current_word)
+        return words
+
+def isEnglish(text:str):
+    if re.search('^[a-zA-Z\']+$', text):
+        return True
+    else:
+        return False
+
+def join_chinese_and_english(input_list):
+    line = ''
+    for token in input_list:
+        if isEnglish(token):
+            line = line + ' ' + token
+        else:
+            line = line + token
+
+    line = line.strip()
+    return line
diff --git a/funasr/models/ct_transformer_streaming/vad_realtime_transformer.py b/funasr/models/ct_transformer_streaming/vad_realtime_transformer.py
new file mode 100644
index 0000000..155057c
--- /dev/null
+++ b/funasr/models/ct_transformer_streaming/vad_realtime_transformer.py
@@ -0,0 +1,135 @@
+from typing import Any
+from typing import List
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from funasr.models.transformer.embedding import SinusoidalPositionEncoder
+from funasr.models.ct_transformer.sanm_encoder import SANMVadEncoder as Encoder
+
+
+class VadRealtimeTransformer(torch.nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
+    https://arxiv.org/pdf/2003.01309.pdf
+    """
+    def __init__(
+        self,
+        vocab_size: int,
+        punc_size: int,
+        pos_enc: str = None,
+        embed_unit: int = 128,
+        att_unit: int = 256,
+        head: int = 2,
+        unit: int = 1024,
+        layer: int = 4,
+        dropout_rate: float = 0.5,
+        kernel_size: int = 11,
+        sanm_shfit: int = 0,
+    ):
+        super().__init__()
+        if pos_enc == "sinusoidal":
+            #            pos_enc_class = PositionalEncoding
+            pos_enc_class = SinusoidalPositionEncoder
+        elif pos_enc is None:
+
+            def pos_enc_class(*args, **kwargs):
+                return nn.Sequential()  # indentity
+
+        else:
+            raise ValueError(f"unknown pos-enc option: {pos_enc}")
+
+        self.embed = nn.Embedding(vocab_size, embed_unit)
+        self.encoder = Encoder(
+            input_size=embed_unit,
+            output_size=att_unit,
+            attention_heads=head,
+            linear_units=unit,
+            num_blocks=layer,
+            dropout_rate=dropout_rate,
+            input_layer="pe",
+            # pos_enc_class=pos_enc_class,
+            padding_idx=0,
+            kernel_size=kernel_size,
+            sanm_shfit=sanm_shfit,
+        )
+        self.decoder = nn.Linear(att_unit, punc_size)
+
+
+#    def _target_mask(self, ys_in_pad):
+#        ys_mask = ys_in_pad != 0
+#        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
+#        return ys_mask.unsqueeze(-2) & m
+
+    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
+                vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
+        """Compute loss value from buffer sequences.
+
+        Args:
+            input (torch.Tensor): Input ids. (batch, len)
+            hidden (torch.Tensor): Target ids. (batch, len)
+
+        """
+        x = self.embed(input)
+        # mask = self._target_mask(input)
+        h, _, _ = self.encoder(x, text_lengths, vad_indexes)
+        y = self.decoder(h)
+        return y, None
+
+    def with_vad(self):
+        return True
+
+    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
+        """Score new token.
+
+        Args:
+            y (torch.Tensor): 1D torch.int64 prefix tokens.
+            state: Scorer state for prefix tokens
+            x (torch.Tensor): encoder feature that generates ys.
+
+        Returns:
+            tuple[torch.Tensor, Any]: Tuple of
+                torch.float32 scores for next token (vocab_size)
+                and next state for ys
+
+        """
+        y = y.unsqueeze(0)
+        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
+        h = self.decoder(h[:, -1])
+        logp = h.log_softmax(dim=-1).squeeze(0)
+        return logp, cache
+
+    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
+        """Score new token batch.
+
+        Args:
+            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+            states (List[Any]): Scorer states for prefix tokens.
+            xs (torch.Tensor):
+                The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+        Returns:
+            tuple[torch.Tensor, List[Any]]: Tuple of
+                batchfied scores for next token with shape of `(n_batch, vocab_size)`
+                and next state list for ys.
+
+        """
+        # merge states
+        n_batch = len(ys)
+        n_layers = len(self.encoder.encoders)
+        if states[0] is None:
+            batch_state = None
+        else:
+            # transpose state of [batch, layer] into [layer, batch]
+            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
+
+        # batch decoding
+        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
+        h = self.decoder(h[:, -1])
+        logp = h.log_softmax(dim=-1)
+
+        # transpose state of [layer, batch] into [batch, layer]
+        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
+        return logp, state_list

--
Gitblit v1.9.1