From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/transformer/decoder.py |  318 +++++++++++++++++++++++++++--------------------------
 1 files changed, 162 insertions(+), 156 deletions(-)

diff --git a/funasr/models/transformer/decoder.py b/funasr/models/transformer/decoder.py
index 820de4a..7d849f3 100644
--- a/funasr/models/transformer/decoder.py
+++ b/funasr/models/transformer/decoder.py
@@ -28,6 +28,7 @@
 
 from funasr.register import tables
 
+
 class DecoderLayer(nn.Module):
     """Single decoder layer module.
 
@@ -51,14 +52,14 @@
     """
 
     def __init__(
-            self,
-            size,
-            self_attn,
-            src_attn,
-            feed_forward,
-            dropout_rate,
-            normalize_before=True,
-            concat_after=False,
+        self,
+        size,
+        self_attn,
+        src_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
     ):
         """Construct an DecoderLayer object."""
         super(DecoderLayer, self).__init__()
@@ -115,9 +116,7 @@
                 tgt_q_mask = tgt_mask[:, -1:, :]
 
         if self.concat_after:
-            tgt_concat = torch.cat(
-                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
-            )
+            tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1)
             x = residual + self.concat_linear1(tgt_concat)
         else:
             x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
@@ -128,9 +127,7 @@
         if self.normalize_before:
             x = self.norm2(x)
         if self.concat_after:
-            x_concat = torch.cat(
-                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
-            )
+            x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1)
             x = residual + self.concat_linear2(x_concat)
         else:
             x = residual + self.dropout(self.src_attn(x, memory, memory, memory_mask))
@@ -146,6 +143,35 @@
 
         if cache is not None:
             x = torch.cat([cache, x], dim=1)
+
+        return x, tgt_mask, memory, memory_mask
+
+
+class DecoderLayerExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.src_attn = model.src_attn
+        self.feed_forward = model.feed_forward
+        self.norm1 = model.norm1
+        self.norm2 = model.norm2
+        self.norm3 = model.norm3
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+        residual = tgt
+        tgt = self.norm1(tgt)
+        tgt_q = tgt
+        tgt_q_mask = tgt_mask
+        x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
+
+        residual = x
+        x = self.norm2(x)
+
+        x = residual + self.src_attn(x, memory, memory, memory_mask)
+
+        residual = x
+        x = self.norm3(x)
+        x = residual + self.feed_forward(x)
 
         return x, tgt_mask, memory, memory_mask
 
@@ -173,15 +199,15 @@
     """
 
     def __init__(
-            self,
-            vocab_size: int,
-            encoder_output_size: int,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            input_layer: str = "embed",
-            use_output_layer: bool = True,
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
     ):
         super().__init__()
         attention_dim = encoder_output_size
@@ -214,11 +240,11 @@
         self.decoders = None
 
     def forward(
-            self,
-            hs_pad: torch.Tensor,
-            hlens: torch.Tensor,
-            ys_in_pad: torch.Tensor,
-            ys_in_lens: torch.Tensor,
+        self,
+        hs_pad: torch.Tensor,
+        hlens: torch.Tensor,
+        ys_in_pad: torch.Tensor,
+        ys_in_lens: torch.Tensor,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Forward decoder.
 
@@ -246,20 +272,14 @@
         tgt_mask = tgt_mask & m
 
         memory = hs_pad
-        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(
-            memory.device
-        )
+        memory_mask = (~make_pad_mask(hlens, maxlen=memory.size(1)))[:, None, :].to(memory.device)
         # Padding for Longformer
         if memory_mask.shape[-1] != memory.shape[1]:
             padlen = memory.shape[1] - memory_mask.shape[-1]
-            memory_mask = torch.nn.functional.pad(
-                memory_mask, (0, padlen), "constant", False
-            )
+            memory_mask = torch.nn.functional.pad(memory_mask, (0, padlen), "constant", False)
 
         x = self.embed(tgt)
-        x, tgt_mask, memory, memory_mask = self.decoders(
-            x, tgt_mask, memory, memory_mask
-        )
+        x, tgt_mask, memory, memory_mask = self.decoders(x, tgt_mask, memory, memory_mask)
         if self.normalize_before:
             x = self.after_norm(x)
         if self.output_layer is not None:
@@ -269,11 +289,11 @@
         return x, olens
 
     def forward_one_step(
-            self,
-            tgt: torch.Tensor,
-            tgt_mask: torch.Tensor,
-            memory: torch.Tensor,
-            cache: List[torch.Tensor] = None,
+        self,
+        tgt: torch.Tensor,
+        tgt_mask: torch.Tensor,
+        memory: torch.Tensor,
+        cache: List[torch.Tensor] = None,
     ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
         """Forward one step.
 
@@ -293,9 +313,7 @@
             cache = [None] * len(self.decoders)
         new_cache = []
         for c, decoder in zip(cache, self.decoders):
-            x, tgt_mask, memory, memory_mask = decoder(
-                x, tgt_mask, memory, None, cache=c
-            )
+            x, tgt_mask, memory, memory_mask = decoder(x, tgt_mask, memory, None, cache=c)
             new_cache.append(x)
 
         if self.normalize_before:
@@ -310,13 +328,11 @@
     def score(self, ys, state, x):
         """Score."""
         ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
-        logp, state = self.forward_one_step(
-            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
-        )
+        logp, state = self.forward_one_step(ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state)
         return logp.squeeze(0), state
 
     def batch_score(
-            self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
+        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
     ) -> Tuple[torch.Tensor, List[Any]]:
         """Score new token batch.
 
@@ -340,8 +356,7 @@
         else:
             # transpose state of [batch, layer] into [layer, batch]
             batch_state = [
-                torch.stack([states[b][i] for b in range(n_batch)])
-                for i in range(n_layers)
+                torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)
             ]
 
         # batch decoding
@@ -352,24 +367,25 @@
         state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
         return logp, state_list
 
+
 @tables.register("decoder_classes", "TransformerDecoder")
 class TransformerDecoder(BaseTransformerDecoder):
     def __init__(
-            self,
-            vocab_size: int,
-            encoder_output_size: int,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            self_attention_dropout_rate: float = 0.0,
-            src_attention_dropout_rate: float = 0.0,
-            input_layer: str = "embed",
-            use_output_layer: bool = True,
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
     ):
         super().__init__(
             vocab_size=vocab_size,
@@ -387,12 +403,8 @@
             num_blocks,
             lambda lnum: DecoderLayer(
                 attention_dim,
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, self_attention_dropout_rate
-                ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
+                MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate),
+                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                 PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
                 normalize_before,
@@ -404,24 +416,24 @@
 @tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder")
 class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
     def __init__(
-            self,
-            vocab_size: int,
-            encoder_output_size: int,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            self_attention_dropout_rate: float = 0.0,
-            src_attention_dropout_rate: float = 0.0,
-            input_layer: str = "embed",
-            use_output_layer: bool = True,
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            conv_wshare: int = 4,
-            conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
-            conv_usebias: int = False,
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        conv_wshare: int = 4,
+        conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+        conv_usebias: int = False,
     ):
         if len(conv_kernel_length) != num_blocks:
             raise ValueError(
@@ -452,9 +464,7 @@
                     use_kernel_mask=True,
                     use_bias=conv_usebias,
                 ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
+                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                 PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
                 normalize_before,
@@ -462,27 +472,28 @@
             ),
         )
 
+
 @tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder")
 class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
     def __init__(
-            self,
-            vocab_size: int,
-            encoder_output_size: int,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            self_attention_dropout_rate: float = 0.0,
-            src_attention_dropout_rate: float = 0.0,
-            input_layer: str = "embed",
-            use_output_layer: bool = True,
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            conv_wshare: int = 4,
-            conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
-            conv_usebias: int = False,
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        conv_wshare: int = 4,
+        conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+        conv_usebias: int = False,
     ):
         if len(conv_kernel_length) != num_blocks:
             raise ValueError(
@@ -513,9 +524,7 @@
                     use_kernel_mask=True,
                     use_bias=conv_usebias,
                 ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
+                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                 PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
                 normalize_before,
@@ -527,24 +536,24 @@
 @tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder")
 class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
     def __init__(
-            self,
-            vocab_size: int,
-            encoder_output_size: int,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            self_attention_dropout_rate: float = 0.0,
-            src_attention_dropout_rate: float = 0.0,
-            input_layer: str = "embed",
-            use_output_layer: bool = True,
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            conv_wshare: int = 4,
-            conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
-            conv_usebias: int = False,
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        conv_wshare: int = 4,
+        conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+        conv_usebias: int = False,
     ):
         if len(conv_kernel_length) != num_blocks:
             raise ValueError(
@@ -575,9 +584,7 @@
                     use_kernel_mask=True,
                     use_bias=conv_usebias,
                 ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
+                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                 PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
                 normalize_before,
@@ -585,27 +592,28 @@
             ),
         )
 
+
 @tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder")
 class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
     def __init__(
-            self,
-            vocab_size: int,
-            encoder_output_size: int,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            self_attention_dropout_rate: float = 0.0,
-            src_attention_dropout_rate: float = 0.0,
-            input_layer: str = "embed",
-            use_output_layer: bool = True,
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            conv_wshare: int = 4,
-            conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
-            conv_usebias: int = False,
+        self,
+        vocab_size: int,
+        encoder_output_size: int,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        self_attention_dropout_rate: float = 0.0,
+        src_attention_dropout_rate: float = 0.0,
+        input_layer: str = "embed",
+        use_output_layer: bool = True,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        conv_wshare: int = 4,
+        conv_kernel_length: Sequence[int] = (11, 11, 11, 11, 11, 11),
+        conv_usebias: int = False,
     ):
         if len(conv_kernel_length) != num_blocks:
             raise ValueError(
@@ -636,9 +644,7 @@
                     use_kernel_mask=True,
                     use_bias=conv_usebias,
                 ),
-                MultiHeadedAttention(
-                    attention_heads, attention_dim, src_attention_dropout_rate
-                ),
+                MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
                 PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
                 normalize_before,

--
Gitblit v1.9.1