From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/branchformer/encoder.py | 50 ++++++++++++++------------------------------------
1 files changed, 14 insertions(+), 36 deletions(-)
diff --git a/funasr/models/branchformer/encoder.py b/funasr/models/branchformer/encoder.py
index 4b5b237..a15a3f2 100644
--- a/funasr/models/branchformer/encoder.py
+++ b/funasr/models/branchformer/encoder.py
@@ -45,6 +45,7 @@
from funasr.register import tables
+
class BranchformerEncoderLayer(torch.nn.Module):
"""Branchformer encoder layer module.
@@ -73,9 +74,7 @@
stochastic_depth_rate: float = 0.0,
):
super().__init__()
- assert (attn is not None) or (
- cgmlp is not None
- ), "At least one branch should be valid"
+ assert (attn is not None) or (cgmlp is not None), "At least one branch should be valid"
self.size = size
self.attn = attn
@@ -111,9 +110,7 @@
self.merge_proj = torch.nn.Linear(size, size)
elif merge_method == "fixed_ave":
- assert (
- 0.0 <= cgmlp_weight <= 1.0
- ), "cgmlp weight should be between 0.0 and 1.0"
+ assert 0.0 <= cgmlp_weight <= 1.0, "cgmlp weight should be between 0.0 and 1.0"
# remove the other branch if only one branch is used
if cgmlp_weight == 0.0:
@@ -223,14 +220,10 @@
) # (batch, 1, time)
if mask is not None:
min_value = float(
- numpy.finfo(
- torch.tensor(0, dtype=score1.dtype).numpy().dtype
- ).min
+ numpy.finfo(torch.tensor(0, dtype=score1.dtype).numpy().dtype).min
)
score1 = score1.masked_fill(mask.eq(0), min_value)
- score1 = torch.softmax(score1, dim=-1).masked_fill(
- mask.eq(0), 0.0
- )
+ score1 = torch.softmax(score1, dim=-1).masked_fill(mask.eq(0), 0.0)
else:
score1 = torch.softmax(score1, dim=-1)
pooled1 = torch.matmul(score1, x1).squeeze(1) # (batch, size)
@@ -242,14 +235,10 @@
) # (batch, 1, time)
if mask is not None:
min_value = float(
- numpy.finfo(
- torch.tensor(0, dtype=score2.dtype).numpy().dtype
- ).min
+ numpy.finfo(torch.tensor(0, dtype=score2.dtype).numpy().dtype).min
)
score2 = score2.masked_fill(mask.eq(0), min_value)
- score2 = torch.softmax(score2, dim=-1).masked_fill(
- mask.eq(0), 0.0
- )
+ score2 = torch.softmax(score2, dim=-1).masked_fill(mask.eq(0), 0.0)
else:
score2 = torch.softmax(score2, dim=-1)
pooled2 = torch.matmul(score2, x2).squeeze(1) # (batch, size)
@@ -259,19 +248,13 @@
merge_weights = torch.softmax(
torch.cat([weight1, weight2], dim=-1), dim=-1
) # (batch, 2)
- merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
- -1
- ) # (batch, 2, 1, 1)
+ merge_weights = merge_weights.unsqueeze(-1).unsqueeze(-1) # (batch, 2, 1, 1)
w1, w2 = merge_weights[:, 0], merge_weights[:, 1] # (batch, 1, 1)
- x = x + stoch_layer_coeff * self.dropout(
- self.merge_proj(w1 * x1 + w2 * x2)
- )
+ x = x + stoch_layer_coeff * self.dropout(self.merge_proj(w1 * x1 + w2 * x2))
elif self.merge_method == "fixed_ave":
x = x + stoch_layer_coeff * self.dropout(
- self.merge_proj(
- (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2
- )
+ self.merge_proj((1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2)
)
else:
raise RuntimeError(f"unknown merge method: {self.merge_method}")
@@ -290,6 +273,7 @@
return (x, pos_emb), mask
return x, mask
+
@tables.register("encoder_classes", "BranchformerEncoder")
class BranchformerEncoder(nn.Module):
@@ -345,9 +329,7 @@
elif pos_enc_layer_type == "legacy_rel_pos":
assert attention_layer_type == "legacy_rel_selfattn"
pos_enc_class = LegacyRelPositionalEncoding
- logging.warning(
- "Using legacy_rel_pos and it will be deprecated in the future."
- )
+ logging.warning("Using legacy_rel_pos and it will be deprecated in the future.")
else:
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
@@ -419,9 +401,7 @@
output_size,
attention_dropout_rate,
)
- logging.warning(
- "Using legacy_rel_selfattn and it will be deprecated in the future."
- )
+ logging.warning("Using legacy_rel_selfattn and it will be deprecated in the future.")
elif attention_layer_type == "rel_selfattn":
assert pos_enc_layer_type == "rel_pos"
encoder_selfattn_layer = RelPositionMultiHeadedAttention
@@ -480,9 +460,7 @@
num_blocks,
lambda lnum: BranchformerEncoderLayer(
output_size,
- encoder_selfattn_layer(*encoder_selfattn_layer_args)
- if use_attn
- else None,
+ encoder_selfattn_layer(*encoder_selfattn_layer_args) if use_attn else None,
cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
dropout_rate,
merge_method,
--
Gitblit v1.9.1