From 99730b35f47579eb99b5e4ba0e6ca99901c23955 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 14 一月 2024 23:21:08 +0800
Subject: [PATCH] funasr1.0 ct-transformer streaming

---
 funasr/models/ct_transformer/utils.py                                  |   20 
 /dev/null                                                              |  135 -----
 examples/industrial_data_pretraining/ct_transformer_streaming/demo.py  |   19 
 examples/industrial_data_pretraining/ct_transformer_streaming/infer.sh |   10 
 funasr/models/ct_transformer_streaming/attention.py                    | 1061 ----------------------------------------
 funasr/models/ct_transformer_streaming/model.py                        |  277 ++--------
 funasr/models/ct_transformer/model.py                                  |    2 
 funasr/models/ct_transformer_streaming/template.yaml                   |   11 
 funasr/models/ct_transformer_streaming/encoder.py                      |    2 
 funasr/train_utils/load_pretrained_model.py                            |    1 
 10 files changed, 91 insertions(+), 1,447 deletions(-)

diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py b/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py
new file mode 100644
index 0000000..5ef8381
--- /dev/null
+++ b/examples/industrial_data_pretraining/ct_transformer_streaming/demo.py
@@ -0,0 +1,19 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727", model_revision="v2.0.1")
+
+inputs = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
+vads = inputs.split("|")
+rec_result_all = "outputs: "
+cache = {}
+for vad in vads:
+    rec_result = model(input=vad, cache=cache)
+    print(rec_result)
+    rec_result_all += rec_result[0]['text']
+
+print(rec_result_all)
diff --git a/examples/industrial_data_pretraining/ct_transformer_streaming/infer.sh b/examples/industrial_data_pretraining/ct_transformer_streaming/infer.sh
new file mode 100644
index 0000000..fa92a6e
--- /dev/null
+++ b/examples/industrial_data_pretraining/ct_transformer_streaming/infer.sh
@@ -0,0 +1,10 @@
+
+model="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
+model_revision="v2.0.1"
+
+python funasr/bin/inference.py \
++model=${model} \
++model_revision=${model_revision} \
++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt" \
++output_dir="./outputs/debug" \
++device="cpu"
diff --git a/funasr/models/ct_transformer/attention.py b/funasr/models/ct_transformer/attention.py
deleted file mode 100644
index a35ddee..0000000
--- a/funasr/models/ct_transformer/attention.py
+++ /dev/null
@@ -1,1091 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-# Copyright 2019 Shigeki Karita
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
-"""Multi-Head Attention layer definition."""
-
-import math
-
-import numpy
-import torch
-from torch import nn
-from typing import Optional, Tuple
-
-import torch.nn.functional as F
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-import funasr.models.lora.layers as lora
-
-class MultiHeadedAttention(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, n_feat, dropout_rate):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttention, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        self.linear_q = nn.Linear(n_feat, n_feat)
-        self.linear_k = nn.Linear(n_feat, n_feat)
-        self.linear_v = nn.Linear(n_feat, n_feat)
-        self.linear_out = nn.Linear(n_feat, n_feat)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-    def forward_qkv(self, query, key, value):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-        n_batch = query.size(0)
-        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
-        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
-        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
-        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
-        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
-        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
-
-        return q, k, v
-
-    def forward_attention(self, value, scores, mask):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x)  # (batch, time1, d_model)
-
-    def forward(self, query, key, value, mask):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q, k, v = self.forward_qkv(query, key, value)
-        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
-        return self.forward_attention(v, scores, mask)
-
-
-class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
-    """Multi-Head Attention layer with relative position encoding (old version).
-
-    Details can be found in https://github.com/espnet/espnet/pull/2816.
-
-    Paper: https://arxiv.org/abs/1901.02860
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
-
-    """
-
-    def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
-        """Construct an RelPositionMultiHeadedAttention object."""
-        super().__init__(n_head, n_feat, dropout_rate)
-        self.zero_triu = zero_triu
-        # linear transformation for positional encoding
-        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
-        # these two learnable bias are used in matrix c and matrix d
-        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
-        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
-        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
-        torch.nn.init.xavier_uniform_(self.pos_bias_u)
-        torch.nn.init.xavier_uniform_(self.pos_bias_v)
-
-    def rel_shift(self, x):
-        """Compute relative positional encoding.
-
-        Args:
-            x (torch.Tensor): Input tensor (batch, head, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor.
-
-        """
-        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
-        x_padded = torch.cat([zero_pad, x], dim=-1)
-
-        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
-        x = x_padded[:, :, 1:].view_as(x)
-
-        if self.zero_triu:
-            ones = torch.ones((x.size(2), x.size(3)))
-            x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
-
-        return x
-
-    def forward(self, query, key, value, pos_emb, mask):
-        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q, k, v = self.forward_qkv(query, key, value)
-        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
-
-        n_batch_pos = pos_emb.size(0)
-        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
-        p = p.transpose(1, 2)  # (batch, head, time1, d_k)
-
-        # (batch, head, time1, d_k)
-        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
-        # (batch, head, time1, d_k)
-        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
-
-        # compute attention score
-        # first compute matrix a and matrix c
-        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
-        # (batch, head, time1, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
-
-        # compute matrix b and matrix d
-        # (batch, head, time1, time1)
-        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
-        matrix_bd = self.rel_shift(matrix_bd)
-
-        scores = (matrix_ac + matrix_bd) / math.sqrt(
-            self.d_k
-        )  # (batch, head, time1, time2)
-
-        return self.forward_attention(v, scores, mask)
-
-
-class RelPositionMultiHeadedAttention(MultiHeadedAttention):
-    """Multi-Head Attention layer with relative position encoding (new implementation).
-
-    Details can be found in https://github.com/espnet/espnet/pull/2816.
-
-    Paper: https://arxiv.org/abs/1901.02860
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
-
-    """
-
-    def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
-        """Construct an RelPositionMultiHeadedAttention object."""
-        super().__init__(n_head, n_feat, dropout_rate)
-        self.zero_triu = zero_triu
-        # linear transformation for positional encoding
-        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
-        # these two learnable bias are used in matrix c and matrix d
-        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
-        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
-        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
-        torch.nn.init.xavier_uniform_(self.pos_bias_u)
-        torch.nn.init.xavier_uniform_(self.pos_bias_v)
-
-    def rel_shift(self, x):
-        """Compute relative positional encoding.
-
-        Args:
-            x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
-            time1 means the length of query vector.
-
-        Returns:
-            torch.Tensor: Output tensor.
-
-        """
-        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
-        x_padded = torch.cat([zero_pad, x], dim=-1)
-
-        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
-        x = x_padded[:, :, 1:].view_as(x)[
-            :, :, :, : x.size(-1) // 2 + 1
-            ]  # only keep the positions from 0 to time2
-
-        if self.zero_triu:
-            ones = torch.ones((x.size(2), x.size(3)), device=x.device)
-            x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
-
-        return x
-
-    def forward(self, query, key, value, pos_emb, mask):
-        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            pos_emb (torch.Tensor): Positional embedding tensor
-                (#batch, 2*time1-1, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q, k, v = self.forward_qkv(query, key, value)
-        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
-
-        n_batch_pos = pos_emb.size(0)
-        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
-        p = p.transpose(1, 2)  # (batch, head, 2*time1-1, d_k)
-
-        # (batch, head, time1, d_k)
-        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
-        # (batch, head, time1, d_k)
-        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
-
-        # compute attention score
-        # first compute matrix a and matrix c
-        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
-        # (batch, head, time1, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
-
-        # compute matrix b and matrix d
-        # (batch, head, time1, 2*time1-1)
-        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
-        matrix_bd = self.rel_shift(matrix_bd)
-
-        scores = (matrix_ac + matrix_bd) / math.sqrt(
-            self.d_k
-        )  # (batch, head, time1, time2)
-
-        return self.forward_attention(v, scores, mask)
-
-
-class MultiHeadedAttentionSANM(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionSANM, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        # self.linear_q = nn.Linear(n_feat, n_feat)
-        # self.linear_k = nn.Linear(n_feat, n_feat)
-        # self.linear_v = nn.Linear(n_feat, n_feat)
-        if lora_list is not None:
-            if "o" in lora_list:
-                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
-            else:
-                self.linear_out = nn.Linear(n_feat, n_feat)
-            lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
-            if lora_qkv_list == [False, False, False]:
-                self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
-            else:
-                self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
-        else:
-            self.linear_out = nn.Linear(n_feat, n_feat)
-            self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-        self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
-        # padding
-        left_padding = (kernel_size - 1) // 2
-        if sanm_shfit > 0:
-            left_padding = left_padding + sanm_shfit
-        right_padding = kernel_size - 1 - left_padding
-        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
-
-    def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
-        b, t, d = inputs.size()
-        if mask is not None:
-            mask = torch.reshape(mask, (b, -1, 1))
-            if mask_shfit_chunk is not None:
-                mask = mask * mask_shfit_chunk
-            inputs = inputs * mask
-
-        x = inputs.transpose(1, 2)
-        x = self.pad_fn(x)
-        x = self.fsmn_block(x)
-        x = x.transpose(1, 2)
-        x += inputs
-        x = self.dropout(x)
-        if mask is not None:
-            x = x * mask
-        return x
-
-    def forward_qkv(self, x):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-        b, t, d = x.size()
-        q_k_v = self.linear_q_k_v(x)
-        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
-        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-
-        return q_h, k_h, v_h, v
-
-    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            if mask_att_chunk_encoder is not None:
-                mask = mask * mask_att_chunk_encoder
-
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x)  # (batch, time1, d_model)
-
-    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
-        return att_outs + fsmn_memory
-
-    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        if chunk_size is not None and look_back > 0 or look_back == -1:
-            if cache is not None:
-                k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
-                v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
-                k_h = torch.cat((cache["k"], k_h), dim=2)
-                v_h = torch.cat((cache["v"], v_h), dim=2)
-
-                cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
-                cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
-                if look_back != -1:
-                    cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
-                    cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
-            else:
-                cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
-                             "v": v_h[:, :, :-(chunk_size[2]), :]}
-                cache = cache_tmp
-        fsmn_memory = self.forward_fsmn(v, None)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, None)
-        return att_outs + fsmn_memory, cache
-
-
-class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
-        return att_outs + fsmn_memory
-
-class MultiHeadedAttentionSANMDecoder(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionSANMDecoder, self).__init__()
-
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-        self.fsmn_block = nn.Conv1d(n_feat, n_feat,
-                                    kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
-        # padding
-        # padding
-        left_padding = (kernel_size - 1) // 2
-        if sanm_shfit > 0:
-            left_padding = left_padding + sanm_shfit
-        right_padding = kernel_size - 1 - left_padding
-        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
-        self.kernel_size = kernel_size
-
-    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
-        '''
-        :param x: (#batch, time1, size).
-        :param mask: Mask tensor (#batch, 1, time)
-        :return:
-        '''
-        # print("in fsmn, inputs", inputs.size())
-        b, t, d = inputs.size()
-        # logging.info(
-        #     "mask: {}".format(mask.size()))
-        if mask is not None:
-            mask = torch.reshape(mask, (b ,-1, 1))
-            # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
-            if mask_shfit_chunk is not None:
-                # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
-                mask = mask * mask_shfit_chunk
-            # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
-            # print("in fsmn, mask", mask.size())
-            # print("in fsmn, inputs", inputs.size())
-            inputs = inputs * mask
-
-        x = inputs.transpose(1, 2)
-        b, d, t = x.size()
-        if cache is None:
-            # print("in fsmn, cache is None, x", x.size())
-
-            x = self.pad_fn(x)
-            if not self.training:
-                cache = x
-        else:
-            # print("in fsmn, cache is not None, x", x.size())
-            # x = torch.cat((x, cache), dim=2)[:, :, :-1]
-            # if t < self.kernel_size:
-            #     x = self.pad_fn(x)
-            x = torch.cat((cache[:, :, 1:], x), dim=2)
-            x = x[:, :, -(self.kernel_size+t-1):]
-            # print("in fsmn, cache is not None, x_cat", x.size())
-            cache = x
-        x = self.fsmn_block(x)
-        x = x.transpose(1, 2)
-        # print("in fsmn, fsmn_out", x.size())
-        if x.size(1) != inputs.size(1):
-            inputs = inputs[:, -1, :]
-
-        x = x + inputs
-        x = self.dropout(x)
-        if mask is not None:
-            x = x * mask
-        return x, cache
-
-class MultiHeadedAttentionCrossAtt(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionCrossAtt, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        if lora_list is not None:
-            if "q" in lora_list:
-                self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
-            else:
-                self.linear_q = nn.Linear(n_feat, n_feat)
-            lora_kv_list = ["k" in lora_list, "v" in lora_list]
-            if lora_kv_list == [False, False]:
-                self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
-            else:
-                self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
-                                      r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
-            if "o" in lora_list:
-                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
-            else:
-                self.linear_out = nn.Linear(n_feat, n_feat)
-        else:
-            self.linear_q = nn.Linear(n_feat, n_feat)
-            self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
-            self.linear_out = nn.Linear(n_feat, n_feat)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-    def forward_qkv(self, x, memory):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-
-        # print("in forward_qkv, x", x.size())
-        b = x.size(0)
-        q = self.linear_q(x)
-        q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time1, d_k)
-
-        k_v = self.linear_k_v(memory)
-        k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
-        k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
-
-
-        return q_h, k_h, v_h
-
-    def forward_attention(self, value, scores, mask):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
-            # logging.info(
-            #     "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x)  # (batch, time1, d_model)
-
-    def forward(self, x, memory, memory_mask):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h = self.forward_qkv(x, memory)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        return self.forward_attention(v_h, scores, memory_mask)
-
-    def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h = self.forward_qkv(x, memory)
-        if chunk_size is not None and look_back > 0:
-            if cache is not None:
-                k_h = torch.cat((cache["k"], k_h), dim=2)
-                v_h = torch.cat((cache["v"], v_h), dim=2)
-                cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
-                cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
-            else:
-                cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
-                             "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
-                cache = cache_tmp
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        return self.forward_attention(v_h, scores, None), cache
-
-
-class MultiHeadSelfAttention(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, in_feat, n_feat, dropout_rate):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadSelfAttention, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        self.linear_out = nn.Linear(n_feat, n_feat)
-        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-    def forward_qkv(self, x):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-        b, t, d = x.size()
-        q_k_v = self.linear_q_k_v(x)
-        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
-        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-
-        return q_h, k_h, v_h, v
-
-    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            if mask_att_chunk_encoder is not None:
-                mask = mask * mask_att_chunk_encoder
-
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x)  # (batch, time1, d_model)
-
-    def forward(self, x, mask, mask_att_chunk_encoder=None):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
-        return att_outs
-
-class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
-    """RelPositionMultiHeadedAttention definition.
-    Args:
-        num_heads: Number of attention heads.
-        embed_size: Embedding size.
-        dropout_rate: Dropout rate.
-    """
-
-    def __init__(
-        self,
-        num_heads: int,
-        embed_size: int,
-        dropout_rate: float = 0.0,
-        simplified_attention_score: bool = False,
-    ) -> None:
-        """Construct an MultiHeadedAttention object."""
-        super().__init__()
-
-        self.d_k = embed_size // num_heads
-        self.num_heads = num_heads
-
-        assert self.d_k * num_heads == embed_size, (
-            "embed_size (%d) must be divisible by num_heads (%d)",
-            (embed_size, num_heads),
-        )
-
-        self.linear_q = torch.nn.Linear(embed_size, embed_size)
-        self.linear_k = torch.nn.Linear(embed_size, embed_size)
-        self.linear_v = torch.nn.Linear(embed_size, embed_size)
-
-        self.linear_out = torch.nn.Linear(embed_size, embed_size)
-
-        if simplified_attention_score:
-            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
-
-            self.compute_att_score = self.compute_simplified_attention_score
-        else:
-            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
-
-            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
-            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
-            torch.nn.init.xavier_uniform_(self.pos_bias_u)
-            torch.nn.init.xavier_uniform_(self.pos_bias_v)
-
-            self.compute_att_score = self.compute_attention_score
-
-        self.dropout = torch.nn.Dropout(p=dropout_rate)
-        self.attn = None
-
-    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
-        """Compute relative positional encoding.
-        Args:
-            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
-            left_context: Number of frames in left context.
-        Returns:
-            x: Output sequence. (B, H, T_1, T_2)
-        """
-        batch_size, n_heads, time1, n = x.shape
-        time2 = time1 + left_context
-
-        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
-
-        return x.as_strided(
-            (batch_size, n_heads, time1, time2),
-            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
-            storage_offset=(n_stride * (time1 - 1)),
-        )
-
-    def compute_simplified_attention_score(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        pos_enc: torch.Tensor,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Simplified attention score computation.
-        Reference: https://github.com/k2-fsa/icefall/pull/458
-        Args:
-            query: Transformed query tensor. (B, H, T_1, d_k)
-            key: Transformed key tensor. (B, H, T_2, d_k)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            left_context: Number of frames in left context.
-        Returns:
-            : Attention score. (B, H, T_1, T_2)
-        """
-        pos_enc = self.linear_pos(pos_enc)
-
-        matrix_ac = torch.matmul(query, key.transpose(2, 3))
-
-        matrix_bd = self.rel_shift(
-            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
-            left_context=left_context,
-        )
-
-        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
-    def compute_attention_score(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        pos_enc: torch.Tensor,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Attention score computation.
-        Args:
-            query: Transformed query tensor. (B, H, T_1, d_k)
-            key: Transformed key tensor. (B, H, T_2, d_k)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            left_context: Number of frames in left context.
-        Returns:
-            : Attention score. (B, H, T_1, T_2)
-        """
-        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
-
-        query = query.transpose(1, 2)
-        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
-        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
-
-        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
-
-        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
-        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
-
-        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
-    def forward_qkv(
-        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        """Transform query, key and value.
-        Args:
-            query: Query tensor. (B, T_1, size)
-            key: Key tensor. (B, T_2, size)
-            v: Value tensor. (B, T_2, size)
-        Returns:
-            q: Transformed query tensor. (B, H, T_1, d_k)
-            k: Transformed key tensor. (B, H, T_2, d_k)
-            v: Transformed value tensor. (B, H, T_2, d_k)
-        """
-        n_batch = query.size(0)
-
-        q = (
-            self.linear_q(query)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-        k = (
-            self.linear_k(key)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-        v = (
-            self.linear_v(value)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-
-        return q, k, v
-
-    def forward_attention(
-        self,
-        value: torch.Tensor,
-        scores: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
-        """Compute attention context vector.
-        Args:
-            value: Transformed value. (B, H, T_2, d_k)
-            scores: Attention score. (B, H, T_1, T_2)
-            mask: Source mask. (B, T_2)
-            chunk_mask: Chunk mask. (T_1, T_1)
-        Returns:
-           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
-        """
-        batch_size = scores.size(0)
-        mask = mask.unsqueeze(1).unsqueeze(2)
-        if chunk_mask is not None:
-            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
-        scores = scores.masked_fill(mask, float("-inf"))
-        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
-
-        attn_output = self.dropout(self.attn)
-        attn_output = torch.matmul(attn_output, value)
-
-        attn_output = self.linear_out(
-            attn_output.transpose(1, 2)
-            .contiguous()
-            .view(batch_size, -1, self.num_heads * self.d_k)
-        )
-
-        return attn_output
-
-    def forward(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Compute scaled dot product attention with rel. positional encoding.
-        Args:
-            query: Query tensor. (B, T_1, size)
-            key: Key tensor. (B, T_2, size)
-            value: Value tensor. (B, T_2, size)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            mask: Source mask. (B, T_2)
-            chunk_mask: Chunk mask. (T_1, T_1)
-            left_context: Number of frames in left context.
-        Returns:
-            : Output tensor. (B, T_1, H * d_k)
-        """
-        q, k, v = self.forward_qkv(query, key, value)
-        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
-        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
-
-
-class CosineDistanceAttention(nn.Module):
-    """ Compute Cosine Distance between spk decoder output and speaker profile 
-    Args:
-        profile_path: speaker profile file path (.npy file)
-    """
-
-    def __init__(self):
-        super().__init__()
-        self.softmax = nn.Softmax(dim=-1)
-
-    def forward(self, spk_decoder_out, profile, profile_lens=None):
-        """
-        Args:
-            spk_decoder_out(torch.Tensor):(B, L, D)
-            spk_profiles(torch.Tensor):(B, N, D)
-        """
-        x = spk_decoder_out.unsqueeze(2)  # (B, L, 1, D)
-        if profile_lens is not None:
-            
-            mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
-            )
-            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
-            weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0)  # (B, L, N)
-        else:
-            x = x[:, -1:, :, :]
-            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
-            weights = self.softmax(weights_not_softmax)  # (B, 1, N)
-        spk_embedding = torch.matmul(weights, profile.to(weights.device))  # (B, L, D)
-
-        return spk_embedding, weights
diff --git a/funasr/models/ct_transformer/encoder.py b/funasr/models/ct_transformer/encoder.py
deleted file mode 100644
index 784baf3..0000000
--- a/funasr/models/ct_transformer/encoder.py
+++ /dev/null
@@ -1,383 +0,0 @@
-from typing import List
-from typing import Optional
-from typing import Sequence
-from typing import Tuple
-from typing import Union
-import logging
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from funasr.models.scama.chunk_utilis import overlap_chunk
-import numpy as np
-from funasr.train_utils.device_funcs import to_device
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.sanm.attention import MultiHeadedAttention
-from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
-from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
-from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
-from funasr.models.transformer.positionwise_feed_forward import (
-    PositionwiseFeedForward,  # noqa: H301
-)
-from funasr.models.transformer.utils.repeat import repeat
-from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
-from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
-from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
-from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
-from funasr.models.transformer.utils.subsampling import TooShortUttError
-from funasr.models.transformer.utils.subsampling import check_short_utt
-from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
-
-from funasr.models.ctc.ctc import CTC
-
-from funasr.register import tables
-
-class EncoderLayerSANM(nn.Module):
-    def __init__(
-        self,
-        in_size,
-        size,
-        self_attn,
-        feed_forward,
-        dropout_rate,
-        normalize_before=True,
-        concat_after=False,
-        stochastic_depth_rate=0.0,
-    ):
-        """Construct an EncoderLayer object."""
-        super(EncoderLayerSANM, self).__init__()
-        self.self_attn = self_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(in_size)
-        self.norm2 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.in_size = in_size
-        self.size = size
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        if self.concat_after:
-            self.concat_linear = nn.Linear(size + size, size)
-        self.stochastic_depth_rate = stochastic_depth_rate
-        self.dropout_rate = dropout_rate
-
-    def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
-        """Compute encoded features.
-
-        Args:
-            x_input (torch.Tensor): Input tensor (#batch, time, size).
-            mask (torch.Tensor): Mask tensor for the input (#batch, time).
-            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time, size).
-            torch.Tensor: Mask tensor (#batch, time).
-
-        """
-        skip_layer = False
-        # with stochastic depth, residual connection `x + f(x)` becomes
-        # `x <- x + 1 / (1 - p) * f(x)` at training time.
-        stoch_layer_coeff = 1.0
-        if self.training and self.stochastic_depth_rate > 0:
-            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
-            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
-        if skip_layer:
-            if cache is not None:
-                x = torch.cat([cache, x], dim=1)
-            return x, mask
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm1(x)
-
-        if self.concat_after:
-            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
-            if self.in_size == self.size:
-                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
-            else:
-                x = stoch_layer_coeff * self.concat_linear(x_concat)
-        else:
-            if self.in_size == self.size:
-                x = residual + stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
-                )
-            else:
-                x = stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
-                )
-        if not self.normalize_before:
-            x = self.norm1(x)
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm2(x)
-        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm2(x)
-
-        return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
-
-    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
-        """Compute encoded features.
-
-        Args:
-            x_input (torch.Tensor): Input tensor (#batch, time, size).
-            mask (torch.Tensor): Mask tensor for the input (#batch, time).
-            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time, size).
-            torch.Tensor: Mask tensor (#batch, time).
-
-        """
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm1(x)
-
-        if self.in_size == self.size:
-            attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
-            x = residual + attn
-        else:
-            x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
-
-        if not self.normalize_before:
-            x = self.norm1(x)
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm2(x)
-        x = residual + self.feed_forward(x)
-        if not self.normalize_before:
-            x = self.norm2(x)
-
-        return x, cache
-
-
-@tables.register("encoder_classes", "SANMVadEncoder")
-class SANMVadEncoder(nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-
-    """
-
-    def __init__(
-        self,
-        input_size: int,
-        output_size: int = 256,
-        attention_heads: int = 4,
-        linear_units: int = 2048,
-        num_blocks: int = 6,
-        dropout_rate: float = 0.1,
-        positional_dropout_rate: float = 0.1,
-        attention_dropout_rate: float = 0.0,
-        input_layer: Optional[str] = "conv2d",
-        pos_enc_class=SinusoidalPositionEncoder,
-        normalize_before: bool = True,
-        concat_after: bool = False,
-        positionwise_layer_type: str = "linear",
-        positionwise_conv_kernel_size: int = 1,
-        padding_idx: int = -1,
-        interctc_layer_idx: List[int] = [],
-        interctc_use_conditioning: bool = False,
-        kernel_size : int = 11,
-        sanm_shfit : int = 0,
-        selfattention_layer_type: str = "sanm",
-    ):
-        super().__init__()
-        self._output_size = output_size
-
-        if input_layer == "linear":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Linear(input_size, output_size),
-                torch.nn.LayerNorm(output_size),
-                torch.nn.Dropout(dropout_rate),
-                torch.nn.ReLU(),
-                pos_enc_class(output_size, positional_dropout_rate),
-            )
-        elif input_layer == "conv2d":
-            self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d2":
-            self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d6":
-            self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
-        elif input_layer == "conv2d8":
-            self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
-        elif input_layer == "embed":
-            self.embed = torch.nn.Sequential(
-                torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
-                SinusoidalPositionEncoder(),
-            )
-        elif input_layer is None:
-            if input_size == output_size:
-                self.embed = None
-            else:
-                self.embed = torch.nn.Linear(input_size, output_size)
-        elif input_layer == "pe":
-            self.embed = SinusoidalPositionEncoder()
-        else:
-            raise ValueError("unknown input_layer: " + input_layer)
-        self.normalize_before = normalize_before
-        if positionwise_layer_type == "linear":
-            positionwise_layer = PositionwiseFeedForward
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d":
-            positionwise_layer = MultiLayeredConv1d
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        elif positionwise_layer_type == "conv1d-linear":
-            positionwise_layer = Conv1dLinear
-            positionwise_layer_args = (
-                output_size,
-                linear_units,
-                positionwise_conv_kernel_size,
-                dropout_rate,
-            )
-        else:
-            raise NotImplementedError("Support only linear or conv1d.")
-
-        if selfattention_layer_type == "selfattn":
-            encoder_selfattn_layer = MultiHeadedAttention
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                attention_dropout_rate,
-            )
-
-        elif selfattention_layer_type == "sanm":
-            self.encoder_selfattn_layer = MultiHeadedAttentionSANMwithMask
-            encoder_selfattn_layer_args0 = (
-                attention_heads,
-                input_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-
-            encoder_selfattn_layer_args = (
-                attention_heads,
-                output_size,
-                output_size,
-                attention_dropout_rate,
-                kernel_size,
-                sanm_shfit,
-            )
-
-        self.encoders0 = repeat(
-            1,
-            lambda lnum: EncoderLayerSANM(
-                input_size,
-                output_size,
-                self.encoder_selfattn_layer(*encoder_selfattn_layer_args0),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-
-        self.encoders = repeat(
-            num_blocks-1,
-            lambda lnum: EncoderLayerSANM(
-                output_size,
-                output_size,
-                self.encoder_selfattn_layer(*encoder_selfattn_layer_args),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-
-        self.interctc_layer_idx = interctc_layer_idx
-        if len(interctc_layer_idx) > 0:
-            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
-        self.interctc_use_conditioning = interctc_use_conditioning
-        self.conditioning_layer = None
-        self.dropout = nn.Dropout(dropout_rate)
-
-    def output_size(self) -> int:
-        return self._output_size
-
-    def forward(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        vad_indexes: torch.Tensor,
-        prev_states: torch.Tensor = None,
-        ctc: CTC = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Embed positions in tensor.
-
-        Args:
-            xs_pad: input tensor (B, L, D)
-            ilens: input length (B)
-            prev_states: Not to be used now.
-        Returns:
-            position embedded tensor and mask
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
-        no_future_masks = masks & sub_masks
-        xs_pad *= self.output_size()**0.5
-        if self.embed is None:
-            xs_pad = xs_pad
-        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
-              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
-            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
-            if short_status:
-                raise TooShortUttError(
-                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
-                    f"(it needs more than {limit_size} frames), return empty results",
-                    xs_pad.size(1),
-                    limit_size,
-                )
-            xs_pad, masks = self.embed(xs_pad, masks)
-        else:
-            xs_pad = self.embed(xs_pad)
-
-        # xs_pad = self.dropout(xs_pad)
-        mask_tup0 = [masks, no_future_masks]
-        encoder_outs = self.encoders0(xs_pad, mask_tup0)
-        xs_pad, _ = encoder_outs[0], encoder_outs[1]
-        intermediate_outs = []
-
-
-        for layer_idx, encoder_layer in enumerate(self.encoders):
-                if layer_idx + 1 == len(self.encoders):
-                    # This is last layer.
-                    coner_mask = torch.ones(masks.size(0),
-                                            masks.size(-1),
-                                            masks.size(-1),
-                                            device=xs_pad.device,
-                                            dtype=torch.bool)
-                    for word_index, length in enumerate(ilens):
-                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
-                                                                vad_indexes[word_index],
-                                                                device=xs_pad.device)
-                    layer_mask = masks & coner_mask
-                else:
-                    layer_mask = no_future_masks
-                mask_tup1 = [masks, layer_mask]
-                encoder_outs = encoder_layer(xs_pad, mask_tup1)
-                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-
-        olens = masks.squeeze(1).sum(1)
-        if len(intermediate_outs) > 0:
-            return (xs_pad, intermediate_outs), olens, None
-        return xs_pad, olens, None
diff --git a/funasr/models/ct_transformer/model.py b/funasr/models/ct_transformer/model.py
index d843686..7187f45 100644
--- a/funasr/models/ct_transformer/model.py
+++ b/funasr/models/ct_transformer/model.py
@@ -60,7 +60,7 @@
         
         
 
-    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs):
         """Compute loss value from buffer sequences.
 
         Args:
diff --git a/funasr/models/ct_transformer/utils.py b/funasr/models/ct_transformer/utils.py
index 917f2e0..c5f85e6 100644
--- a/funasr/models/ct_transformer/utils.py
+++ b/funasr/models/ct_transformer/utils.py
@@ -14,26 +14,6 @@
     return sentences
 
 
-# def split_words(text: str, **kwargs):
-#     words = []
-#     segs = text.split()
-#     for seg in segs:
-#         # There is no space in seg.
-#         current_word = ""
-#         for c in seg:
-#             if len(c.encode()) == 1:
-#                 # This is an ASCII char.
-#                 current_word += c
-#             else:
-#                 # This is a Chinese char.
-#                 if len(current_word) > 0:
-#                     words.append(current_word)
-#                     current_word = ""
-#                 words.append(c)
-#         if len(current_word) > 0:
-#             words.append(current_word)
-#
-#     return words
 
 def split_words(text: str, jieba_usr_dict=None, **kwargs):
     if jieba_usr_dict:
diff --git a/funasr/models/ct_transformer/vad_realtime_transformer.py b/funasr/models/ct_transformer/vad_realtime_transformer.py
deleted file mode 100644
index 155057c..0000000
--- a/funasr/models/ct_transformer/vad_realtime_transformer.py
+++ /dev/null
@@ -1,135 +0,0 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-
-import torch
-import torch.nn as nn
-
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder
-from funasr.models.ct_transformer.sanm_encoder import SANMVadEncoder as Encoder
-
-
-class VadRealtimeTransformer(torch.nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
-    https://arxiv.org/pdf/2003.01309.pdf
-    """
-    def __init__(
-        self,
-        vocab_size: int,
-        punc_size: int,
-        pos_enc: str = None,
-        embed_unit: int = 128,
-        att_unit: int = 256,
-        head: int = 2,
-        unit: int = 1024,
-        layer: int = 4,
-        dropout_rate: float = 0.5,
-        kernel_size: int = 11,
-        sanm_shfit: int = 0,
-    ):
-        super().__init__()
-        if pos_enc == "sinusoidal":
-            #            pos_enc_class = PositionalEncoding
-            pos_enc_class = SinusoidalPositionEncoder
-        elif pos_enc is None:
-
-            def pos_enc_class(*args, **kwargs):
-                return nn.Sequential()  # indentity
-
-        else:
-            raise ValueError(f"unknown pos-enc option: {pos_enc}")
-
-        self.embed = nn.Embedding(vocab_size, embed_unit)
-        self.encoder = Encoder(
-            input_size=embed_unit,
-            output_size=att_unit,
-            attention_heads=head,
-            linear_units=unit,
-            num_blocks=layer,
-            dropout_rate=dropout_rate,
-            input_layer="pe",
-            # pos_enc_class=pos_enc_class,
-            padding_idx=0,
-            kernel_size=kernel_size,
-            sanm_shfit=sanm_shfit,
-        )
-        self.decoder = nn.Linear(att_unit, punc_size)
-
-
-#    def _target_mask(self, ys_in_pad):
-#        ys_mask = ys_in_pad != 0
-#        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
-#        return ys_mask.unsqueeze(-2) & m
-
-    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
-                vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
-        """Compute loss value from buffer sequences.
-
-        Args:
-            input (torch.Tensor): Input ids. (batch, len)
-            hidden (torch.Tensor): Target ids. (batch, len)
-
-        """
-        x = self.embed(input)
-        # mask = self._target_mask(input)
-        h, _, _ = self.encoder(x, text_lengths, vad_indexes)
-        y = self.decoder(h)
-        return y, None
-
-    def with_vad(self):
-        return True
-
-    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
-        """Score new token.
-
-        Args:
-            y (torch.Tensor): 1D torch.int64 prefix tokens.
-            state: Scorer state for prefix tokens
-            x (torch.Tensor): encoder feature that generates ys.
-
-        Returns:
-            tuple[torch.Tensor, Any]: Tuple of
-                torch.float32 scores for next token (vocab_size)
-                and next state for ys
-
-        """
-        y = y.unsqueeze(0)
-        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
-        h = self.decoder(h[:, -1])
-        logp = h.log_softmax(dim=-1).squeeze(0)
-        return logp, cache
-
-    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
-        """Score new token batch.
-
-        Args:
-            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
-            states (List[Any]): Scorer states for prefix tokens.
-            xs (torch.Tensor):
-                The encoder feature that generates ys (n_batch, xlen, n_feat).
-
-        Returns:
-            tuple[torch.Tensor, List[Any]]: Tuple of
-                batchfied scores for next token with shape of `(n_batch, vocab_size)`
-                and next state list for ys.
-
-        """
-        # merge states
-        n_batch = len(ys)
-        n_layers = len(self.encoder.encoders)
-        if states[0] is None:
-            batch_state = None
-        else:
-            # transpose state of [batch, layer] into [layer, batch]
-            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
-
-        # batch decoding
-        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
-        h = self.decoder(h[:, -1])
-        logp = h.log_softmax(dim=-1)
-
-        # transpose state of [layer, batch] into [batch, layer]
-        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
-        return logp, state_list
diff --git a/funasr/models/ct_transformer_streaming/attention.py b/funasr/models/ct_transformer_streaming/attention.py
index a35ddee..382334e 100644
--- a/funasr/models/ct_transformer_streaming/attention.py
+++ b/funasr/models/ct_transformer_streaming/attention.py
@@ -11,487 +11,12 @@
 import numpy
 import torch
 from torch import nn
+import torch.nn.functional as F
 from typing import Optional, Tuple
 
-import torch.nn.functional as F
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-import funasr.models.lora.layers as lora
-
-class MultiHeadedAttention(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, n_feat, dropout_rate):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttention, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        self.linear_q = nn.Linear(n_feat, n_feat)
-        self.linear_k = nn.Linear(n_feat, n_feat)
-        self.linear_v = nn.Linear(n_feat, n_feat)
-        self.linear_out = nn.Linear(n_feat, n_feat)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-    def forward_qkv(self, query, key, value):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-        n_batch = query.size(0)
-        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
-        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
-        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
-        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
-        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
-        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
-
-        return q, k, v
-
-    def forward_attention(self, value, scores, mask):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x)  # (batch, time1, d_model)
-
-    def forward(self, query, key, value, mask):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q, k, v = self.forward_qkv(query, key, value)
-        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
-        return self.forward_attention(v, scores, mask)
-
-
-class LegacyRelPositionMultiHeadedAttention(MultiHeadedAttention):
-    """Multi-Head Attention layer with relative position encoding (old version).
-
-    Details can be found in https://github.com/espnet/espnet/pull/2816.
-
-    Paper: https://arxiv.org/abs/1901.02860
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
-
-    """
-
-    def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
-        """Construct an RelPositionMultiHeadedAttention object."""
-        super().__init__(n_head, n_feat, dropout_rate)
-        self.zero_triu = zero_triu
-        # linear transformation for positional encoding
-        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
-        # these two learnable bias are used in matrix c and matrix d
-        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
-        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
-        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
-        torch.nn.init.xavier_uniform_(self.pos_bias_u)
-        torch.nn.init.xavier_uniform_(self.pos_bias_v)
-
-    def rel_shift(self, x):
-        """Compute relative positional encoding.
-
-        Args:
-            x (torch.Tensor): Input tensor (batch, head, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor.
-
-        """
-        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
-        x_padded = torch.cat([zero_pad, x], dim=-1)
-
-        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
-        x = x_padded[:, :, 1:].view_as(x)
-
-        if self.zero_triu:
-            ones = torch.ones((x.size(2), x.size(3)))
-            x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
-
-        return x
-
-    def forward(self, query, key, value, pos_emb, mask):
-        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            pos_emb (torch.Tensor): Positional embedding tensor (#batch, time1, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q, k, v = self.forward_qkv(query, key, value)
-        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
-
-        n_batch_pos = pos_emb.size(0)
-        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
-        p = p.transpose(1, 2)  # (batch, head, time1, d_k)
-
-        # (batch, head, time1, d_k)
-        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
-        # (batch, head, time1, d_k)
-        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
-
-        # compute attention score
-        # first compute matrix a and matrix c
-        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
-        # (batch, head, time1, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
-
-        # compute matrix b and matrix d
-        # (batch, head, time1, time1)
-        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
-        matrix_bd = self.rel_shift(matrix_bd)
-
-        scores = (matrix_ac + matrix_bd) / math.sqrt(
-            self.d_k
-        )  # (batch, head, time1, time2)
-
-        return self.forward_attention(v, scores, mask)
-
-
-class RelPositionMultiHeadedAttention(MultiHeadedAttention):
-    """Multi-Head Attention layer with relative position encoding (new implementation).
-
-    Details can be found in https://github.com/espnet/espnet/pull/2816.
-
-    Paper: https://arxiv.org/abs/1901.02860
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
-
-    """
-
-    def __init__(self, n_head, n_feat, dropout_rate, zero_triu=False):
-        """Construct an RelPositionMultiHeadedAttention object."""
-        super().__init__(n_head, n_feat, dropout_rate)
-        self.zero_triu = zero_triu
-        # linear transformation for positional encoding
-        self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
-        # these two learnable bias are used in matrix c and matrix d
-        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
-        self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
-        self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
-        torch.nn.init.xavier_uniform_(self.pos_bias_u)
-        torch.nn.init.xavier_uniform_(self.pos_bias_v)
-
-    def rel_shift(self, x):
-        """Compute relative positional encoding.
-
-        Args:
-            x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
-            time1 means the length of query vector.
-
-        Returns:
-            torch.Tensor: Output tensor.
-
-        """
-        zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
-        x_padded = torch.cat([zero_pad, x], dim=-1)
-
-        x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
-        x = x_padded[:, :, 1:].view_as(x)[
-            :, :, :, : x.size(-1) // 2 + 1
-            ]  # only keep the positions from 0 to time2
-
-        if self.zero_triu:
-            ones = torch.ones((x.size(2), x.size(3)), device=x.device)
-            x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
-
-        return x
-
-    def forward(self, query, key, value, pos_emb, mask):
-        """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            pos_emb (torch.Tensor): Positional embedding tensor
-                (#batch, 2*time1-1, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q, k, v = self.forward_qkv(query, key, value)
-        q = q.transpose(1, 2)  # (batch, time1, head, d_k)
-
-        n_batch_pos = pos_emb.size(0)
-        p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
-        p = p.transpose(1, 2)  # (batch, head, 2*time1-1, d_k)
-
-        # (batch, head, time1, d_k)
-        q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
-        # (batch, head, time1, d_k)
-        q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
-
-        # compute attention score
-        # first compute matrix a and matrix c
-        # as described in https://arxiv.org/abs/1901.02860 Section 3.3
-        # (batch, head, time1, time2)
-        matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
-
-        # compute matrix b and matrix d
-        # (batch, head, time1, 2*time1-1)
-        matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
-        matrix_bd = self.rel_shift(matrix_bd)
-
-        scores = (matrix_ac + matrix_bd) / math.sqrt(
-            self.d_k
-        )  # (batch, head, time1, time2)
-
-        return self.forward_attention(v, scores, mask)
-
-
-class MultiHeadedAttentionSANM(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, in_feat, n_feat, dropout_rate, kernel_size, sanm_shfit=0, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionSANM, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        # self.linear_q = nn.Linear(n_feat, n_feat)
-        # self.linear_k = nn.Linear(n_feat, n_feat)
-        # self.linear_v = nn.Linear(n_feat, n_feat)
-        if lora_list is not None:
-            if "o" in lora_list:
-                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
-            else:
-                self.linear_out = nn.Linear(n_feat, n_feat)
-            lora_qkv_list = ["q" in lora_list, "k" in lora_list, "v" in lora_list]
-            if lora_qkv_list == [False, False, False]:
-                self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
-            else:
-                self.linear_q_k_v = lora.MergedLinear(in_feat, n_feat * 3, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_qkv_list)
-        else:
-            self.linear_out = nn.Linear(n_feat, n_feat)
-            self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-        self.fsmn_block = nn.Conv1d(n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
-        # padding
-        left_padding = (kernel_size - 1) // 2
-        if sanm_shfit > 0:
-            left_padding = left_padding + sanm_shfit
-        right_padding = kernel_size - 1 - left_padding
-        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
-
-    def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
-        b, t, d = inputs.size()
-        if mask is not None:
-            mask = torch.reshape(mask, (b, -1, 1))
-            if mask_shfit_chunk is not None:
-                mask = mask * mask_shfit_chunk
-            inputs = inputs * mask
-
-        x = inputs.transpose(1, 2)
-        x = self.pad_fn(x)
-        x = self.fsmn_block(x)
-        x = x.transpose(1, 2)
-        x += inputs
-        x = self.dropout(x)
-        if mask is not None:
-            x = x * mask
-        return x
-
-    def forward_qkv(self, x):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-        b, t, d = x.size()
-        q_k_v = self.linear_q_k_v(x)
-        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
-        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-
-        return q_h, k_h, v_h, v
-
-    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            if mask_att_chunk_encoder is not None:
-                mask = mask * mask_att_chunk_encoder
-
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x)  # (batch, time1, d_model)
-
-    def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
-        return att_outs + fsmn_memory
-
-    def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
+from funasr.models.sanm.attention import MultiHeadedAttentionSANM
 
-        """
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        if chunk_size is not None and look_back > 0 or look_back == -1:
-            if cache is not None:
-                k_h_stride = k_h[:, :, :-(chunk_size[2]), :]
-                v_h_stride = v_h[:, :, :-(chunk_size[2]), :]
-                k_h = torch.cat((cache["k"], k_h), dim=2)
-                v_h = torch.cat((cache["v"], v_h), dim=2)
 
-                cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
-                cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
-                if look_back != -1:
-                    cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
-                    cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
-            else:
-                cache_tmp = {"k": k_h[:, :, :-(chunk_size[2]), :],
-                             "v": v_h[:, :, :-(chunk_size[2]), :]}
-                cache = cache_tmp
-        fsmn_memory = self.forward_fsmn(v, None)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, None)
-        return att_outs + fsmn_memory, cache
 
 
 class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
@@ -506,586 +31,4 @@
         att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
         return att_outs + fsmn_memory
 
-class MultiHeadedAttentionSANMDecoder(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionSANMDecoder, self).__init__()
-
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-        self.fsmn_block = nn.Conv1d(n_feat, n_feat,
-                                    kernel_size, stride=1, padding=0, groups=n_feat, bias=False)
-        # padding
-        # padding
-        left_padding = (kernel_size - 1) // 2
-        if sanm_shfit > 0:
-            left_padding = left_padding + sanm_shfit
-        right_padding = kernel_size - 1 - left_padding
-        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
-        self.kernel_size = kernel_size
-
-    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
-        '''
-        :param x: (#batch, time1, size).
-        :param mask: Mask tensor (#batch, 1, time)
-        :return:
-        '''
-        # print("in fsmn, inputs", inputs.size())
-        b, t, d = inputs.size()
-        # logging.info(
-        #     "mask: {}".format(mask.size()))
-        if mask is not None:
-            mask = torch.reshape(mask, (b ,-1, 1))
-            # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
-            if mask_shfit_chunk is not None:
-                # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
-                mask = mask * mask_shfit_chunk
-            # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
-            # print("in fsmn, mask", mask.size())
-            # print("in fsmn, inputs", inputs.size())
-            inputs = inputs * mask
-
-        x = inputs.transpose(1, 2)
-        b, d, t = x.size()
-        if cache is None:
-            # print("in fsmn, cache is None, x", x.size())
-
-            x = self.pad_fn(x)
-            if not self.training:
-                cache = x
-        else:
-            # print("in fsmn, cache is not None, x", x.size())
-            # x = torch.cat((x, cache), dim=2)[:, :, :-1]
-            # if t < self.kernel_size:
-            #     x = self.pad_fn(x)
-            x = torch.cat((cache[:, :, 1:], x), dim=2)
-            x = x[:, :, -(self.kernel_size+t-1):]
-            # print("in fsmn, cache is not None, x_cat", x.size())
-            cache = x
-        x = self.fsmn_block(x)
-        x = x.transpose(1, 2)
-        # print("in fsmn, fsmn_out", x.size())
-        if x.size(1) != inputs.size(1):
-            inputs = inputs[:, -1, :]
-
-        x = x + inputs
-        x = self.dropout(x)
-        if mask is not None:
-            x = x * mask
-        return x, cache
-
-class MultiHeadedAttentionCrossAtt(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, n_feat, dropout_rate, lora_list=None, lora_rank=8, lora_alpha=16, lora_dropout=0.1, encoder_output_size=None):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadedAttentionCrossAtt, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        if lora_list is not None:
-            if "q" in lora_list:
-                self.linear_q = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
-            else:
-                self.linear_q = nn.Linear(n_feat, n_feat)
-            lora_kv_list = ["k" in lora_list, "v" in lora_list]
-            if lora_kv_list == [False, False]:
-                self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
-            else:
-                self.linear_k_v = lora.MergedLinear(n_feat if encoder_output_size is None else encoder_output_size, n_feat * 2, 
-                                      r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout, enable_lora=lora_kv_list)
-            if "o" in lora_list:
-                self.linear_out = lora.Linear(n_feat, n_feat, r=lora_rank, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
-            else:
-                self.linear_out = nn.Linear(n_feat, n_feat)
-        else:
-            self.linear_q = nn.Linear(n_feat, n_feat)
-            self.linear_k_v = nn.Linear(n_feat if encoder_output_size is None else encoder_output_size, n_feat*2)
-            self.linear_out = nn.Linear(n_feat, n_feat)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-    def forward_qkv(self, x, memory):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-
-        # print("in forward_qkv, x", x.size())
-        b = x.size(0)
-        q = self.linear_q(x)
-        q_h = torch.reshape(q, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time1, d_k)
-
-        k_v = self.linear_k_v(memory)
-        k, v = torch.split(k_v, int(self.h*self.d_k), dim=-1)
-        k_h = torch.reshape(k, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, -1, self.h, self.d_k)).transpose(1, 2)    # (batch, head, time2, d_k)
-
-
-        return q_h, k_h, v_h
-
-    def forward_attention(self, value, scores, mask):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
-            # logging.info(
-            #     "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x)  # (batch, time1, d_model)
-
-    def forward(self, x, memory, memory_mask):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h = self.forward_qkv(x, memory)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        return self.forward_attention(v_h, scores, memory_mask)
-
-    def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h = self.forward_qkv(x, memory)
-        if chunk_size is not None and look_back > 0:
-            if cache is not None:
-                k_h = torch.cat((cache["k"], k_h), dim=2)
-                v_h = torch.cat((cache["v"], v_h), dim=2)
-                cache["k"] = k_h[:, :, -(look_back * chunk_size[1]):, :]
-                cache["v"] = v_h[:, :, -(look_back * chunk_size[1]):, :]
-            else:
-                cache_tmp = {"k": k_h[:, :, -(look_back * chunk_size[1]):, :],
-                             "v": v_h[:, :, -(look_back * chunk_size[1]):, :]}
-                cache = cache_tmp
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        return self.forward_attention(v_h, scores, None), cache
-
-
-class MultiHeadSelfAttention(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, in_feat, n_feat, dropout_rate):
-        """Construct an MultiHeadedAttention object."""
-        super(MultiHeadSelfAttention, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        self.linear_out = nn.Linear(n_feat, n_feat)
-        self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-    def forward_qkv(self, x):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-        b, t, d = x.size()
-        q_k_v = self.linear_q_k_v(x)
-        q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
-        q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time1, d_k)
-        k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-        v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2)  # (batch, head, time2, d_k)
-
-        return q_h, k_h, v_h, v
-
-    def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            if mask_att_chunk_encoder is not None:
-                mask = mask * mask_att_chunk_encoder
-
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
-            )
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x)  # (batch, time1, d_model)
-
-    def forward(self, x, mask, mask_att_chunk_encoder=None):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q_h, k_h, v_h, v = self.forward_qkv(x)
-        q_h = q_h * self.d_k ** (-0.5)
-        scores = torch.matmul(q_h, k_h.transpose(-2, -1))
-        att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
-        return att_outs
-
-class RelPositionMultiHeadedAttentionChunk(torch.nn.Module):
-    """RelPositionMultiHeadedAttention definition.
-    Args:
-        num_heads: Number of attention heads.
-        embed_size: Embedding size.
-        dropout_rate: Dropout rate.
-    """
-
-    def __init__(
-        self,
-        num_heads: int,
-        embed_size: int,
-        dropout_rate: float = 0.0,
-        simplified_attention_score: bool = False,
-    ) -> None:
-        """Construct an MultiHeadedAttention object."""
-        super().__init__()
-
-        self.d_k = embed_size // num_heads
-        self.num_heads = num_heads
-
-        assert self.d_k * num_heads == embed_size, (
-            "embed_size (%d) must be divisible by num_heads (%d)",
-            (embed_size, num_heads),
-        )
-
-        self.linear_q = torch.nn.Linear(embed_size, embed_size)
-        self.linear_k = torch.nn.Linear(embed_size, embed_size)
-        self.linear_v = torch.nn.Linear(embed_size, embed_size)
-
-        self.linear_out = torch.nn.Linear(embed_size, embed_size)
-
-        if simplified_attention_score:
-            self.linear_pos = torch.nn.Linear(embed_size, num_heads)
-
-            self.compute_att_score = self.compute_simplified_attention_score
-        else:
-            self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False)
-
-            self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
-            self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k))
-            torch.nn.init.xavier_uniform_(self.pos_bias_u)
-            torch.nn.init.xavier_uniform_(self.pos_bias_v)
-
-            self.compute_att_score = self.compute_attention_score
-
-        self.dropout = torch.nn.Dropout(p=dropout_rate)
-        self.attn = None
-
-    def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
-        """Compute relative positional encoding.
-        Args:
-            x: Input sequence. (B, H, T_1, 2 * T_1 - 1)
-            left_context: Number of frames in left context.
-        Returns:
-            x: Output sequence. (B, H, T_1, T_2)
-        """
-        batch_size, n_heads, time1, n = x.shape
-        time2 = time1 + left_context
-
-        batch_stride, n_heads_stride, time1_stride, n_stride = x.stride()
-
-        return x.as_strided(
-            (batch_size, n_heads, time1, time2),
-            (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride),
-            storage_offset=(n_stride * (time1 - 1)),
-        )
-
-    def compute_simplified_attention_score(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        pos_enc: torch.Tensor,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Simplified attention score computation.
-        Reference: https://github.com/k2-fsa/icefall/pull/458
-        Args:
-            query: Transformed query tensor. (B, H, T_1, d_k)
-            key: Transformed key tensor. (B, H, T_2, d_k)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            left_context: Number of frames in left context.
-        Returns:
-            : Attention score. (B, H, T_1, T_2)
-        """
-        pos_enc = self.linear_pos(pos_enc)
-
-        matrix_ac = torch.matmul(query, key.transpose(2, 3))
-
-        matrix_bd = self.rel_shift(
-            pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1),
-            left_context=left_context,
-        )
-
-        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
-    def compute_attention_score(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        pos_enc: torch.Tensor,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Attention score computation.
-        Args:
-            query: Transformed query tensor. (B, H, T_1, d_k)
-            key: Transformed key tensor. (B, H, T_2, d_k)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            left_context: Number of frames in left context.
-        Returns:
-            : Attention score. (B, H, T_1, T_2)
-        """
-        p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k)
-
-        query = query.transpose(1, 2)
-        q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2)
-        q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2)
-
-        matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1))
-
-        matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1))
-        matrix_bd = self.rel_shift(matrix_bd, left_context=left_context)
-
-        return (matrix_ac + matrix_bd) / math.sqrt(self.d_k)
-
-    def forward_qkv(
-        self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
-    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
-        """Transform query, key and value.
-        Args:
-            query: Query tensor. (B, T_1, size)
-            key: Key tensor. (B, T_2, size)
-            v: Value tensor. (B, T_2, size)
-        Returns:
-            q: Transformed query tensor. (B, H, T_1, d_k)
-            k: Transformed key tensor. (B, H, T_2, d_k)
-            v: Transformed value tensor. (B, H, T_2, d_k)
-        """
-        n_batch = query.size(0)
-
-        q = (
-            self.linear_q(query)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-        k = (
-            self.linear_k(key)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-        v = (
-            self.linear_v(value)
-            .view(n_batch, -1, self.num_heads, self.d_k)
-            .transpose(1, 2)
-        )
-
-        return q, k, v
-
-    def forward_attention(
-        self,
-        value: torch.Tensor,
-        scores: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-    ) -> torch.Tensor:
-        """Compute attention context vector.
-        Args:
-            value: Transformed value. (B, H, T_2, d_k)
-            scores: Attention score. (B, H, T_1, T_2)
-            mask: Source mask. (B, T_2)
-            chunk_mask: Chunk mask. (T_1, T_1)
-        Returns:
-           attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k)
-        """
-        batch_size = scores.size(0)
-        mask = mask.unsqueeze(1).unsqueeze(2)
-        if chunk_mask is not None:
-            mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask
-        scores = scores.masked_fill(mask, float("-inf"))
-        self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0)
-
-        attn_output = self.dropout(self.attn)
-        attn_output = torch.matmul(attn_output, value)
-
-        attn_output = self.linear_out(
-            attn_output.transpose(1, 2)
-            .contiguous()
-            .view(batch_size, -1, self.num_heads * self.d_k)
-        )
-
-        return attn_output
-
-    def forward(
-        self,
-        query: torch.Tensor,
-        key: torch.Tensor,
-        value: torch.Tensor,
-        pos_enc: torch.Tensor,
-        mask: torch.Tensor,
-        chunk_mask: Optional[torch.Tensor] = None,
-        left_context: int = 0,
-    ) -> torch.Tensor:
-        """Compute scaled dot product attention with rel. positional encoding.
-        Args:
-            query: Query tensor. (B, T_1, size)
-            key: Key tensor. (B, T_2, size)
-            value: Value tensor. (B, T_2, size)
-            pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size)
-            mask: Source mask. (B, T_2)
-            chunk_mask: Chunk mask. (T_1, T_1)
-            left_context: Number of frames in left context.
-        Returns:
-            : Output tensor. (B, T_1, H * d_k)
-        """
-        q, k, v = self.forward_qkv(query, key, value)
-        scores = self.compute_att_score(q, k, pos_enc, left_context=left_context)
-        return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask)
-
-
-class CosineDistanceAttention(nn.Module):
-    """ Compute Cosine Distance between spk decoder output and speaker profile 
-    Args:
-        profile_path: speaker profile file path (.npy file)
-    """
-
-    def __init__(self):
-        super().__init__()
-        self.softmax = nn.Softmax(dim=-1)
-
-    def forward(self, spk_decoder_out, profile, profile_lens=None):
-        """
-        Args:
-            spk_decoder_out(torch.Tensor):(B, L, D)
-            spk_profiles(torch.Tensor):(B, N, D)
-        """
-        x = spk_decoder_out.unsqueeze(2)  # (B, L, 1, D)
-        if profile_lens is not None:
-            
-            mask = (make_pad_mask(profile_lens)[:, None, :]).to(profile.device)
-            min_value = float(
-                numpy.finfo(torch.tensor(0, dtype=x.dtype).numpy().dtype).min
-            )
-            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1), dim=-1).masked_fill(mask, min_value)
-            weights = self.softmax(weights_not_softmax).masked_fill(mask, 0.0)  # (B, L, N)
-        else:
-            x = x[:, -1:, :, :]
-            weights_not_softmax=F.cosine_similarity(x, profile.unsqueeze(1).to(x.device), dim=-1)
-            weights = self.softmax(weights_not_softmax)  # (B, 1, N)
-        spk_embedding = torch.matmul(weights, profile.to(weights.device))  # (B, L, D)
 
