游雁
2024-05-17 d3ff05837bbc14749d09f44947633b87e8f2db0e
funasr/models/sense_voice/decoder.py
@@ -15,6 +15,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from funasr.models.transformer.utils.mask import subsequent_mask
class LayerNorm(nn.LayerNorm):
@@ -31,275 +32,575 @@
        )
def sense_voice_decode_forward(
   self,
   x: torch.Tensor,
   xa: torch.Tensor,
   kv_cache: Optional[dict] = None,
   **kwargs,
    self,
    x: torch.Tensor,
    xa: torch.Tensor,
    kv_cache: Optional[dict] = None,
    **kwargs,
):
   """Forward decoder.
    """Forward decoder.
   Args:
      hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
      hlens: (batch)
      ys_in_pad:
         input token ids, int64 (batch, maxlen_out)
         if input_layer == "embed"
         input tensor (batch, maxlen_out, #mels) in the other cases
      ys_in_lens: (batch)
   Returns:
      (tuple): tuple containing:
    Args:
            hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
            hlens: (batch)
            ys_in_pad:
                    input token ids, int64 (batch, maxlen_out)
                    if input_layer == "embed"
                    input tensor (batch, maxlen_out, #mels) in the other cases
            ys_in_lens: (batch)
    Returns:
            (tuple): tuple containing:
      x: decoded token score before softmax (batch, maxlen_out, token)
         if use_output_layer is True,
      olens: (batch, )
   """
   # import pdb;pdb.set_trace()
   use_padmask = self.use_padmask
   hlens = kwargs.get("hlens", None)
   ys_in_lens = kwargs.get("ys_in_lens", None)
   offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
   tgt, memory = x, xa
   tgt[tgt == -1] = 0
   tgt = (
      self.token_embedding(tgt)
      + self.positional_embedding[offset: offset + tgt.size(1)]
   )
   # tgt = self.dropout(tgt)
   x = tgt.to(memory.dtype)
   if use_padmask and hlens is not None:
      memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
   else:
      memory_mask = None
   for layer, block in enumerate(self.blocks):
      x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
   x = self.ln(x)
   x = (
      x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
   ).float()
   return x
            x: decoded token score before softmax (batch, maxlen_out, token)
                    if use_output_layer is True,
            olens: (batch, )
    """
    # import pdb;pdb.set_trace()
    use_padmask = self.use_padmask
    hlens = kwargs.get("hlens", None)
    ys_in_lens = kwargs.get("ys_in_lens", None)
    offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
    tgt, memory = x, xa
    tgt[tgt == -1] = 0
    tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)]
    # tgt = self.dropout(tgt)
    x = tgt.to(memory.dtype)
    if use_padmask and hlens is not None:
        memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
    else:
        memory_mask = None
    for layer, block in enumerate(self.blocks):
        x = block(
            x,
            memory,
            mask=self.mask,
            memory_mask=memory_mask,
            is_pad_mask=False,
            is_pad_memory_mask=True,
        )
    x = self.ln(x)
    x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
    return x
