kongdeqiang
2026-03-13 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/sond/encoder/self_attention_encoder.py
@@ -87,7 +87,9 @@
            x = self.norm1(x)
        if self.concat_after:
            x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
            x_concat = torch.cat(
                (x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1
            )
            if self.in_size == self.size:
                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
            else:
@@ -207,32 +209,36 @@
        self.encoders = repeat(
            num_blocks,
            lambda lnum: EncoderLayer(
                output_size,
                output_size,
                MultiHeadSelfAttention(
                    attention_heads,
            lambda lnum: (
                EncoderLayer(
                    output_size,
                    output_size,
                    attention_dropout_rate,
                ),
                positionwise_layer(*positionwise_layer_args),
                dropout_rate,
                normalize_before,
                concat_after,
            ) if lnum > 0 else EncoderLayer(
                input_size,
                output_size,
                MultiHeadSelfAttention(
                    attention_heads,
                    input_size if input_layer == "pe" or input_layer == "null" else output_size,
                    MultiHeadSelfAttention(
                        attention_heads,
                        output_size,
                        output_size,
                        attention_dropout_rate,
                    ),
                    positionwise_layer(*positionwise_layer_args),
                    dropout_rate,
                    normalize_before,
                    concat_after,
                )
                if lnum > 0
                else EncoderLayer(
                    input_size,
                    output_size,
                    attention_dropout_rate,
                ),
                positionwise_layer(*positionwise_layer_args),
                dropout_rate,
                normalize_before,
                concat_after,
                    MultiHeadSelfAttention(
                        attention_heads,
                        input_size if input_layer == "pe" or input_layer == "null" else output_size,
                        output_size,
                        attention_dropout_rate,
                    ),
                    positionwise_layer(*positionwise_layer_args),
                    dropout_rate,
                    normalize_before,
                    concat_after,
                )
            ),
        )
        if self.normalize_before:
@@ -270,7 +276,7 @@
            position embedded tensor and mask
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        xs_pad = xs_pad * self.output_size()**0.5
        xs_pad = xs_pad * self.output_size() ** 0.5
        if self.embed is None:
            xs_pad = xs_pad
        elif (
@@ -325,4 +331,3 @@
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None