From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/branchformer/encoder.py |   50 ++++++++++++++------------------------------------
 1 files changed, 14 insertions(+), 36 deletions(-)

diff --git a/funasr/models/branchformer/encoder.py b/funasr/models/branchformer/encoder.py
index 4b5b237..a15a3f2 100644
--- a/funasr/models/branchformer/encoder.py
+++ b/funasr/models/branchformer/encoder.py
@@ -45,6 +45,7 @@
 
 from funasr.register import tables
 
+
 class BranchformerEncoderLayer(torch.nn.Module):
     """Branchformer encoder layer module.
 
@@ -73,9 +74,7 @@
         stochastic_depth_rate: float = 0.0,
     ):
         super().__init__()
-        assert (attn is not None) or (
-            cgmlp is not None
-        ), "At least one branch should be valid"
+        assert (attn is not None) or (cgmlp is not None), "At least one branch should be valid"
 
         self.size = size
         self.attn = attn
@@ -111,9 +110,7 @@
                 self.merge_proj = torch.nn.Linear(size, size)
 
             elif merge_method == "fixed_ave":
-                assert (
-                    0.0 <= cgmlp_weight <= 1.0
-                ), "cgmlp weight should be between 0.0 and 1.0"
+                assert 0.0 <= cgmlp_weight <= 1.0, "cgmlp weight should be between 0.0 and 1.0"
 
                 # remove the other branch if only one branch is used
                 if cgmlp_weight == 0.0:
@@ -223,14 +220,10 @@
                     )  # (batch, 1, time)
                     if mask is not None:
                         min_value = float(
-                            numpy.finfo(
-                                torch.tensor(0, dtype=score1.dtype).numpy().dtype
-                            ).min
+                            numpy.finfo(torch.tensor(0, dtype=score1.dtype).numpy().dtype).min
                         )
                         score1 = score1.masked_fill(mask.eq(0), min_value)
-                        score1 = torch.softmax(score1, dim=-1).masked_fill(
-                            mask.eq(0), 0.0
-                        )
+                        score1 = torch.softmax(score1, dim=-1).masked_fill(mask.eq(0), 0.0)
                     else:
                         score1 = torch.softmax(score1, dim=-1)
                     pooled1 = torch.matmul(score1, x1).squeeze(1)  # (batch, size)
@@ -242,14 +235,10 @@
                     )  # (batch, 1, time)
                     if mask is not None:
                         min_value = float(
-                            numpy.finfo(
-                                torch.tensor(0, dtype=score2.dtype).numpy().dtype
-                            ).min
+                            numpy.finfo(torch.tensor(0, dtype=score2.dtype).numpy().dtype).min
                         )
                         score2 = score2.masked_fill(mask.eq(0), min_value)
-                        score2 = torch.softmax(score2, dim=-1).masked_fill(
-                            mask.eq(0), 0.0
-                        )
+                        score2 = torch.softmax(score2, dim=-1).masked_fill(mask.eq(0), 0.0)
                     else:
                         score2 = torch.softmax(score2, dim=-1)
                     pooled2 = torch.matmul(score2, x2).squeeze(1)  # (batch, size)
@@ -259,19 +248,13 @@
                     merge_weights = torch.softmax(
                         torch.cat([weight1, weight2], dim=-1), dim=-1
                     )  # (batch, 2)
-                    merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
-                        -1
-                    )  # (batch, 2, 1, 1)
+                    merge_weights = merge_weights.unsqueeze(-1).unsqueeze(-1)  # (batch, 2, 1, 1)
                     w1, w2 = merge_weights[:, 0], merge_weights[:, 1]  # (batch, 1, 1)
 
-                x = x + stoch_layer_coeff * self.dropout(
-                    self.merge_proj(w1 * x1 + w2 * x2)
-                )
+                x = x + stoch_layer_coeff * self.dropout(self.merge_proj(w1 * x1 + w2 * x2))
             elif self.merge_method == "fixed_ave":
                 x = x + stoch_layer_coeff * self.dropout(
-                    self.merge_proj(
-                        (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2
-                    )
+                    self.merge_proj((1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2)
                 )
             else:
                 raise RuntimeError(f"unknown merge method: {self.merge_method}")
@@ -290,6 +273,7 @@
             return (x, pos_emb), mask
 
         return x, mask
+
 
 @tables.register("encoder_classes", "BranchformerEncoder")
 class BranchformerEncoder(nn.Module):
@@ -345,9 +329,7 @@
         elif pos_enc_layer_type == "legacy_rel_pos":
             assert attention_layer_type == "legacy_rel_selfattn"
             pos_enc_class = LegacyRelPositionalEncoding
-            logging.warning(
-                "Using legacy_rel_pos and it will be deprecated in the future."
-            )
+            logging.warning("Using legacy_rel_pos and it will be deprecated in the future.")
         else:
             raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
 
@@ -419,9 +401,7 @@
                 output_size,
                 attention_dropout_rate,
             )
-            logging.warning(
-                "Using legacy_rel_selfattn and it will be deprecated in the future."
-            )
+            logging.warning("Using legacy_rel_selfattn and it will be deprecated in the future.")
         elif attention_layer_type == "rel_selfattn":
             assert pos_enc_layer_type == "rel_pos"
             encoder_selfattn_layer = RelPositionMultiHeadedAttention
@@ -480,9 +460,7 @@
             num_blocks,
             lambda lnum: BranchformerEncoderLayer(
                 output_size,
-                encoder_selfattn_layer(*encoder_selfattn_layer_args)
-                if use_attn
-                else None,
+                encoder_selfattn_layer(*encoder_selfattn_layer_args) if use_attn else None,
                 cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
                 dropout_rate,
                 merge_method,

--
Gitblit v1.9.1