class MultiHeadAttention(nn.Module):
   def __init__(self, n_state: int, n_head: int):
      super().__init__()
      self.n_head = n_head
      self.query = Linear(n_state, n_state)
      self.key = Linear(n_state, n_state, bias=False)
      self.value = Linear(n_state, n_state)
      self.out = Linear(n_state, n_state)
   def forward(
      self,
      x: Tensor,
      xa: Optional[Tensor] = None,
      mask: Optional[Tensor] = None,
      kv_cache: Optional[dict] = None,
      **kwargs,
   ):
      is_pad_mask = kwargs.get("is_pad_mask", False)
      q = self.query(x)
      if kv_cache is None or xa is None or self.key not in kv_cache:
         # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
         # otherwise, perform key/value projections for self- or cross-attention as usual.
         k = self.key(x if xa is None else xa)
         v = self.value(x if xa is None else xa)
      else:
         # for cross-attention, calculate keys and values once and reuse in subsequent calls.
         k = kv_cache[self.key]
         v = kv_cache[self.value]
      wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
      return self.out(wv), qk
   def qkv_attention(
      self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
   ):
      is_pad_mask = kwargs.get("is_pad_mask", False)
      n_batch, n_ctx, n_state = q.shape
      scale = (n_state // self.n_head) ** -0.25
      q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
      k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
      v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
      qk = q @ k
      if mask is not None:
         if not is_pad_mask:
            qk = qk + mask[:n_ctx, :n_ctx]
         else:
            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
            min_value = float(
               np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
            )
            qk = qk.masked_fill(mask, min_value)
      qk = qk.float()
      w = F.softmax(qk, dim=-1).to(q.dtype)
      if mask is not None and is_pad_mask:
         w = w.masked_fill(mask, 0.0)
      return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = Linear(n_state, n_state)
        self.key = Linear(n_state, n_state, bias=False)
        self.value = Linear(n_state, n_state)
        self.out = Linear(n_state, n_state)
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        is_pad_mask = kwargs.get("is_pad_mask", False)
        q = self.query(x)
        if kv_cache is None or xa is None or self.key not in kv_cache:
            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
            # otherwise, perform key/value projections for self- or cross-attention as usual.
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
            k = kv_cache[self.key]
            v = kv_cache[self.value]
        wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
        return self.out(wv), qk
    def qkv_attention(
        self,
        q: Tensor,
        k: Tensor,
        v: Tensor,
        mask: Optional[Tensor] = None,
        **kwargs,
    ):
        is_pad_mask = kwargs.get("is_pad_mask", False)
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
        qk = q @ k
        if mask is not None:
            if not is_pad_mask:
                qk = qk + mask[:n_ctx, :n_ctx]
            else:
                mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
                min_value = -float(
                    "inf"
                )  # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
                qk = qk.masked_fill(mask, min_value)
        qk = qk.float()
        w = F.softmax(qk, dim=-1).to(q.dtype)
        if mask is not None and is_pad_mask:
            w = w.masked_fill(mask, 0.0)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
from omegaconf import OmegaConf
class ResidualAttentionBlockRWKV(nn.Module):
   def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, layer_id=0, **kwargs):
      super().__init__()
      rwkv_cfg = kwargs.get("rwkv_cfg", {})
      args = OmegaConf.create(rwkv_cfg)
      self.attn = RWKVLayer(args=args, layer_id=layer_id)
      if args.get("datatype", "bf16") == "bf16":
         self.attn.to(torch.bfloat16)
      self.ln0 = None
      if layer_id == 0 and not args.get("ln0", True):
         self.ln0 = LayerNorm(args.n_embd)
         if args.get("init_rwkv", True):
            print("init_rwkv")
            layer_id = 0
            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
            nn.init.constant_(self.ln0.weight, scale)
      self.layer_id = layer_id
      self.args = args
      self.ln1 = None
      if not args.get("ln1", True):
         self.ln1 = LayerNorm(args.n_embd)
         # init
         if args.get("init_rwkv", True):
            print("init_rwkv")
            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
            nn.init.constant_(self.ln1.weight, scale)
      self.cross_attn = (
         MultiHeadAttention(n_state, n_head) if cross_attention else None
      )
      self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
      n_mlp = n_state * 4
      self.mlp = nn.Sequential(
         Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
      )
      self.mlp_ln = LayerNorm(n_state)
   def forward(
      self,
      x: Tensor,
      xa: Optional[Tensor] = None,
      mask: Optional[Tensor] = None,
      kv_cache: Optional[dict] = None,
      **kwargs,
   ):
      is_pad_mask = kwargs.get("is_pad_mask", False)
      is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
      if self.layer_id == 0 and self.ln0 is not None:
         x = self.ln0(x)
      if self.ln1 is None:
         x = x + self.attn(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
      else:
         x = x + self.attn(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
      if self.cross_attn:
         x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
      x = x + self.mlp(self.mlp_ln(x))
      return x
    def __init__(
        self, n_state: int, n_head: int, cross_attention: bool = False, layer_id=0, **kwargs
    ):
        super().__init__()
        rwkv_cfg = kwargs.get("rwkv_cfg", {})
        args = OmegaConf.create(rwkv_cfg)
        if args.get("version", "v4") == "v4":
            from funasr.models.sense_voice.rwkv_v4 import RWKVLayer
            from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix
        elif args.get("version", "v5") == "v5":
            from funasr.models.sense_voice.rwkv_v5 import RWKVLayer
            from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix
        else:
            from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
            from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
        # self.att = RWKVLayer(args=args, layer_id=layer_id)
        self.att = RWKV_Tmix(args, layer_id=layer_id)
        if args.get("init_rwkv", True):
            print("init_rwkv")
            nn.init.orthogonal_(self.att.receptance.weight, gain=1)
            nn.init.orthogonal_(self.att.key.weight, gain=0.1)
            nn.init.orthogonal_(self.att.value.weight, gain=1)
            nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
            nn.init.zeros_(self.att.output.weight)
        self.ln0 = None
        if layer_id == 0 and not args.get("ln0", True):
            self.ln0 = LayerNorm(args.n_embd)
            if args.get("init_rwkv", True):
                print("init_rwkv")
                layer_id = 0
                scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
                nn.init.constant_(self.ln0.weight, scale)
        self.layer_id = layer_id
        self.args = args
        self.ln1 = None
        if not args.get("ln1", True):
            self.ln1 = LayerNorm(args.n_embd)
            # init
            if args.get("init_rwkv", True):
                print("init_rwkv")
                scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
                nn.init.constant_(self.ln1.weight, scale)
        if args.get("datatype", "bf16") == "bf16":
            self.att.to(torch.bfloat16)
            # if self.ln1 is not None:
            #     self.ln1.to(torch.bfloat16)
        self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
        self.mlp_ln = LayerNorm(n_state)
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        is_pad_mask = kwargs.get("is_pad_mask", False)
        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
        if self.layer_id == 0 and self.ln0 is not None:
            x = self.ln0(x)
        if self.args.get("datatype", "bf16") == "bf16":
            x = x.bfloat16()
        if self.ln1 is None:
            x = x + self.att(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
        else:
            x = x + self.att(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
        if self.args.get("datatype", "bf16") == "bf16":
            x = x.to(torch.float32)
        if self.cross_attn:
            x = (
                x
                + self.cross_attn(
                    self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
                )[0]
            )
        x = x + self.mlp(self.mlp_ln(x))
        return x
@tables.register("decoder_classes", "SenseVoiceDecoder")
class SenseVoiceDecoder(nn.Module):
   def __init__(
      self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs
   ):
      super().__init__()
      self.token_embedding = nn.Embedding(n_vocab, n_state)
      self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
    def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        self.blocks = nn.ModuleList(
            [
                ResidualAttentionBlockRWKV(
                    n_state, n_head, cross_attention=True, layer_id=i, **kwargs
                )
                for i in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)
        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)
        self.use_padmask = kwargs.get("use_padmask", True)
    def forward(
        self,
        x: torch.Tensor,
        xa: torch.Tensor,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        """Forward decoder.
        Args:
                hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
                hlens: (batch)
                ys_in_pad:
                        input token ids, int64 (batch, maxlen_out)
                        if input_layer == "embed"
                        input tensor (batch, maxlen_out, #mels) in the other cases
                ys_in_lens: (batch)
        Returns:
                (tuple): tuple containing:
                x: decoded token score before softmax (batch, maxlen_out, token)
                        if use_output_layer is True,
                olens: (batch, )
        """
        # import pdb;pdb.set_trace()
        use_padmask = self.use_padmask
        hlens = kwargs.get("hlens", None)
        ys_in_lens = kwargs.get("ys_in_lens", None)
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        tgt, memory = x, xa
        tgt[tgt == -1] = 0
        tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)]
        # tgt = self.dropout(tgt)
        x = tgt.to(memory.dtype)
        if use_padmask and hlens is not None:
            memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
        else:
            memory_mask = None
        for layer, block in enumerate(self.blocks):
            x = block(
                x,
                memory,
                mask=self.mask,
                memory_mask=memory_mask,
                is_pad_mask=False,
                is_pad_memory_mask=True,
            )
        x = self.ln(x)
        x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
        return x
    def init_state(self, x):
        state = {}
        return state
    def final_score(self, state) -> float:
        """Score eos (optional).
        Args:
            state: Scorer state for prefix tokens
        Returns:
            float: final score
        """
        return 0.0
    def score(self, ys, state, x):
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
        return logp.squeeze(0)[-1, :], state
      self.blocks = nn.ModuleList(
         [
            ResidualAttentionBlockRWKV(n_state, n_head, cross_attention=True, layer_id=i, **kwargs)
            for i in range(n_layer)
         ]
      )
      self.ln = LayerNorm(n_state)
      mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
      self.register_buffer("mask", mask, persistent=False)
      self.use_padmask = kwargs.get("use_padmask", True)
class MultiHeadedAttentionSANMDecoder(nn.Module):
    """Multi-Head Attention layer.
    Args:
        n_head (int): The number of heads.
        n_feat (int): The number of features.
        dropout_rate (float): Dropout rate.
    """
    def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
        """Construct an MultiHeadedAttention object."""
        super().__init__()
        self.dropout = nn.Dropout(p=dropout_rate)
        self.fsmn_block = nn.Conv1d(
            n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
        )
        # padding
        # padding
        left_padding = (kernel_size - 1) // 2
        if sanm_shfit > 0:
            left_padding = left_padding + sanm_shfit
        right_padding = kernel_size - 1 - left_padding
        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
        self.kernel_size = kernel_size
    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None, **kwargs):
        """
        :param x: (#batch, time1, size).
        :param mask: Mask tensor (#batch, 1, time)
        :return:
        """
        # print("in fsmn, inputs", inputs.size())
        b, t, d = inputs.size()
        # logging.info(
        #     "mask: {}".format(mask.size()))
        if mask is not None:
            mask = torch.reshape(mask, (b, -1, 1))
            # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
            if mask_shfit_chunk is not None:
                # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
                mask = mask * mask_shfit_chunk
            # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
            # print("in fsmn, mask", mask.size())
            # print("in fsmn, inputs", inputs.size())
            inputs = inputs * mask
        x = inputs.transpose(1, 2)
        b, d, t = x.size()
        if cache is None:
            # print("in fsmn, cache is None, x", x.size())
            x = self.pad_fn(x)
            if not self.training:
                cache = x
        else:
            # print("in fsmn, cache is not None, x", x.size())
            # x = torch.cat((x, cache), dim=2)[:, :, :-1]
            # if t < self.kernel_size:
            #     x = self.pad_fn(x)
            x = torch.cat((cache[:, :, 1:], x), dim=2)
            x = x[:, :, -(self.kernel_size + t - 1) :]
            # print("in fsmn, cache is not None, x_cat", x.size())
            cache = x
        x = self.fsmn_block(x)
        x = x.transpose(1, 2)
        # print("in fsmn, fsmn_out", x.size())
        if x.size(1) != inputs.size(1):
            inputs = inputs[:, -1, :]
        x = x + inputs
        x = self.dropout(x)
        if mask is not None:
            x = x * mask
        return x, cache
   def forward(
      self,
      x: torch.Tensor,
      xa: torch.Tensor,
      kv_cache: Optional[dict] = None,
      **kwargs,
   ):
      """Forward decoder.
      Args:
         hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
         hlens: (batch)
         ys_in_pad:
            input token ids, int64 (batch, maxlen_out)
            if input_layer == "embed"
            input tensor (batch, maxlen_out, #mels) in the other cases
         ys_in_lens: (batch)
      Returns:
         (tuple): tuple containing:
         x: decoded token score before softmax (batch, maxlen_out, token)
            if use_output_layer is True,
         olens: (batch, )
      """
      # import pdb;pdb.set_trace()
      use_padmask = self.use_padmask
      hlens = kwargs.get("hlens", None)
      ys_in_lens = kwargs.get("ys_in_lens", None)
      offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
      tgt, memory = x, xa
      tgt[tgt == -1] = 0
      tgt = (
         self.token_embedding(tgt)
         + self.positional_embedding[offset: offset + tgt.size(1)]
      )
      # tgt = self.dropout(tgt)
      x = tgt.to(memory.dtype)
      if use_padmask and hlens is not None:
         memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
      else:
         memory_mask = None
      for layer, block in enumerate(self.blocks):
         x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
      x = self.ln(x)
      x = (
         x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
      ).float()
      return x
class ResidualAttentionBlockFSMN(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, **kwargs):
        super().__init__()
        self.attn = MultiHeadedAttentionSANMDecoder(
            n_state,
            kwargs.get("self_attention_dropout_rate"),
            kwargs.get("kernel_size", 20),
            kwargs.get("sanm_shfit", 10),
        )
        self.attn_ln = LayerNorm(n_state)
        self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
        self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
        n_mlp = n_state * 4
        self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
        self.mlp_ln = LayerNorm(n_state)
    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        cache = kwargs.get("cache", {})
        layer = kwargs.get("layer", 0)
        is_pad_mask = kwargs.get("is_pad_mask", False)
        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
        fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None
        # if fsmn_cache is not None:
        #     x = x[:, -1:]
        att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
        # if len(cache)>1:
        #     cache[layer]["fsmn_cache"] = fsmn_cache
        #     x = x[:, -1:]
        x = x + att_res
        if self.cross_attn:
            x = (
                x
                + self.cross_attn(
                    self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
                )[0]
            )
        x = x + self.mlp(self.mlp_ln(x))
        return x
@tables.register("decoder_classes", "SenseVoiceDecoderFSMN")
class SenseVoiceDecoderFSMN(nn.Module):
    def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs):
        super().__init__()
        self.token_embedding = nn.Embedding(n_vocab, n_state)
        self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
        self.blocks = nn.ModuleList(
            [
                ResidualAttentionBlockFSMN(
                    n_state, n_head, cross_attention=True, layer_id=i, **kwargs
                )
                for i in range(n_layer)
            ]
        )
        self.ln = LayerNorm(n_state)
        mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
        self.register_buffer("mask", mask, persistent=False)
        self.use_padmask = kwargs.get("use_padmask", True)
    def forward(
        self,
        x: torch.Tensor,
        xa: torch.Tensor,
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        """Forward decoder.
        Args:
                hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
                hlens: (batch)
                ys_in_pad:
                        input token ids, int64 (batch, maxlen_out)
                        if input_layer == "embed"
                        input tensor (batch, maxlen_out, #mels) in the other cases
                ys_in_lens: (batch)
        Returns:
                (tuple): tuple containing:
                x: decoded token score before softmax (batch, maxlen_out, token)
                        if use_output_layer is True,
                olens: (batch, )
        """
        # import pdb;pdb.set_trace()
        use_padmask = self.use_padmask
        hlens = kwargs.get("hlens", None)
        ys_in_lens = kwargs.get("ys_in_lens", None)
        tgt, memory = x, xa
        tgt[tgt == -1] = 0
        tgt = self.token_embedding(tgt) + self.positional_embedding[: tgt.size(1)]
        # tgt = self.dropout(tgt)
        x = tgt.to(memory.dtype)
        if use_padmask and hlens is not None:
            memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
        else:
            memory_mask = None
        for layer, block in enumerate(self.blocks):
            x = block(
                x,
                memory,
                mask=self.mask,
                memory_mask=memory_mask,
                is_pad_mask=False,
                is_pad_memory_mask=True,
                cache=kwargs.get("cache", None),
                layer=layer,
            )
        x = self.ln(x)
        x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
        return x
    def init_state(self, x):
        state = {}
        for layer, block in enumerate(self.blocks):
            state[layer] = {
                "fsmn_cache": None,
                "memory_key": None,
                "memory_value": None,
            }
        return state
    def final_score(self, state) -> float:
        """Score eos (optional).
        Args:
            state: Scorer state for prefix tokens
        Returns:
            float: final score
        """
        return 0.0
    def score(self, ys, state, x):
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=None)
        logp = torch.log_softmax(logp, dim=-1)
        return logp.squeeze(0)[-1, :], state