-        return spk_embedding, weights
diff --git a/funasr/models/ct_transformer_streaming/encoder.py b/funasr/models/ct_transformer_streaming/encoder.py
index 784baf3..32ee2f2 100644
--- a/funasr/models/ct_transformer_streaming/encoder.py
+++ b/funasr/models/ct_transformer_streaming/encoder.py
@@ -12,7 +12,7 @@
 from funasr.train_utils.device_funcs import to_device
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.sanm.attention import MultiHeadedAttention
-from funasr.models.ct_transformer.attention import MultiHeadedAttentionSANMwithMask
+from funasr.models.ct_transformer_streaming.attention import MultiHeadedAttentionSANMwithMask
 from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
 from funasr.models.transformer.layer_norm import LayerNorm
 from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
diff --git a/funasr/models/ct_transformer_streaming/model.py b/funasr/models/ct_transformer_streaming/model.py
index 4c84261..5254d15 100644
--- a/funasr/models/ct_transformer_streaming/model.py
+++ b/funasr/models/ct_transformer_streaming/model.py
@@ -12,11 +12,12 @@
 import torch.nn as nn
 from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
 from funasr.utils.load_utils import load_audio_text_image_video
+from funasr.models.ct_transformer.model import CTTransformer
 
 from funasr.register import tables
 
 @tables.register("model_classes", "CTTransformerStreaming")
