zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/sond/encoder/self_attention_encoder.py
@@ -87,7 +87,9 @@
            x = self.norm1(x)
        if self.concat_after:
            x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
            x_concat = torch.cat(
                (x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1
            )
            if self.in_size == self.size:
                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
            else:
@@ -207,7 +209,8 @@
        self.encoders = repeat(
            num_blocks,
            lambda lnum: EncoderLayer(
            lambda lnum: (
                EncoderLayer(
                output_size,
                output_size,
                MultiHeadSelfAttention(
@@ -220,7 +223,9 @@
                dropout_rate,
                normalize_before,
                concat_after,
            ) if lnum > 0 else EncoderLayer(
                )
                if lnum > 0
                else EncoderLayer(
                input_size,
                output_size,
                MultiHeadSelfAttention(
@@ -233,6 +238,7 @@
                dropout_rate,
                normalize_before,
                concat_after,
                )
            ),
        )
        if self.normalize_before:
@@ -325,4 +331,3 @@
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None