kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/lcbnet/encoder.py
@@ -21,6 +21,7 @@
from funasr.models.transformer.utils.repeat import repeat
from funasr.register import tables
class EncoderLayer(nn.Module):
    """Encoder layer module.
@@ -44,14 +45,14 @@
    """
    def __init__(
            self,
            size,
            self_attn,
            feed_forward,
            dropout_rate,
            normalize_before=True,
            concat_after=False,
            stochastic_depth_rate=0.0,
        self,
        size,
        self_attn,
        feed_forward,
        dropout_rate,
        normalize_before=True,
        concat_after=False,
        stochastic_depth_rate=0.0,
    ):
        """Construct an EncoderLayer object."""
        super(EncoderLayer, self).__init__()
@@ -109,9 +110,7 @@
            x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
            x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
        else:
            x = residual + stoch_layer_coeff * self.dropout(
                self.self_attn(x_q, x, x, mask)
            )
            x = residual + stoch_layer_coeff * self.dropout(self.self_attn(x_q, x, x, mask))
        if not self.normalize_before:
            x = self.norm1(x)
@@ -126,6 +125,7 @@
            x = torch.cat([cache, x], dim=1)
        return x, mask
@tables.register("encoder_classes", "TransformerTextEncoder")
class TransformerTextEncoder(nn.Module):
@@ -154,18 +154,18 @@
    """
    def __init__(
            self,
            input_size: int,
            output_size: int = 256,
            attention_heads: int = 4,
            linear_units: int = 2048,
            num_blocks: int = 6,
            dropout_rate: float = 0.1,
            positional_dropout_rate: float = 0.1,
            attention_dropout_rate: float = 0.0,
            pos_enc_class=PositionalEncoding,
            normalize_before: bool = True,
            concat_after: bool = False,
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        pos_enc_class=PositionalEncoding,
        normalize_before: bool = True,
        concat_after: bool = False,
    ):
        super().__init__()
        self._output_size = output_size
@@ -187,9 +187,7 @@
            num_blocks,
            lambda lnum: EncoderLayer(
                output_size,
                MultiHeadedAttention(
                    attention_heads, output_size, attention_dropout_rate
                ),
                MultiHeadedAttention(attention_heads, output_size, attention_dropout_rate),
                positionwise_layer(*positionwise_layer_args),
                dropout_rate,
                normalize_before,
@@ -203,9 +201,9 @@
        return self._output_size
    def forward(
            self,
            xs_pad: torch.Tensor,
            ilens: torch.Tensor,
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Embed positions in tensor.
@@ -225,8 +223,6 @@
        olens = masks.squeeze(1).sum(1)
        return xs_pad, olens, None
@tables.register("encoder_classes", "FusionSANEncoder")
@@ -251,25 +247,32 @@
    """
    def __init__(
            self,
            size,
            attention_heads,
            attention_dim,
            linear_units,
            self_attention_dropout_rate,
            src_attention_dropout_rate,
            positional_dropout_rate,
            dropout_rate,
            normalize_before=True,
            concat_after=False,
        self,
        size,
        attention_heads,
        attention_dim,
        linear_units,
        self_attention_dropout_rate,
        src_attention_dropout_rate,
        positional_dropout_rate,
        dropout_rate,
        normalize_before=True,
        concat_after=False,
    ):
        """Construct an SelfSrcAttention object."""
        super(SelfSrcAttention, self).__init__()
        self.size = size
        self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
        self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
        self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
        self.self_attn = MultiHeadedAttention(
            attention_heads, attention_dim, self_attention_dropout_rate
        )
        self.src_attn = MultiHeadedAttentionReturnWeight(
            attention_heads, attention_dim, src_attention_dropout_rate
        )
        self.feed_forward = PositionwiseFeedForward(
            attention_dim, linear_units, positional_dropout_rate
        )
        self.norm1 = LayerNorm(size)
        self.norm2 = LayerNorm(size)
        self.norm3 = LayerNorm(size)
@@ -319,9 +322,7 @@
                tgt_q_mask = tgt_mask[:, -1:, :]
        if self.concat_after:
            tgt_concat = torch.cat(
                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
            )
            tgt_concat = torch.cat((tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1)
            x = residual + self.concat_linear1(tgt_concat)
        else:
            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
@@ -332,9 +333,7 @@
        if self.normalize_before:
            x = self.norm2(x)
        if self.concat_after:
            x_concat = torch.cat(
                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
            )
            x_concat = torch.cat((x, self.src_attn(x, memory, memory, memory_mask)), dim=-1)
            x = residual + self.concat_linear2(x_concat)
        else:
            x, score = self.src_attn(x, memory, memory, memory_mask)
@@ -357,7 +356,15 @@
@tables.register("encoder_classes", "ConvBiasPredictor")
class ConvPredictor(nn.Module):
    def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
    def __init__(
        self,
        size=256,
        l_order=3,
        r_order=3,
        attention_heads=4,
        attention_dropout_rate=0.1,
        linear_units=2048,
    ):
        super().__init__()
        self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
        self.norm1 = LayerNorm(size)
@@ -367,17 +374,16 @@
        self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
        self.output_linear = nn.Linear(size, 1)
    def forward(self, text_enc, asr_enc):
        # stage1 cross-attention
        residual = text_enc
        text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
        # stage2 FFN
        residual = text_enc
        text_enc = self.norm1(text_enc)
        text_enc = residual + self.feed_forward(text_enc)
        # stage Conv predictor
        text_enc = self.norm2(text_enc)
        context = text_enc.transpose(1, 2)
@@ -387,6 +393,6 @@
        output = output.transpose(1, 2)
        output = torch.relu(output)
        output = self.output_linear(output)
        if output.dim()==3:
          output = output.squeeze(2)
        if output.dim() == 3:
            output = output.squeeze(2)
        return output