kongdeqiang
5 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/paraformer/decoder.py
@@ -1,25 +1,29 @@
from typing import List
from typing import Tuple
import logging
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple
from funasr.register import tables
from funasr.models.scama import utils as myutils
from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.decoder import DecoderLayer
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.register import tables
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.sanm.attention import (
    MultiHeadedAttentionSANMDecoder,
    MultiHeadedAttentionCrossAtt,
)
class DecoderLayerSANM(nn.Module):
class DecoderLayerSANM(torch.nn.Module):
    """Single decoder layer module.
    Args:
@@ -62,13 +66,13 @@
            self.norm2 = LayerNorm(size)
        if src_attn is not None:
            self.norm3 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        if self.concat_after:
            self.concat_linear1 = nn.Linear(size + size, size)
            self.concat_linear2 = nn.Linear(size + size, size)
        self.reserve_attn=False
            self.concat_linear1 = torch.nn.Linear(size + size, size)
            self.concat_linear2 = torch.nn.Linear(size + size, size)
        self.reserve_attn = False
        self.attn_mat = []
    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -116,6 +120,22 @@
        return x, tgt_mask, memory, memory_mask, cache
    def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
        residual = x
        x = self.norm3(x)
        x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
        return attn_mat
    def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        """Compute decoded features.
@@ -156,10 +176,11 @@
            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
        return x, tgt_mask, memory, memory_mask, cache
    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
    def forward_chunk(
        self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
    ):
        """Compute decoded features.
        Args:
@@ -207,6 +228,7 @@
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(
        self,
        vocab_size: int,
@@ -286,7 +308,13 @@
                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                ),
                MultiHeadedAttentionCrossAtt(
                    attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
                    attention_heads,
                    attention_dim,
                    src_attention_dropout_rate,
                    lora_list,
                    lora_rank,
                    lora_alpha,
                    lora_dropout,
                ),
                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                dropout_rate,
@@ -334,9 +362,9 @@
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        return_hidden: bool = False,
        return_both: bool= False,
        chunk_mask: torch.Tensor = None,
        return_hidden: bool = False,
        return_both: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
@@ -357,7 +385,7 @@
        """
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        if chunk_mask is not None:
@@ -366,16 +394,10 @@
                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.decoders(
            x, tgt_mask, memory, memory_mask
        )
        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
        if self.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
            x, tgt_mask, memory, memory_mask
        )
            x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
        x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
        if self.normalize_before:
            hidden = self.after_norm(x)
