kongdeqiang
5 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/paraformer/decoder.py
@@ -17,7 +17,10 @@
from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
from funasr.models.sanm.attention import (
    MultiHeadedAttentionSANMDecoder,
    MultiHeadedAttentionCrossAtt,
)
class DecoderLayerSANM(torch.nn.Module):
@@ -69,7 +72,7 @@
        if self.concat_after:
            self.concat_linear1 = torch.nn.Linear(size + size, size)
            self.concat_linear2 = torch.nn.Linear(size + size, size)
        self.reserve_attn=False
        self.reserve_attn = False
        self.attn_mat = []
    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -117,6 +120,22 @@
        return x, tgt_mask, memory, memory_mask, cache
    def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
        residual = x
        x = self.norm3(x)
        x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
        return attn_mat
    def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        """Compute decoded features.
@@ -157,10 +176,11 @@
            x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
        return x, tgt_mask, memory, memory_mask, cache
    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
    def forward_chunk(
        self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
    ):
        """Compute decoded features.
        Args:
@@ -208,6 +228,7 @@
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(
        self,
        vocab_size: int,
@@ -287,7 +308,13 @@
                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                ),
                MultiHeadedAttentionCrossAtt(
                    attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
                    attention_heads,
                    attention_dim,
                    src_attention_dropout_rate,
                    lora_list,
                    lora_rank,
                    lora_alpha,
                    lora_dropout,
                ),
                PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                dropout_rate,
@@ -335,9 +362,9 @@
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        return_hidden: bool = False,
        return_both: bool= False,
        chunk_mask: torch.Tensor = None,
        return_hidden: bool = False,
        return_both: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
