From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/paraformer/decoder.py |  409 ++++++++++++++++++++++++++++------------------------------
 1 files changed, 197 insertions(+), 212 deletions(-)

diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index f08e97b..7edd91a 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -17,7 +17,10 @@
 from funasr.models.transformer.decoder import BaseTransformerDecoder
 from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
 from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
-from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import (
+    MultiHeadedAttentionSANMDecoder,
+    MultiHeadedAttentionCrossAtt,
+)
 
 
 class DecoderLayerSANM(torch.nn.Module):
@@ -69,7 +72,7 @@
         if self.concat_after:
             self.concat_linear1 = torch.nn.Linear(size + size, size)
             self.concat_linear2 = torch.nn.Linear(size + size, size)
-        self.reserve_attn=False
+        self.reserve_attn = False
         self.attn_mat = []
 
     def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
@@ -116,7 +119,7 @@
             # x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
 
         return x, tgt_mask, memory, memory_mask, cache
-    
+
     def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
         residual = tgt
         tgt = self.norm1(tgt)
@@ -173,10 +176,11 @@
 
             x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
 
-
         return x, tgt_mask, memory, memory_mask, cache
 
-    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+    def forward_chunk(
+        self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
+    ):
         """Compute decoded features.
 
         Args:
@@ -224,6 +228,7 @@
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     https://arxiv.org/abs/2006.01713
     """
+
     def __init__(
         self,
         vocab_size: int,
@@ -303,7 +308,13 @@
                     attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                 ),
                 MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
+                    attention_heads,
+                    attention_dim,
+                    src_attention_dropout_rate,
+                    lora_list,
+                    lora_rank,
+                    lora_alpha,
+                    lora_dropout,
                 ),
                 PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
@@ -351,9 +362,9 @@
         hlens: torch.Tensor,
         ys_in_pad: torch.Tensor,
         ys_in_lens: torch.Tensor,
-        return_hidden: bool = False,
-        return_both: bool= False,
         chunk_mask: torch.Tensor = None,
+        return_hidden: bool = False,
+        return_both: bool = False,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
 
