kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/branchformer/encoder.py
@@ -45,6 +45,7 @@
from funasr.register import tables
class BranchformerEncoderLayer(torch.nn.Module):
    """Branchformer encoder layer module.
@@ -73,9 +74,7 @@
        stochastic_depth_rate: float = 0.0,
    ):
        super().__init__()
        assert (attn is not None) or (
            cgmlp is not None
        ), "At least one branch should be valid"
        assert (attn is not None) or (cgmlp is not None), "At least one branch should be valid"
        self.size = size
        self.attn = attn
@@ -111,9 +110,7 @@
                self.merge_proj = torch.nn.Linear(size, size)
            elif merge_method == "fixed_ave":
                assert (
                    0.0 <= cgmlp_weight <= 1.0
                ), "cgmlp weight should be between 0.0 and 1.0"
                assert 0.0 <= cgmlp_weight <= 1.0, "cgmlp weight should be between 0.0 and 1.0"
                # remove the other branch if only one branch is used
                if cgmlp_weight == 0.0:
@@ -223,14 +220,10 @@
                    )  # (batch, 1, time)
                    if mask is not None:
                        min_value = float(
                            numpy.finfo(
                                torch.tensor(0, dtype=score1.dtype).numpy().dtype
                            ).min
                            numpy.finfo(torch.tensor(0, dtype=score1.dtype).numpy().dtype).min
                        )
                        score1 = score1.masked_fill(mask.eq(0), min_value)
                        score1 = torch.softmax(score1, dim=-1).masked_fill(
                            mask.eq(0), 0.0
                        )
                        score1 = torch.softmax(score1, dim=-1).masked_fill(mask.eq(0), 0.0)
                    else:
                        score1 = torch.softmax(score1, dim=-1)
                    pooled1 = torch.matmul(score1, x1).squeeze(1)  # (batch, size)
@@ -242,14 +235,10 @@
                    )  # (batch, 1, time)
                    if mask is not None:
                        min_value = float(
                            numpy.finfo(
                                torch.tensor(0, dtype=score2.dtype).numpy().dtype
                            ).min
                            numpy.finfo(torch.tensor(0, dtype=score2.dtype).numpy().dtype).min
                        )
                        score2 = score2.masked_fill(mask.eq(0), min_value)
                        score2 = torch.softmax(score2, dim=-1).masked_fill(
                            mask.eq(0), 0.0
                        )
                        score2 = torch.softmax(score2, dim=-1).masked_fill(mask.eq(0), 0.0)
                    else:
                        score2 = torch.softmax(score2, dim=-1)
                    pooled2 = torch.matmul(score2, x2).squeeze(1)  # (batch, size)
@@ -259,19 +248,13 @@
                    merge_weights = torch.softmax(
                        torch.cat([weight1, weight2], dim=-1), dim=-1
                    )  # (batch, 2)
                    merge_weights = merge_weights.unsqueeze(-1).unsqueeze(
                        -1
                    )  # (batch, 2, 1, 1)
                    merge_weights = merge_weights.unsqueeze(-1).unsqueeze(-1)  # (batch, 2, 1, 1)
                    w1, w2 = merge_weights[:, 0], merge_weights[:, 1]  # (batch, 1, 1)
                x = x + stoch_layer_coeff * self.dropout(
                    self.merge_proj(w1 * x1 + w2 * x2)
                )
                x = x + stoch_layer_coeff * self.dropout(self.merge_proj(w1 * x1 + w2 * x2))
            elif self.merge_method == "fixed_ave":
                x = x + stoch_layer_coeff * self.dropout(
                    self.merge_proj(
                        (1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2
                    )
                    self.merge_proj((1.0 - self.cgmlp_weight) * x1 + self.cgmlp_weight * x2)
                )
            else:
                raise RuntimeError(f"unknown merge method: {self.merge_method}")
@@ -290,6 +273,7 @@
            return (x, pos_emb), mask
        return x, mask
@tables.register("encoder_classes", "BranchformerEncoder")
class BranchformerEncoder(nn.Module):
@@ -345,9 +329,7 @@
        elif pos_enc_layer_type == "legacy_rel_pos":
            assert attention_layer_type == "legacy_rel_selfattn"
            pos_enc_class = LegacyRelPositionalEncoding
            logging.warning(
                "Using legacy_rel_pos and it will be deprecated in the future."
            )
            logging.warning("Using legacy_rel_pos and it will be deprecated in the future.")
        else:
            raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
@@ -419,9 +401,7 @@
                output_size,
                attention_dropout_rate,
            )
            logging.warning(
                "Using legacy_rel_selfattn and it will be deprecated in the future."
            )
            logging.warning("Using legacy_rel_selfattn and it will be deprecated in the future.")
        elif attention_layer_type == "rel_selfattn":
            assert pos_enc_layer_type == "rel_pos"
            encoder_selfattn_layer = RelPositionMultiHeadedAttention
@@ -480,9 +460,7 @@
            num_blocks,
            lambda lnum: BranchformerEncoderLayer(
                output_size,
                encoder_selfattn_layer(*encoder_selfattn_layer_args)
                if use_attn
                else None,
                encoder_selfattn_layer(*encoder_selfattn_layer_args) if use_attn else None,
                cgmlp_layer(*cgmlp_layer_args) if use_cgmlp else None,
                dropout_rate,
                merge_method,