@@ -358,7 +385,7 @@
        """
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        if chunk_mask is not None:
@@ -367,16 +394,10 @@
                memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.decoders(
            x, tgt_mask, memory, memory_mask
        )
        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
        if self.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
            x, tgt_mask, memory, memory_mask
        )
            x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
        x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
        if self.normalize_before:
            hidden = self.after_norm(x)
@@ -391,11 +412,51 @@
    def score(self, ys, state, x):
        """Score."""
        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
        logp, state = self.forward_one_step(
            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
        )
        ys_mask = myutils.sequence_mask(
            torch.tensor([len(ys)], dtype=torch.int32), device=x.device
        )[:, :, None]
        logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
        return logp.squeeze(0), state
    def forward_asf2(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
        attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    def forward_asf6(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[1](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[2](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[3](tgt, tgt_mask, memory, memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.decoders[4](tgt, tgt_mask, memory, memory_mask)
        attn_mat = self.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    def forward_chunk(
        self,
@@ -438,22 +499,24 @@
        for i in range(self.att_layer_num):
            decoder = self.decoders[i]
            x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
                x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
                chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
                x,
                memory,
                fsmn_cache=fsmn_cache[i],
                opt_cache=opt_cache[i],
                chunk_size=cache["chunk_size"],
                look_back=cache["decoder_chunk_look_back"],
            )
        if self.num_blocks - self.att_layer_num > 1:
            for i in range(self.num_blocks - self.att_layer_num):
                j = i + self.att_layer_num
                decoder = self.decoders2[i]
                x, memory, fsmn_cache[j], _  = decoder.forward_chunk(
                x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
                    x, memory, fsmn_cache=fsmn_cache[j]
                )
        for decoder in self.decoders3:
            x, memory, _, _ = decoder.forward_chunk(
                x, memory
            )
            x, memory, _, _ = decoder.forward_chunk(x, memory)
        if self.normalize_before:
            x = self.after_norm(x)
        if self.output_layer is not None:
@@ -526,6 +589,395 @@
        return y, new_cache
class DecoderLayerSANMExport(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.self_attn = model.self_attn
        self.src_attn = model.src_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2 if hasattr(model, "norm2") else None
        self.norm3 = model.norm3 if hasattr(model, "norm3") else None
        self.size = model.size
    def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
        if self.src_attn is not None:
            residual = x
            x = self.norm3(x)
            x = residual + self.src_attn(x, memory, memory_mask)
        return x, tgt_mask, memory, memory_mask, cache
    def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
        residual = tgt
        tgt = self.norm1(tgt)
        tgt = self.feed_forward(tgt)
        x = tgt
        if self.self_attn is not None:
            tgt = self.norm2(tgt)
            x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
            x = residual + x
        residual = x
        x = self.norm3(x)
        x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
        return attn_mat
@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
class ParaformerSANMDecoderExport(torch.nn.Module):
    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANMExport(d)
        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANMExport(d)
        for i, d in enumerate(self.model.decoders3):
            self.model.decoders3[i] = DecoderLayerSANMExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        return_hidden: bool = False,
        return_both: bool = False,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(x, tgt_mask, memory, memory_mask)
        if self.model.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
        hidden = self.after_norm(x)
        # x = self.output_layer(x)
        if self.output_layer is not None and return_hidden is False:
            x = self.output_layer(hidden)
            return x, ys_in_lens
        if return_both:
            x = self.output_layer(hidden)
            return x, hidden, ys_in_lens
        return hidden, ys_in_lens
    def forward_asf2(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        _, memory_mask = self.prepare_mask(memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
            tgt, tgt_mask, memory, memory_mask
        )
        attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    def forward_asf6(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        _, memory_mask = self.prepare_mask(memory_mask)
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](
            tgt, tgt_mask, memory, memory_mask
        )
        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](
            tgt, tgt_mask, memory, memory_mask
        )
        attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
        return attn_mat
    """
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
            for _ in range(cache_num)
        ]
        return (tgt, memory, pre_acoustic_embeds, cache)
    def is_optimizable(self):
        return True
    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
               + ['cache_%d' % i for i in range(cache_num)]
    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['y'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            'tgt': {
                0: 'tgt_batch',
                1: 'tgt_length'
            },
            'memory': {
                0: 'memory_batch',
                1: 'memory_length'
            },
            'pre_acoustic_embeds': {
                0: 'acoustic_embeds_batch',
                1: 'acoustic_embeds_length',
            }
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update({
            'cache_%d' % d: {
                0: 'cache_%d_batch' % d,
                2: 'cache_%d_length' % d
            }
            for d in range(cache_num)
        })
        return ret
    """
@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
        from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANMExport(d)
        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANMExport(d)
        for i, d in enumerate(self.model.decoders3):
            self.model.decoders3[i] = DecoderLayerSANMExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        *args,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        out_caches = list()
        for i, decoder in enumerate(self.model.decoders):
            in_cache = args[i]
            x, tgt_mask, memory, memory_mask, out_cache = decoder(
                x, tgt_mask, memory, memory_mask, cache=in_cache
            )
            out_caches.append(out_cache)
        if self.model.decoders2 is not None:
            for i, decoder in enumerate(self.model.decoders2):
                in_cache = args[i + len(self.model.decoders)]
                x, tgt_mask, memory, memory_mask, out_cache = decoder(
                    x, tgt_mask, memory, memory_mask, cache=in_cache
                )
                out_caches.append(out_cache)
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
        x = self.after_norm(x)
        x = self.output_layer(x)
        return x, out_caches
    def get_dummy_inputs(self, enc_size):
        enc = torch.randn(2, 100, enc_size).type(torch.float32)
        enc_len = torch.tensor([30, 100], dtype=torch.int32)
        acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
        acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        cache = [
            torch.zeros(
                (2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
                dtype=torch.float32,
            )
            for _ in range(cache_num)
        ]
        return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
    def get_input_names(self):
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        return ["enc", "enc_len", "acoustic_embeds", "acoustic_embeds_len"] + [
            "in_cache_%d" % i for i in range(cache_num)
        ]
    def get_output_names(self):
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        return ["logits", "sample_ids"] + ["out_cache_%d" % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            "enc": {0: "batch_size", 1: "enc_length"},
            "acoustic_embeds": {0: "batch_size", 1: "token_length"},
            "enc_len": {
                0: "batch_size",
            },
            "acoustic_embeds_len": {
                0: "batch_size",
            },
        }
        cache_num = len(self.model.decoders)
        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
            cache_num += len(self.model.decoders2)
        ret.update(
            {
                "in_cache_%d"
                % d: {
                    0: "batch_size",
                }
                for d in range(cache_num)
            }
        )
        ret.update(
            {
                "out_cache_%d"
                % d: {
                    0: "batch_size",
                }
                for d in range(cache_num)
            }
        )
        return ret
@tables.register("decoder_classes", "ParaformerSANDecoder")
class ParaformerSANDecoder(BaseTransformerDecoder):
    """
@@ -533,23 +985,24 @@
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2006.01713
    """
    def __init__(
            self,
            vocab_size: int,
            encoder_output_size: int,
            attention_heads: int = 4,
            linear_units: int = 2048,
            num_blocks: int = 6,
            dropout_rate: float = 0.1,
            positional_dropout_rate: float = 0.1,
            self_attention_dropout_rate: float = 0.0,
            src_attention_dropout_rate: float = 0.0,
            input_layer: str = "embed",
            use_output_layer: bool = True,
            pos_enc_class=PositionalEncoding,
            normalize_before: bool = True,
            concat_after: bool = False,
            embeds_id: int = -1,
        self,
        vocab_size: int,
        encoder_output_size: int,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        self_attention_dropout_rate: float = 0.0,
        src_attention_dropout_rate: float = 0.0,
        input_layer: str = "embed",
        use_output_layer: bool = True,
        pos_enc_class=PositionalEncoding,
        normalize_before: bool = True,
        concat_after: bool = False,
        embeds_id: int = -1,
    ):
        super().__init__(
            vocab_size=vocab_size,
@@ -567,12 +1020,8 @@
            num_blocks,
            lambda lnum: DecoderLayer(
                attention_dim,
                MultiHeadedAttention(
                    attention_heads, attention_dim, self_attention_dropout_rate
                ),
                MultiHeadedAttention(
                    attention_heads, attention_dim, src_attention_dropout_rate
                ),
                MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                dropout_rate,
                normalize_before,
@@ -583,11 +1032,11 @@
        self.attention_dim = attention_dim
    def forward(
            self,
            hs_pad: torch.Tensor,
            hlens: torch.Tensor,
            ys_in_pad: torch.Tensor,
            ys_in_lens: torch.Tensor,
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward decoder.
@@ -610,23 +1059,17 @@
        tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
        memory = hs_pad
        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
            memory.device
        )
        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
        # Padding for Longformer
        if memory_mask.shape[-1] != memory.shape[1]:
            padlen = memory.shape[1] - memory_mask.shape[-1]
            memory_mask = torch.nn.functional.pad(
                memory_mask, (0, padlen), "constant", False
            )
            memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
        # x = self.embed(tgt)
        x = tgt
        embeds_outputs = None
        for layer_id, decoder in enumerate(self.decoders):
            x, tgt_mask, memory, memory_mask = decoder(
                x, tgt_mask, memory, memory_mask
            )
            x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, memory_mask)
            if layer_id == self.embeds_id:
                embeds_outputs = x
        if self.normalize_before:
@@ -640,3 +1083,111 @@
        else:
            return x, olens
@tables.register("decoder_classes", "ParaformerDecoderSANExport")
class ParaformerDecoderSANExport(torch.nn.Module):
    def __init__(
        self,
        model,
        max_seq_len=512,
        model_name="decoder",
        onnx: bool = True,
    ):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        from funasr.utils.torch_function import sequence_mask
        self.model = model
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.transformer.decoder import DecoderLayerExport
        from funasr.models.transformer.attention import MultiHeadedAttentionExport
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.src_attn, MultiHeadedAttention):
                d.src_attn = MultiHeadedAttentionExport(d.src_attn)
            self.model.decoders[i] = DecoderLayerExport(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        x, tgt_mask, memory, memory_mask = self.model.decoders(x, tgt_mask, memory, memory_mask)
        x = self.after_norm(x)
        x = self.output_layer(x)
        return x, ys_in_lens
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros(
                (1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)
            )
            for _ in range(cache_num)
        ]
        return (tgt, memory, pre_acoustic_embeds, cache)
    def is_optimizable(self):
        return True
    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ["tgt", "memory", "pre_acoustic_embeds"] + ["cache_%d" % i for i in range(cache_num)]
    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            "tgt": {0: "tgt_batch", 1: "tgt_length"},
            "memory": {0: "memory_batch", 1: "memory_length"},
            "pre_acoustic_embeds": {
                0: "acoustic_embeds_batch",
                1: "acoustic_embeds_length",
            },
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update(
            {
                "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
                for d in range(cache_num)
            }
        )
        return ret