@@ -374,7 +385,7 @@
         """
         tgt = ys_in_pad
         tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-        
+
         memory = hs_pad
         memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
         if chunk_mask is not None:
@@ -383,16 +394,10 @@
                 memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
 
         x = tgt
-        x, tgt_mask, memory, memory_mask, _ = self.decoders(
-            x, tgt_mask, memory, memory_mask
-        )
+        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
         if self.decoders2 is not None:
-            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
-                x, tgt_mask, memory, memory_mask
-            )
-        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
-            x, tgt_mask, memory, memory_mask
-        )
+            x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
+        x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
         if self.normalize_before:
             hidden = self.after_norm(x)
 
@@ -407,12 +412,12 @@
 
     def score(self, ys, state, x):
         """Score."""
-        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
-        logp, state = self.forward_one_step(
-            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
-        )
+        ys_mask = myutils.sequence_mask(
+            torch.tensor([len(ys)], dtype=torch.int32), device=x.device
+        )[:, :, None]
+        logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
         return logp.squeeze(0), state
-    
+
     def forward_asf2(
         self,
         hs_pad: torch.Tensor,
@@ -430,7 +435,7 @@
         tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
         attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
         return attn_mat
-    
+
     def forward_asf6(
         self,
         hs_pad: torch.Tensor,
@@ -494,22 +499,24 @@
         for i in range(self.att_layer_num):
             decoder = self.decoders[i]
             x, memory, fsmn_cache[i], opt_cache[i] = decoder.forward_chunk(
-                x, memory, fsmn_cache=fsmn_cache[i], opt_cache=opt_cache[i],
-                chunk_size=cache["chunk_size"], look_back=cache["decoder_chunk_look_back"]
+                x,
+                memory,
+                fsmn_cache=fsmn_cache[i],
+                opt_cache=opt_cache[i],
+                chunk_size=cache["chunk_size"],
+                look_back=cache["decoder_chunk_look_back"],
             )
 
         if self.num_blocks - self.att_layer_num > 1:
             for i in range(self.num_blocks - self.att_layer_num):
                 j = i + self.att_layer_num
                 decoder = self.decoders2[i]
-                x, memory, fsmn_cache[j], _  = decoder.forward_chunk(
+                x, memory, fsmn_cache[j], _ = decoder.forward_chunk(
                     x, memory, fsmn_cache=fsmn_cache[j]
                 )
 
         for decoder in self.decoders3:
-            x, memory, _, _ = decoder.forward_chunk(
-                x, memory
-            )
+            x, memory, _, _ = decoder.forward_chunk(x, memory)
         if self.normalize_before:
             x = self.after_norm(x)
         if self.output_layer is not None:
@@ -581,21 +588,18 @@
 
         return y, new_cache
 
+
 class DecoderLayerSANMExport(torch.nn.Module):
 
-    def __init__(
-        self,
-        model
-    ):
+    def __init__(self, model):
         super().__init__()
         self.self_attn = model.self_attn
         self.src_attn = model.src_attn
         self.feed_forward = model.feed_forward
         self.norm1 = model.norm1
-        self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
-        self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
+        self.norm2 = model.norm2 if hasattr(model, "norm2") else None
+        self.norm3 = model.norm3 if hasattr(model, "norm3") else None
         self.size = model.size
-
 
     def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
 
@@ -614,9 +618,8 @@
             x = self.norm3(x)
             x = residual + self.src_attn(x, memory, memory_mask)
 
-
         return x, tgt_mask, memory, memory_mask, cache
-    
+
     def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
         residual = tgt
         tgt = self.norm1(tgt)
@@ -636,45 +639,39 @@
 
 @tables.register("decoder_classes", "ParaformerSANMDecoderExport")
 class ParaformerSANMDecoderExport(torch.nn.Module):
-    def __init__(self, model,
-                 max_seq_len=512,
-                 model_name='decoder',
-                 onnx: bool = True,
-                 **kwargs
-                 ):
+    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
         super().__init__()
         # self.embed = model.embed #Embedding(model.embed, max_seq_len)
 
         from funasr.utils.torch_function import sequence_mask
-        
+
         self.model = model
 
         self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-        
+
         from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
         from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
-        
-        
+
         for i, d in enumerate(self.model.decoders):
             if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                 d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
             if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                 d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
             self.model.decoders[i] = DecoderLayerSANMExport(d)
-        
+
         if self.model.decoders2 is not None:
             for i, d in enumerate(self.model.decoders2):
                 if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                     d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
                 self.model.decoders2[i] = DecoderLayerSANMExport(d)
-        
+
         for i, d in enumerate(self.model.decoders3):
             self.model.decoders3[i] = DecoderLayerSANMExport(d)
-        
+
         self.output_layer = model.output_layer
         self.after_norm = model.after_norm
         self.model_name = model_name
-    
+
     def prepare_mask(self, mask):
         mask_3d_btd = mask[:, :, None]
         if len(mask.shape) == 2:
@@ -682,9 +679,9 @@
         elif len(mask.shape) == 3:
             mask_4d_bhlt = 1 - mask[:, None, :]
         mask_4d_bhlt = mask_4d_bhlt * -10000.0
-        
+
         return mask_3d_btd, mask_4d_bhlt
-    
+
     def forward(
         self,
         hs_pad: torch.Tensor,
@@ -694,31 +691,27 @@
         return_hidden: bool = False,
         return_both: bool = False,
     ):
-        
+
         tgt = ys_in_pad
         tgt_mask = self.make_pad_mask(ys_in_lens)
         tgt_mask, _ = self.prepare_mask(tgt_mask)
         # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-        
+
         memory = hs_pad
         memory_mask = self.make_pad_mask(hlens)
         _, memory_mask = self.prepare_mask(memory_mask)
         # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-        
+
         x = tgt
-        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
-            x, tgt_mask, memory, memory_mask
-        )
+        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(x, tgt_mask, memory, memory_mask)
         if self.model.decoders2 is not None:
             x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
                 x, tgt_mask, memory, memory_mask
             )
-        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
-            x, tgt_mask, memory, memory_mask
-        )
+        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
         hidden = self.after_norm(x)
         # x = self.output_layer(x)
-        
+
         if self.output_layer is not None and return_hidden is False:
             x = self.output_layer(hidden)
             return x, ys_in_lens
@@ -726,7 +719,7 @@
             x = self.output_layer(hidden)
             return x, hidden, ys_in_lens
         return hidden, ys_in_lens
-    
+
     def forward_asf2(
         self,
         hs_pad: torch.Tensor,
@@ -742,10 +735,12 @@
         memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
         _, memory_mask = self.prepare_mask(memory_mask)
 
-        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
+            tgt, tgt_mask, memory, memory_mask
+        )
         attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
         return attn_mat
-    
+
     def forward_asf6(
         self,
         hs_pad: torch.Tensor,
@@ -761,15 +756,25 @@
         memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
         _, memory_mask = self.prepare_mask(memory_mask)
 
-        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
-        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](tgt, tgt_mask, memory, memory_mask)
-        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](tgt, tgt_mask, memory, memory_mask)
-        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](tgt, tgt_mask, memory, memory_mask)
-        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](tgt, tgt_mask, memory, memory_mask)
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](
+            tgt, tgt_mask, memory, memory_mask
+        )
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](
+            tgt, tgt_mask, memory, memory_mask
+        )
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](
+            tgt, tgt_mask, memory, memory_mask
+        )
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](
+            tgt, tgt_mask, memory, memory_mask
+        )
+        tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](
+            tgt, tgt_mask, memory, memory_mask
+        )
         attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
         return attn_mat
-    
-    '''
+
+    """
     def get_dummy_inputs(self, enc_size):
         tgt = torch.LongTensor([0]).unsqueeze(0)
         memory = torch.randn(1, 100, enc_size)