-class CTTransformerStreaming(nn.Module):
+class CTTransformerStreaming(CTTransformer):
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
@@ -24,43 +25,13 @@
     """
     def __init__(
         self,
-        encoder: str = None,
-        encoder_conf: dict = None,
-        vocab_size: int = -1,
-        punc_list: list = None,
-        punc_weight: list = None,
-        embed_unit: int = 128,
-        att_unit: int = 256,
-        dropout_rate: float = 0.5,
-        ignore_id: int = -1,
-        sos: int = 1,
-        eos: int = 2,
-        sentence_end_id: int = 3,
+        *args,
         **kwargs,
     ):
-        super().__init__()
+        super().__init__(*args, **kwargs)
 
-        punc_size = len(punc_list)
-        if punc_weight is None:
-            punc_weight = [1] * punc_size
-        
-        
-        self.embed = nn.Embedding(vocab_size, embed_unit)
-        encoder_class = tables.encoder_classes.get(encoder.lower())
-        encoder = encoder_class(**encoder_conf)
 
-        self.decoder = nn.Linear(att_unit, punc_size)
-        self.encoder = encoder
-        self.punc_list = punc_list
-        self.punc_weight = punc_weight
-        self.ignore_id = ignore_id
-        self.sos = sos
-        self.eos = eos
-        self.sentence_end_id = sentence_end_id
-        
-        
-
-    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+    def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, vad_indexes: torch.Tensor, **kwargs):
         """Compute loss value from buffer sequences.
 
         Args:
@@ -70,146 +41,14 @@
         """
         x = self.embed(text)
         # mask = self._target_mask(input)
-        h, _, _ = self.encoder(x, text_lengths)
+        h, _, _ = self.encoder(x, text_lengths, vad_indexes=vad_indexes)
         y = self.decoder(h)
         return y, None
 
     def with_vad(self):
-        return False
-
-    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
-        """Score new token.
-
-        Args:
-            y (torch.Tensor): 1D torch.int64 prefix tokens.
-            state: Scorer state for prefix tokens
-            x (torch.Tensor): encoder feature that generates ys.
-
-        Returns:
-            tuple[torch.Tensor, Any]: Tuple of
-                torch.float32 scores for next token (vocab_size)
-                and next state for ys
-
-        """
-        y = y.unsqueeze(0)
-        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
-        h = self.decoder(h[:, -1])
-        logp = h.log_softmax(dim=-1).squeeze(0)
-        return logp, cache
-
-    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
-        """Score new token batch.
-
-        Args:
-            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
-            states (List[Any]): Scorer states for prefix tokens.
-            xs (torch.Tensor):
-                The encoder feature that generates ys (n_batch, xlen, n_feat).
-
-        Returns:
-            tuple[torch.Tensor, List[Any]]: Tuple of
-                batchfied scores for next token with shape of `(n_batch, vocab_size)`
-                and next state list for ys.
-
-        """
-        # merge states
-        n_batch = len(ys)
-        n_layers = len(self.encoder.encoders)
-        if states[0] is None:
-            batch_state = None
-        else:
-            # transpose state of [batch, layer] into [layer, batch]
-            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
-
-        # batch decoding
-        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
-        h = self.decoder(h[:, -1])
-        logp = h.log_softmax(dim=-1)
-
-        # transpose state of [layer, batch] into [batch, layer]
-        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
-        return logp, state_list
-
-    def nll(
-        self,
-        text: torch.Tensor,
-        punc: torch.Tensor,
-        text_lengths: torch.Tensor,
-        punc_lengths: torch.Tensor,
-        max_length: Optional[int] = None,
-        vad_indexes: Optional[torch.Tensor] = None,
-        vad_indexes_lengths: Optional[torch.Tensor] = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Compute negative log likelihood(nll)
-
-        Normally, this function is called in batchify_nll.
-        Args:
-            text: (Batch, Length)
-            punc: (Batch, Length)
-            text_lengths: (Batch,)
-            max_lengths: int
-        """
-        batch_size = text.size(0)
-        # For data parallel
-        if max_length is None:
-            text = text[:, :text_lengths.max()]
-            punc = punc[:, :text_lengths.max()]
-        else:
-            text = text[:, :max_length]
-            punc = punc[:, :max_length]
-    
-        if self.with_vad():
-            # Should be VadRealtimeTransformer
-            assert vad_indexes is not None
-            y, _ = self.punc_forward(text, text_lengths, vad_indexes)
-        else:
-            # Should be TargetDelayTransformer,
-            y, _ = self.punc_forward(text, text_lengths)
-    
-        # Calc negative log likelihood
-        # nll: (BxL,)
-        if self.training == False:
-            _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
-            from sklearn.metrics import f1_score
-            f1_score = f1_score(punc.view(-1).detach().cpu().numpy(),
-                                indices.squeeze(-1).detach().cpu().numpy(),
-                                average='micro')
-            nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
-            return nll, text_lengths
-        else:
-            self.punc_weight = self.punc_weight.to(punc.device)
-            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
-                                  ignore_index=self.ignore_id)
-        # nll: (BxL,) -> (BxL,)
-        if max_length is None:
-            nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
-        else:
-            nll.masked_fill_(
-                make_pad_mask(text_lengths, maxlen=max_length + 1).to(nll.device).view(-1),
-                0.0,
-            )
-        # nll: (BxL,) -> (B, L)
-        nll = nll.view(batch_size, -1)
-        return nll, text_lengths
+        return True
 
 
-    def forward(
-        self,
-        text: torch.Tensor,
-        punc: torch.Tensor,
-        text_lengths: torch.Tensor,
-        punc_lengths: torch.Tensor,
-        vad_indexes: Optional[torch.Tensor] = None,
-        vad_indexes_lengths: Optional[torch.Tensor] = None,
-    ):
-        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
-        ntokens = y_lengths.sum()
-        loss = nll.sum() / ntokens
-        stats = dict(loss=loss.detach())
-    
-        # force_gatherable: to-device and to-tensor if scalar for DataParallel
-        loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
-        return loss, stats, weight
     
     def generate(self,
                  data_in,
@@ -217,22 +56,20 @@
                  key: list = None,
                  tokenizer=None,
                  frontend=None,
+                 cache: dict = {},
                  **kwargs,
                  ):
         assert len(data_in) == 1