@@ -390,11 +412,51 @@
    def score(self, ys, state, x):
        """Score."""
        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
        logp, state = self.forward_one_step(
            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
        )
        ys_mask = myutils.sequence_mask(
            torch.tensor([len(ys)], dtype=torch.int32), device=x.device
        )[:, :, None]
        logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
        return logp.squeeze(0), state
    def forward_asf2(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
        attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    def forward_asf6(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[1](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[2](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[3](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[4](tgt, tgt_mask, memory, memory_mask)
        attn_mat = self.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    def forward_chunk(
        self,
@@ -437,22 +499,24 @@
        for i in range(self.att_layer_num):
            decoder = self.decoders[i]
            x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
                x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
                chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
                x,
                memory,
                fsmn_cache=fsmn_cache[i],
                opt_cache=opt_cache[i],
                chunk_size=cache["chunk_size"],
                look_back=cache["decoder_chunk_look_back"],
            )
        if self.num_blocks - self.att_layer_num > 1:
            for i in range(self.num_blocks - self.att_layer_num):
                j = i + self.att_layer_num
                decoder = self.decoders2[i]
                x, memory, fsmn_cache[j], _  = decoder.forward_chunk(
                x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
                    x, memory, fsmn_cache=fsmn_cache[j]
                )
        for decoder in self.decoders3:
            x, memory, _, _ = decoder.forward_chunk(
                x, memory
            )
            x, memory, _, _ = decoder.forward_chunk(x, memory)
        if self.normalize_before:
            x = self.after_norm(x)
        if self.output_layer is not None:
@@ -525,6 +589,395 @@
        return y, new_cache
class DecoderLayerSANMExport(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.self_attn = model.self_attn
        self.src_attn = model.src_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2 if hasattr(model, "norm2") else None
        self.norm3 = model.norm3 if hasattr(model, "norm3") else None
        self.size = model.size
    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
        if self.src_attn is not None:
            residual = x
            x = self.norm3(x)
            x = residual + self.src_attn(x, memory, memory_mask)
        return x, tgt_mask, memory, memory_mask, cache
    def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
        residual = x
        x = self.norm3(x)
        x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
        return attn_mat
@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
class ParaformerSANMDecoderExport(torch.nn.Module):
    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANMExport(d)
        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANMExport(d)
        for i, d in enumerate(self.model.decoders3):
            self.model.decoders3[i] = DecoderLayerSANMExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        return_hidden: bool = False,
        return_both: bool = False,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(x, tgt_mask, memory, memory_mask)
        if self.model.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
        hidden = self.after_norm(x)
        # x = self.output_layer(x)
        if self.output_layer is not None and return_hidden is False:
            x = self.output_layer(hidden)
            return x, ys_in_lens
        if return_both:
            x = self.output_layer(hidden)
            return x, hidden, ys_in_lens
        return hidden, ys_in_lens
    def forward_asf2(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        _, memory_mask = self.prepare_mask(memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
            tgt, tgt_mask, memory, memory_mask
        )
        attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    def forward_asf6(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        _, memory_mask = self.prepare_mask(memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](
            tgt, tgt_mask, memory, memory_mask
        )
        attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    """
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
            for _ in range(cache_num)
        ]
        return (tgt, memory, pre_acoustic_embeds, cache)
    def is_optimizable(self):
        return True
    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
               + ['cache_%d' % i for i in range(cache_num)]
    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['y'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            'tgt': {
                0: 'tgt_batch',
                1: 'tgt_length'
            },
            'memory': {
                0: 'memory_batch',
                1: 'memory_length'
            },
            'pre_acoustic_embeds': {
                0: 'acoustic_embeds_batch',
                1: 'acoustic_embeds_length',
            }
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update({
            'cache_%d' % d: {
                0: 'cache_%d_batch' % d,
                2: 'cache_%d_length' % d
            }
            for d in range(cache_num)
        })
        return ret
    """
@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANMExport(d)
        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANMExport(d)
        for i, d in enumerate(self.model.decoders3):
            self.model.decoders3[i] = DecoderLayerSANMExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        *args,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        out_caches = list()
        for i, decoder in enumerate(self.model.decoders):
            in_cache = args[i]
            x, tgt_mask, memory, memory_mask, out_cache = decoder(
                x, tgt_mask, memory, memory_mask, cache=in_cache
            )
            out_caches.append(out_cache)
        if self.model.decoders2 is not None:
            for i, decoder in enumerate(self.model.decoders2):
                in_cache = args[i + len(self.model.decoders)]
                x, tgt_mask, memory, memory_mask, out_cache = decoder(
                    x, tgt_mask, memory, memory_mask, cache=in_cache
                )
                out_caches.append(out_cache)
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
        x = self.after_norm(x)
        x = self.output_layer(x)
        return x, out_caches
    def get_dummy_inputs(self, enc_size):
        enc = torch.randn(2, 100, enc_size).type(torch.float32)
        enc_len = torch.tensor([30, 100], dtype=torch.int32)
        acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
        acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        cache = [
            torch.zeros(
                (2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
                dtype=torch.float32,
            )
            for _ in range(cache_num)
        ]
        return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
    def get_input_names(self):
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        return ["enc", "enc_len", "acoustic_embeds", "acoustic_embeds_len"] + [
            "in_cache_%d" % i for i in range(cache_num)
        ]
    def get_output_names(self):
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        return ["logits", "sample_ids"] + ["out_cache_%d" % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            "enc": {0: "batch_size", 1: "enc_length"},
            "acoustic_embeds": {0: "batch_size", 1: "token_length"},
            "enc_len": {
                0: "batch_size",
            },
            "acoustic_embeds_len": {
                0: "batch_size",
            },
        }
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        ret.update(
            {
                "in_cache_%d"
                % d: {
                    0: "batch_size",
                }
                for d in range(cache_num)
            }
        )
        ret.update(
            {
                "out_cache_%d"
                % d: {
                    0: "batch_size",
                }
                for d in range(cache_num)
            }
        )
        return ret
@tables.register("decoder_classes", "ParaformerSANDecoder")
class ParaformerSANDecoder(BaseTransformerDecoder):
    """
@@ -532,23 +985,24 @@
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(
            self,
            vocab_size: int,
            encoder_output_size: int,
            attention_heads: int = 4,
            linear_units: int = 2048,
            num_blocks: int = 6,
            dropout_rate: float = 0.1,
            positional_dropout_rate: float = 0.1,
            self_attention_dropout_rate: float = 0.0,
            src_attention_dropout_rate: float = 0.0,
            input_layer: str = "embed",
            use_output_layer: bool = True,
            pos_enc_class=PositionalEncoding,
            normalize_before: bool = True,
            concat_after: bool = False,
            embeds_id: int = -1,
        self,
        vocab_size: int,
        encoder_output_size: int,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        self_attention_dropout_rate: float = 0.0,
        src_attention_dropout_rate: float = 0.0,
        input_layer: str = "embed",
        use_output_layer: bool = True,
        pos_enc_class=PositionalEncoding,
        normalize_before: bool = True,
        concat_after: bool = False,
        embeds_id: int = -1,
    ):
        super().__init__(
            vocab_size=vocab_size,
@@ -566,12 +1020,8 @@
            num_blocks,
            lambda lnum: DecoderLayer(
                attention_dim,
                MultiHeadedAttention(
                    attention_heads, attention_dim, self_attention_dropout_rate
                ),
                MultiHeadedAttention(
                    attention_heads, attention_dim, src_attention_dropout_rate
                ),
                MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                dropout_rate,
                normalize_before,
@@ -582,11 +1032,11 @@
        self.attention_dim = attention_dim
    def forward(
            self,
            hs_pad: torch.Tensor,
            hlens: torch.Tensor,
            ys_in_pad: torch.Tensor,
            ys_in_lens: torch.Tensor,
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
@@ -609,23 +1059,17 @@
        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
        memory = hs_pad
        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
            memory.device
        )
        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
        # Padding for Longformer
        if memory_mask.shape[-1] != memory.shape[1]:
            padlen = memory.shape[1] - memory_mask.shape[-1]
            memory_mask = torch.nn.functional.pad(
                memory_mask, (0, padlen), "constant", False
            )
            memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
        # x = self.embed(tgt)
        x = tgt
        embeds_outputs = None
        for layer_id, decoder in enumerate(self.decoders):
            x, tgt_mask, memory, memory_mask = decoder(
                x, tgt_mask, memory, memory_mask
            )
            x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, memory_mask)
            if layer_id == self.embeds_id:
                embeds_outputs = x
        if self.normalize_before:
@@ -639,3 +1083,111 @@
        else:
            return x, olens
@tables.register("decoder_classes", "ParaformerDecoderSANExport")
class ParaformerDecoderSANExport(torch.nn.Module):
    def __init__(
        self,
        model,
        max_seq_len=512,
        model_name="decoder",
        onnx: bool = True,
    ):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.transformer.decoder import DecoderLayerExport
        from funasr.models.transformer.attention import MultiHeadedAttentionExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.src_attn, MultiHeadedAttention):
                d.src_attn = MultiHeadedAttentionExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        x, tgt_mask, memory, memory_mask = self.model.decoders(x, tgt_mask, memory, memory_mask)
        x = self.after_norm(x)
        x = self.output_layer(x)
        return x, ys_in_lens
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros(
                (1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)
            )
            for _ in range(cache_num)
        ]
        return (tgt, memory, pre_acoustic_embeds, cache)
    def is_optimizable(self):
        return True
    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ["tgt", "memory", "pre_acoustic_embeds"] + ["cache_%d" % i for i in range(cache_num)]
    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            "tgt": {0: "tgt_batch", 1: "tgt_length"},
            "memory": {0: "memory_batch", 1: "memory_length"},
            "pre_acoustic_embeds": {
                0: "acoustic_embeds_batch",
                1: "acoustic_embeds_length",
            },
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update(
            {
                "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
                for d in range(cache_num)
            }
        )
        return ret