@@ -818,14 +823,12 @@
             for d in range(cache_num)
         })
         return ret
-    '''
-    
+    """
+
+
 @tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
 class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
-    def __init__(self, model,
-                 max_seq_len=512,
-                 model_name='decoder',
-                 onnx: bool = True, **kwargs):
+    def __init__(self, model, max_seq_len=512, model_name="decoder", onnx: bool = True, **kwargs):
         super().__init__()
         # self.embed = model.embed #Embedding(model.embed, max_seq_len)
         self.model = model
@@ -833,7 +836,7 @@
         from funasr.utils.torch_function import sequence_mask
 
         self.model = model
-        
+
         self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
 
         from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
@@ -854,11 +857,11 @@
 
         for i, d in enumerate(self.model.decoders3):
             self.model.decoders3[i] = DecoderLayerSANMExport(d)
-        
+
         self.output_layer = model.output_layer
         self.after_norm = model.after_norm
         self.model_name = model_name
-    
+
     def prepare_mask(self, mask):
         mask_3d_btd = mask[:, :, None]
         if len(mask.shape) == 2:
@@ -866,9 +869,9 @@
         elif len(mask.shape) == 3:
             mask_4d_bhlt = 1 - mask[:, None, :]
         mask_4d_bhlt = mask_4d_bhlt * -10000.0
-        
+
         return mask_3d_btd, mask_4d_bhlt
-    
+
     def forward(
         self,
         hs_pad: torch.Tensor,
@@ -877,17 +880,17 @@
         ys_in_lens: torch.Tensor,
         *args,
     ):
-        
+
         tgt = ys_in_pad
         tgt_mask = self.make_pad_mask(ys_in_lens)
         tgt_mask, _ = self.prepare_mask(tgt_mask)
         # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-        
+
         memory = hs_pad
         memory_mask = self.make_pad_mask(hlens)
         _, memory_mask = self.prepare_mask(memory_mask)
         # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-        
+
         x = tgt
         out_caches = list()
         for i, decoder in enumerate(self.model.decoders):
@@ -903,76 +906,75 @@
                     x, tgt_mask, memory, memory_mask, cache=in_cache
                 )
                 out_caches.append(out_cache)
-        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
-            x, tgt_mask, memory, memory_mask
-        )
+        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(x, tgt_mask, memory, memory_mask)
         x = self.after_norm(x)
         x = self.output_layer(x)
-        
+
         return x, out_caches
-    
+
     def get_dummy_inputs(self, enc_size):
         enc = torch.randn(2, 100, enc_size).type(torch.float32)
         enc_len = torch.tensor([30, 100], dtype=torch.int32)
         acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
         acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
         cache_num = len(self.model.decoders)
-        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
             cache_num += len(self.model.decoders2)
         cache = [
-            torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
-                        dtype=torch.float32)
+            torch.zeros(
+                (2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
+                dtype=torch.float32,
+            )
             for _ in range(cache_num)
         ]
         return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
-    
+
     def get_input_names(self):
         cache_num = len(self.model.decoders)
-        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
             cache_num += len(self.model.decoders2)
-        return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
-               + ['in_cache_%d' % i for i in range(cache_num)]
-    
+        return ["enc", "enc_len", "acoustic_embeds", "acoustic_embeds_len"] + [
+            "in_cache_%d" % i for i in range(cache_num)
+        ]
+
     def get_output_names(self):
         cache_num = len(self.model.decoders)
-        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
             cache_num += len(self.model.decoders2)
-        return ['logits', 'sample_ids'] \
-               + ['out_cache_%d' % i for i in range(cache_num)]
-    
+        return ["logits", "sample_ids"] + ["out_cache_%d" % i for i in range(cache_num)]
+
     def get_dynamic_axes(self):
         ret = {
-            'enc': {
-                0: 'batch_size',
-                1: 'enc_length'
+            "enc": {0: "batch_size", 1: "enc_length"},
+            "acoustic_embeds": {0: "batch_size", 1: "token_length"},
+            "enc_len": {
+                0: "batch_size",
             },
-            'acoustic_embeds': {
-                0: 'batch_size',
-                1: 'token_length'
+            "acoustic_embeds_len": {
+                0: "batch_size",
             },
-            'enc_len': {
-                0: 'batch_size',
-            },
-            'acoustic_embeds_len': {
-                0: 'batch_size',
-            },
-            
         }
         cache_num = len(self.model.decoders)
-        if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+        if hasattr(self.model, "decoders2") and self.model.decoders2 is not None:
             cache_num += len(self.model.decoders2)
-        ret.update({
-            'in_cache_%d' % d: {
-                0: 'batch_size',
+        ret.update(
+            {
+                "in_cache_%d"
+                % d: {
+                    0: "batch_size",
+                }
+                for d in range(cache_num)
             }
-            for d in range(cache_num)
-        })
-        ret.update({
-            'out_cache_%d' % d: {
-                0: 'batch_size',
+        )
+        ret.update(
+            {
+                "out_cache_%d"
+                % d: {
+                    0: "batch_size",
+                }
+                for d in range(cache_num)
             }
-            for d in range(cache_num)
-        })
+        )
         return ret
 
 
@@ -983,23 +985,24 @@
     Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
     https://arxiv.org/abs/2006.01713
     """
+
     def __init__(
-            self,
-            vocab_size: int,
-            encoder_output_size: int,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            self_attention_dropout_rate: float = 0.0,
-            src_attention_dropout_rate: float = 0.0,
-            input_layer: str = "embed",
-            use_output_layer: bool = True,
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            embeds_id: int = -1,
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        embeds_id: int = -1,
     ):
         super().__init__(
             vocab_size=vocab_size,
@@ -1017,12 +1020,8 @@
             num_blocks,
             lambda lnum: DecoderLayer(
                 attention_dim,
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, self_attention_dropout_rate
-                ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
+                MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
+                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                 PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
                 normalize_before,
@@ -1033,11 +1032,11 @@
         self.attention_dim = attention_dim
 
     def forward(
-            self,
-            hs_pad: torch.Tensor,
-            hlens: torch.Tensor,
-            ys_in_pad: torch.Tensor,
-            ys_in_lens: torch.Tensor,
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
 
@@ -1060,23 +1059,17 @@
         tgt_mask = (~make_pad_mask(ys_in_lens)[:, None, :]).to(tgt.device)
 
         memory = hs_pad
-        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
-            memory.device
-        )
+        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
         # Padding for Longformer
         if memory_mask.shape[-1] != memory.shape[1]:
             padlen = memory.shape[1] - memory_mask.shape[-1]
-            memory_mask = torch.nn.functional.pad(
-                memory_mask, (0, padlen), "constant", False
-            )
+            memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
 
         # x = self.embed(tgt)
         x = tgt
         embeds_outputs = None
         for layer_id, decoder in enumerate(self.decoders):
-            x, tgt_mask, memory, memory_mask = decoder(
-                x, tgt_mask, memory, memory_mask
-            )
+            x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, memory_mask)
             if layer_id == self.embeds_id:
                 embeds_outputs = x
         if self.normalize_before:
@@ -1090,16 +1083,19 @@
         else:
             return x, olens
 
+
 @tables.register("decoder_classes", "ParaformerDecoderSANExport")
 class ParaformerDecoderSANExport(torch.nn.Module):
-    def __init__(self, model,
-                 max_seq_len=512,
-                 model_name='decoder',
-                 onnx: bool = True, ):
+    def __init__(
+        self,
+        model,
+        max_seq_len=512,
+        model_name="decoder",
+        onnx: bool = True,
+    ):
         super().__init__()
         # self.embed = model.embed #Embedding(model.embed, max_seq_len)
         self.model = model
-
 
         from funasr.utils.torch_function import sequence_mask
 
@@ -1107,19 +1103,18 @@
 
         self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
 
-
         from funasr.models.transformer.decoder import DecoderLayerExport
         from funasr.models.transformer.attention import MultiHeadedAttentionExport
-        
+
         for i, d in enumerate(self.model.decoders):
             if isinstance(d.src_attn, MultiHeadedAttention):
                 d.src_attn = MultiHeadedAttentionExport(d.src_attn)
             self.model.decoders[i] = DecoderLayerExport(d)
-        
+
         self.output_layer = model.output_layer
         self.after_norm = model.after_norm
         self.model_name = model_name
-    
+
     def prepare_mask(self, mask):
         mask_3d_btd = mask[:, :, None]
         if len(mask.shape) == 2:
@@ -1127,9 +1122,9 @@
         elif len(mask.shape) == 3:
             mask_4d_bhlt = 1 - mask[:, None, :]
         mask_4d_bhlt = mask_4d_bhlt * -10000.0
-        
+
         return mask_3d_btd, mask_4d_bhlt
-    
+
     def forward(
         self,
         hs_pad: torch.Tensor,
@@ -1137,72 +1132,62 @@
         ys_in_pad: torch.Tensor,
         ys_in_lens: torch.Tensor,
     ):
-        
+
         tgt = ys_in_pad
         tgt_mask = self.make_pad_mask(ys_in_lens)
         tgt_mask, _ = self.prepare_mask(tgt_mask)
         # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-        
+
         memory = hs_pad
         memory_mask = self.make_pad_mask(hlens)
         _, memory_mask = self.prepare_mask(memory_mask)
         # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
-        
+
         x = tgt
-        x, tgt_mask, memory, memory_mask = self.model.decoders(
-            x, tgt_mask, memory, memory_mask
-        )
+        x, tgt_mask, memory, memory_mask = self.model.decoders(x, tgt_mask, memory, memory_mask)
         x = self.after_norm(x)
         x = self.output_layer(x)
-        
+
         return x, ys_in_lens
-    
+
     def get_dummy_inputs(self, enc_size):
         tgt = torch.LongTensor([0]).unsqueeze(0)
         memory = torch.randn(1, 100, enc_size)
         pre_acoustic_embeds = torch.randn(1, 1, enc_size)
         cache_num = len(self.model.decoders) + len(self.model.decoders2)
         cache = [
-            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+            torch.zeros(
+                (1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)
+            )
             for _ in range(cache_num)
         ]
         return (tgt, memory, pre_acoustic_embeds, cache)
-    
+
     def is_optimizable(self):
         return True
-    
+
     def get_input_names(self):
         cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
-               + ['cache_%d' % i for i in range(cache_num)]
-    
+        return ["tgt", "memory", "pre_acoustic_embeds"] + ["cache_%d" % i for i in range(cache_num)]
+
     def get_output_names(self):
         cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        return ['y'] \
-               + ['out_cache_%d' % i for i in range(cache_num)]
-    
+        return ["y"] + ["out_cache_%d" % i for i in range(cache_num)]
+
     def get_dynamic_axes(self):
         ret = {
-            'tgt': {
-                0: 'tgt_batch',
-                1: 'tgt_length'
+            "tgt": {0: "tgt_batch", 1: "tgt_length"},
+            "memory": {0: "memory_batch", 1: "memory_length"},
+            "pre_acoustic_embeds": {
+                0: "acoustic_embeds_batch",
+                1: "acoustic_embeds_length",
             },
-            'memory': {
-                0: 'memory_batch',
-                1: 'memory_length'
-            },
-            'pre_acoustic_embeds': {
-                0: 'acoustic_embeds_batch',
-                1: 'acoustic_embeds_length',
-            }
         }
         cache_num = len(self.model.decoders) + len(self.model.decoders2)
-        ret.update({
-            'cache_%d' % d: {
-                0: 'cache_%d_batch' % d,
-                2: 'cache_%d_length' % d
+        ret.update(
+            {
+                "cache_%d" % d: {0: "cache_%d_batch" % d, 2: "cache_%d_length" % d}
+                for d in range(cache_num)
             }
-            for d in range(cache_num)
-        })
+        )
         return ret
-    
\ No newline at end of file

--
Gitblit v1.9.1