+        
+        if len(cache) == 0:
+            cache["pre_text"] = []
         text = load_audio_text_image_video(data_in, data_type=kwargs.get("kwargs", "text"))[0]
-        vad_indexes = kwargs.get("vad_indexes", None)
-        # text = data_in[0]
-        # text_lengths = data_lengths[0] if data_lengths is not None else None
+        text = "".join(cache["pre_text"]) + " " + text
+
+
         split_size = kwargs.get("split_size", 20)
 
-        jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
-        if jieba_usr_dict and isinstance(jieba_usr_dict, str):
-            import jieba
-            jieba.load_userdict(jieba_usr_dict)
-            jieba_usr_dict = jieba
-            kwargs["jieba_usr_dict"] = "jieba_usr_dict"
-        tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
+        tokens = split_words(text)
         tokens_int = tokenizer.encode(tokens)
 
         mini_sentences = split_to_mini_sentence(tokens, split_size)
@@ -240,8 +77,9 @@
         assert len(mini_sentences) == len(mini_sentences_id)
         cache_sent = []
         cache_sent_id = torch.from_numpy(np.array([], dtype='int32'))
-        new_mini_sentence = ""
-        new_mini_sentence_punc = []
+        skip_num = 0
+        sentence_punc_list = []
+        sentence_words_list = []
         cache_pop_trigger_limit = 200
         results = []
         meta_data = {}
