From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/sond/encoder/self_attention_encoder.py |   57 +++++++++++++++++++++++++++++++--------------------------
 1 files changed, 31 insertions(+), 26 deletions(-)

diff --git a/funasr/models/sond/encoder/self_attention_encoder.py b/funasr/models/sond/encoder/self_attention_encoder.py
index f3c4736..2e979b1 100644
--- a/funasr/models/sond/encoder/self_attention_encoder.py
+++ b/funasr/models/sond/encoder/self_attention_encoder.py
@@ -87,7 +87,9 @@
             x = self.norm1(x)
 
         if self.concat_after:
-            x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+            x_concat = torch.cat(
+                (x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1
+            )
             if self.in_size == self.size:
                 x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
             else:
@@ -207,32 +209,36 @@
 
         self.encoders = repeat(
             num_blocks,
-            lambda lnum: EncoderLayer(
-                output_size,
-                output_size,
-                MultiHeadSelfAttention(
-                    attention_heads,
+            lambda lnum: (
+                EncoderLayer(
                     output_size,
                     output_size,
-                    attention_dropout_rate,
-                ),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ) if lnum > 0 else EncoderLayer(
-                input_size,
-                output_size,
-                MultiHeadSelfAttention(
-                    attention_heads,
-                    input_size if input_layer == "pe" or input_layer == "null" else output_size,
+                    MultiHeadSelfAttention(
+                        attention_heads,
+                        output_size,
+                        output_size,
+                        attention_dropout_rate,
+                    ),
+                    positionwise_layer(*positionwise_layer_args),
+                    dropout_rate,
+                    normalize_before,
+                    concat_after,
+                )
+                if lnum > 0
+                else EncoderLayer(
+                    input_size,
                     output_size,
-                    attention_dropout_rate,
-                ),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
+                    MultiHeadSelfAttention(
+                        attention_heads,
+                        input_size if input_layer == "pe" or input_layer == "null" else output_size,
+                        output_size,
+                        attention_dropout_rate,
+                    ),
+                    positionwise_layer(*positionwise_layer_args),
+                    dropout_rate,
+                    normalize_before,
+                    concat_after,
+                )
             ),
         )
         if self.normalize_before:
@@ -270,7 +276,7 @@
             position embedded tensor and mask
         """
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad = xs_pad * self.output_size()**0.5
+        xs_pad = xs_pad * self.output_size() ** 0.5
         if self.embed is None:
             xs_pad = xs_pad
         elif (
@@ -325,4 +331,3 @@
         if len(intermediate_outs) > 0:
             return (xs_pad, intermediate_outs), olens, None
         return xs_pad, olens, None
-

--
Gitblit v1.9.1