From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/lcbnet/encoder.py |  120 +++++++++++++++++++++++++++++++----------------------------
 1 files changed, 63 insertions(+), 57 deletions(-)

diff --git a/funasr/models/lcbnet/encoder.py b/funasr/models/lcbnet/encoder.py
index c65823c..f5f2497 100644
--- a/funasr/models/lcbnet/encoder.py
+++ b/funasr/models/lcbnet/encoder.py
@@ -21,6 +21,7 @@
 from funasr.models.transformer.utils.repeat import repeat
 from funasr.register import tables
 
+
 class EncoderLayer(nn.Module):
     """Encoder layer module.
 
@@ -44,14 +45,14 @@
     """
 
     def __init__(
-            self,
-            size,
-            self_attn,
-            feed_forward,
-            dropout_rate,
-            normalize_before=True,
-            concat_after=False,
-            stochastic_depth_rate=0.0,
+        self,
+        size,
+        self_attn,
+        feed_forward,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
+        stochastic_depth_rate=0.0,
     ):
         """Construct an EncoderLayer object."""
         super(EncoderLayer, self).__init__()
@@ -109,9 +110,7 @@
             x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
             x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
         else:
-            x = residual + stoch_layer_coeff * self.dropout(
-                self.self_attn(x_q, x, x, mask)
-            )
+            x = residual + stoch_layer_coeff * self.dropout(self.self_attn(x_q, x, x, mask))
         if not self.normalize_before:
             x = self.norm1(x)
 
@@ -126,6 +125,7 @@
             x = torch.cat([cache, x], dim=1)
 
         return x, mask
+
 
 @tables.register("encoder_classes", "TransformerTextEncoder")
 class TransformerTextEncoder(nn.Module):
@@ -154,18 +154,18 @@
     """
 
     def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        pos_enc_class=PositionalEncoding,
+        normalize_before: bool = True,
+        concat_after: bool = False,
     ):
         super().__init__()
         self._output_size = output_size
@@ -187,9 +187,7 @@
             num_blocks,
             lambda lnum: EncoderLayer(
                 output_size,
-                MultiHeadedAttention(
-                    attention_heads, output_size, attention_dropout_rate
-                ),
+                MultiHeadedAttention(attention_heads, output_size, attention_dropout_rate),
                 positionwise_layer(*positionwise_layer_args),
                 dropout_rate,
                 normalize_before,
@@ -203,9 +201,9 @@
         return self._output_size
 
     def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Embed positions in tensor.
 
@@ -225,8 +223,6 @@
 
         olens = masks.squeeze(1).sum(1)
         return xs_pad, olens, None
-
-
 
 
 @tables.register("encoder_classes", "FusionSANEncoder")
@@ -251,25 +247,32 @@
 
 
     """
+
     def __init__(
-            self,
-            size,
-            attention_heads,
-            attention_dim,
-            linear_units,
-            self_attention_dropout_rate,
-            src_attention_dropout_rate,
-            positional_dropout_rate,
-            dropout_rate,
-            normalize_before=True,
-            concat_after=False,
+        self,
+        size,
+        attention_heads,
+        attention_dim,
+        linear_units,
+        self_attention_dropout_rate,
+        src_attention_dropout_rate,
+        positional_dropout_rate,
+        dropout_rate,
+        normalize_before=True,
+        concat_after=False,
     ):
         """Construct an SelfSrcAttention object."""
         super(SelfSrcAttention, self).__init__()
         self.size = size
-        self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
-        self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
-        self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
+        self.self_attn = MultiHeadedAttention(
+            attention_heads, attention_dim, self_attention_dropout_rate
+        )
+        self.src_attn = MultiHeadedAttentionReturnWeight(
+            attention_heads, attention_dim, src_attention_dropout_rate
+        )
+        self.feed_forward = PositionwiseFeedForward(
+            attention_dim, linear_units, positional_dropout_rate
+        )
         self.norm1 = LayerNorm(size)
         self.norm2 = LayerNorm(size)
         self.norm3 = LayerNorm(size)
@@ -319,9 +322,7 @@
                 tgt_q_mask = tgt_mask[:, -1:, :]
 
         if self.concat_after:
-            tgt_concat = torch.cat(
-                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
-            )
+            tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1)
             x = residual + self.concat_linear1(tgt_concat)
         else:
             x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
@@ -332,9 +333,7 @@
         if self.normalize_before:
             x = self.norm2(x)
         if self.concat_after:
-            x_concat = torch.cat(
-                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
-            )
+            x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1)
             x = residual + self.concat_linear2(x_concat)
         else:
             x, score = self.src_attn(x, memory, memory, memory_mask)
@@ -357,7 +356,15 @@
 
 @tables.register("encoder_classes", "ConvBiasPredictor")
 class ConvPredictor(nn.Module):
-    def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
+    def __init__(
+        self,
+        size=256,
+        l_order=3,
+        r_order=3,
+        attention_heads=4,
+        attention_dropout_rate=0.1,
+        linear_units=2048,
+    ):
         super().__init__()
         self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
         self.norm1 = LayerNorm(size)
@@ -367,17 +374,16 @@
         self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
         self.output_linear = nn.Linear(size, 1)
 
-
     def forward(self, text_enc, asr_enc):
         # stage1 cross-attention
         residual = text_enc
         text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
-        
+
         # stage2 FFN
         residual = text_enc
         text_enc = self.norm1(text_enc)
         text_enc = residual + self.feed_forward(text_enc)
-        
+
         # stage Conv predictor
         text_enc = self.norm2(text_enc)
         context = text_enc.transpose(1, 2)
@@ -387,6 +393,6 @@
         output = output.transpose(1, 2)
         output = torch.relu(output)
         output = self.output_linear(output)
-        if output.dim()==3:
-          output = output.squeeze(2)
+        if output.dim() == 3:
+            output = output.squeeze(2)
         return output

--
Gitblit v1.9.1