@@ -254,6 +92,7 @@
             data = {
                 "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
                 "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
+                "vad_indexes": torch.from_numpy(np.array([len(cache["pre_text"])], dtype='int32')),
             }
             data = to_device(data, kwargs["device"])
             # y, _ = self.wrapped_model(**data)
@@ -288,52 +127,42 @@
             #    continue
 
             punctuations_np = punctuations.cpu().numpy()
-            new_mini_sentence_punc += [int(x) for x in punctuations_np]
-            words_with_punc = []
-            for i in range(len(mini_sentence)):
-                if (i==0 or self.punc_list[punctuations[i-1]] == "銆�" or self.punc_list[punctuations[i-1]] == "锛�") and len(mini_sentence[i][0].encode()) == 1:
-                    mini_sentence[i] = mini_sentence[i].capitalize()
-                if i == 0:
-                    if len(mini_sentence[i][0].encode()) == 1:
-                        mini_sentence[i] = " " + mini_sentence[i]
-                if i > 0:
-                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
-                        mini_sentence[i] = " " + mini_sentence[i]
-                words_with_punc.append(mini_sentence[i])
-                if self.punc_list[punctuations[i]] != "_":
-                    punc_res = self.punc_list[punctuations[i]]
-                    if len(mini_sentence[i][0].encode()) == 1:
-                        if punc_res == "锛�":
-                            punc_res = ","
-                        elif punc_res == "銆�":
-                            punc_res = "."
-                        elif punc_res == "锛�":
-                            punc_res = "?"
-                    words_with_punc.append(punc_res)
-            new_mini_sentence += "".join(words_with_punc)
-            # Add Period for the end of the sentence
-            new_mini_sentence_out = new_mini_sentence
-            new_mini_sentence_punc_out = new_mini_sentence_punc
-            if mini_sentence_i == len(mini_sentences) - 1:
-                if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
-                    new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
-                elif new_mini_sentence[-1] == ",":
-                    new_mini_sentence_out = new_mini_sentence[:-1] + "."
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
-                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
-                    new_mini_sentence_out = new_mini_sentence + "銆�"
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
-                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
-                    new_mini_sentence_out = new_mini_sentence + "."
-                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
-            # keep a punctuations array for punc segment
-            if punc_array is None:
-                punc_array = punctuations
+            sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
+            sentence_words_list += mini_sentence
+
+        assert len(sentence_punc_list) == len(sentence_words_list)
+        words_with_punc = []
+        sentence_punc_list_out = []
+        for i in range(0, len(sentence_words_list)):
+            if i > 0:
+                if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
+                    sentence_words_list[i] = " " + sentence_words_list[i]
+            if skip_num < len(cache["pre_text"]):
+                skip_num += 1
             else:
-                punc_array = torch.cat([punc_array, punctuations], dim=0)
+                words_with_punc.append(sentence_words_list[i])
+            if skip_num >= len(cache["pre_text"]):
+                sentence_punc_list_out.append(sentence_punc_list[i])
+                if sentence_punc_list[i] != "_":
+                    words_with_punc.append(sentence_punc_list[i])
+        sentence_out = "".join(words_with_punc)
+
+        sentenceEnd = -1
+        for i in range(len(sentence_punc_list) - 2, 1, -1):
+            if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
+                sentenceEnd = i
+                break
+        cache["pre_text"] = sentence_words_list[sentenceEnd + 1:]
+        if sentence_out[-1] in self.punc_list:
+            sentence_out = sentence_out[:-1]
+            sentence_punc_list_out[-1] = "_"
+        # keep a punctuations array for punc segment
+        if punc_array is None:
+            punc_array = punctuations
+        else:
+            punc_array = torch.cat([punc_array, punctuations], dim=0)
         
-        result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
+        result_i = {"key": key[0], "text": sentence_out, "punc_array": punc_array}
         results.append(result_i)
     
         return results, meta_data
diff --git a/funasr/models/ct_transformer_streaming/template.yaml b/funasr/models/ct_transformer_streaming/template.yaml
index c20a098..2477ac2 100644
--- a/funasr/models/ct_transformer_streaming/template.yaml
+++ b/funasr/models/ct_transformer_streaming/template.yaml
@@ -27,13 +27,13 @@
         - 1.0
     sentence_end_id: 3
 
-encoder: SANMEncoder
+encoder: SANMVadEncoder
 encoder_conf:
     input_size: 256
     output_size: 256
     attention_heads: 8
     linear_units: 1024
-    num_blocks: 4
+    num_blocks: 3
     dropout_rate: 0.1
     positional_dropout_rate: 0.1
     attention_dropout_rate: 0.0
@@ -41,13 +41,10 @@
     pos_enc_class: SinusoidalPositionEncoder
     normalize_before: true
     kernel_size: 11
-    sanm_shfit: 0
+    sanm_shfit: 5
     selfattention_layer_type: sanm
     padding_idx: 0
 
 tokenizer: CharTokenizer
 tokenizer_conf:
-  unk_symbol: <unk>
-
-
-
+  unk_symbol: <unk>
\ No newline at end of file
diff --git a/funasr/models/ct_transformer_streaming/utils.py b/funasr/models/ct_transformer_streaming/utils.py
deleted file mode 100644
index 917f2e0..0000000
--- a/funasr/models/ct_transformer_streaming/utils.py
+++ /dev/null
@@ -1,111 +0,0 @@
-import re
-
-def split_to_mini_sentence(words: list, word_limit: int = 20):
-    assert word_limit > 1
-    if len(words) <= word_limit:
-        return [words]
-    sentences = []
-    length = len(words)
-    sentence_len = length // word_limit
-    for i in range(sentence_len):
-        sentences.append(words[i * word_limit:(i + 1) * word_limit])
-    if length % word_limit > 0:
-        sentences.append(words[sentence_len * word_limit:])
-    return sentences
-
-
-# def split_words(text: str, **kwargs):
-#     words = []
-#     segs = text.split()
-#     for seg in segs:
-#         # There is no space in seg.
-#         current_word = ""
-#         for c in seg:
-#             if len(c.encode()) == 1:
-#                 # This is an ASCII char.
-#                 current_word += c
-#             else:
-#                 # This is a Chinese char.
-#                 if len(current_word) > 0:
-#                     words.append(current_word)
-#                     current_word = ""
-#                 words.append(c)
-#         if len(current_word) > 0:
-#             words.append(current_word)
-#
-#     return words
-
-def split_words(text: str, jieba_usr_dict=None, **kwargs):
-    if jieba_usr_dict:
-        input_list = text.split()
-        token_list_all = []
-        langauge_list = []
-        token_list_tmp = []
-        language_flag = None
-        for token in input_list:
-            if isEnglish(token) and language_flag == 'Chinese':
-                token_list_all.append(token_list_tmp)
-                langauge_list.append('Chinese')
-                token_list_tmp = []
-            elif not isEnglish(token) and language_flag == 'English':
-                token_list_all.append(token_list_tmp)
-                langauge_list.append('English')
-                token_list_tmp = []
-
-            token_list_tmp.append(token)
-
-            if isEnglish(token):
-                language_flag = 'English'
-            else:
-                language_flag = 'Chinese'
-
-        if token_list_tmp:
-            token_list_all.append(token_list_tmp)
-            langauge_list.append(language_flag)
-
-        result_list = []
-        for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
-            if language_flag == 'English':
-                result_list.extend(token_list_tmp)
-            else:
-                seg_list = jieba_usr_dict.cut(join_chinese_and_english(token_list_tmp), HMM=False)
-                result_list.extend(seg_list)
-
-        return result_list
-
-    else:
-        words = []
-        segs = text.split()
-        for seg in segs:
-            # There is no space in seg.
-            current_word = ""
-            for c in seg:
-                if len(c.encode()) == 1:
-                    # This is an ASCII char.
-                    current_word += c
-                else:
-                    # This is a Chinese char.
-                    if len(current_word) > 0:
-                        words.append(current_word)
-                        current_word = ""
-                    words.append(c)
-            if len(current_word) > 0:
-                words.append(current_word)
-        return words
-
-def isEnglish(text:str):
-    if re.search('^[a-zA-Z\']+$', text):
-        return True
-    else:
-        return False
-
-def join_chinese_and_english(input_list):
-    line = ''
-    for token in input_list:
-        if isEnglish(token):
-            line = line + ' ' + token
-        else:
-            line = line + token
-
-    line = line.strip()
-    return line
diff --git a/funasr/models/ct_transformer_streaming/vad_realtime_transformer.py b/funasr/models/ct_transformer_streaming/vad_realtime_transformer.py
deleted file mode 100644
index 155057c..0000000
--- a/funasr/models/ct_transformer_streaming/vad_realtime_transformer.py
+++ /dev/null
@@ -1,135 +0,0 @@
-from typing import Any
-from typing import List
-from typing import Tuple
-
-import torch
-import torch.nn as nn
-
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder
-from funasr.models.ct_transformer.sanm_encoder import SANMVadEncoder as Encoder
-
-
-class VadRealtimeTransformer(torch.nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
-    https://arxiv.org/pdf/2003.01309.pdf
-    """
-    def __init__(
-        self,
-        vocab_size: int,
-        punc_size: int,
-        pos_enc: str = None,
-        embed_unit: int = 128,
-        att_unit: int = 256,
-        head: int = 2,
-        unit: int = 1024,
-        layer: int = 4,
-        dropout_rate: float = 0.5,
-        kernel_size: int = 11,
-        sanm_shfit: int = 0,
-    ):
-        super().__init__()
-        if pos_enc == "sinusoidal":
-            #            pos_enc_class = PositionalEncoding
-            pos_enc_class = SinusoidalPositionEncoder
-        elif pos_enc is None:
-
-            def pos_enc_class(*args, **kwargs):
-                return nn.Sequential()  # indentity
-
-        else:
-            raise ValueError(f"unknown pos-enc option: {pos_enc}")
-
-        self.embed = nn.Embedding(vocab_size, embed_unit)
-        self.encoder = Encoder(
-            input_size=embed_unit,
-            output_size=att_unit,
-            attention_heads=head,
-            linear_units=unit,
-            num_blocks=layer,
-            dropout_rate=dropout_rate,
-            input_layer="pe",
-            # pos_enc_class=pos_enc_class,
-            padding_idx=0,
-            kernel_size=kernel_size,
-            sanm_shfit=sanm_shfit,
-        )
-        self.decoder = nn.Linear(att_unit, punc_size)
-
-
-#    def _target_mask(self, ys_in_pad):
-#        ys_mask = ys_in_pad != 0
-#        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
-#        return ys_mask.unsqueeze(-2) & m
-
-    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor,
-                vad_indexes: torch.Tensor) -> Tuple[torch.Tensor, None]:
-        """Compute loss value from buffer sequences.
-
-        Args:
-            input (torch.Tensor): Input ids. (batch, len)
-            hidden (torch.Tensor): Target ids. (batch, len)
-
-        """
-        x = self.embed(input)
-        # mask = self._target_mask(input)
-        h, _, _ = self.encoder(x, text_lengths, vad_indexes)
-        y = self.decoder(h)
-        return y, None
-
-    def with_vad(self):
-        return True
-
-    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
-        """Score new token.
-
-        Args:
-            y (torch.Tensor): 1D torch.int64 prefix tokens.
-            state: Scorer state for prefix tokens
-            x (torch.Tensor): encoder feature that generates ys.
-
-        Returns:
-            tuple[torch.Tensor, Any]: Tuple of
-                torch.float32 scores for next token (vocab_size)
-                and next state for ys
-
-        """
-        y = y.unsqueeze(0)
-        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
-        h = self.decoder(h[:, -1])
-        logp = h.log_softmax(dim=-1).squeeze(0)
-        return logp, cache
-
-    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
-        """Score new token batch.
-
-        Args:
-            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
-            states (List[Any]): Scorer states for prefix tokens.
-            xs (torch.Tensor):
-                The encoder feature that generates ys (n_batch, xlen, n_feat).
-
-        Returns:
-            tuple[torch.Tensor, List[Any]]: Tuple of
-                batchfied scores for next token with shape of `(n_batch, vocab_size)`
-                and next state list for ys.
-
-        """
-        # merge states
-        n_batch = len(ys)
-        n_layers = len(self.encoder.encoders)
-        if states[0] is None:
-            batch_state = None
-        else:
-            # transpose state of [batch, layer] into [layer, batch]
-            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
-
-        # batch decoding
-        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
-        h = self.decoder(h[:, -1])
-        logp = h.log_softmax(dim=-1)
-
-        # transpose state of [layer, batch] into [batch, layer]
-        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
-        return logp, state_list
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index a6596a0..ef9d93a 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -125,3 +125,4 @@
     logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
     dst_state.update(src_state)
     obj.load_state_dict(dst_state)
+    
\ No newline at end of file

--
Gitblit v1.9.1