zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/transformer/encoder.py
@@ -30,6 +30,7 @@
from funasr.register import tables
class EncoderLayer(nn.Module):
    """Encoder layer module.
@@ -118,9 +119,7 @@
            x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
            x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
        else:
            x = residual + stoch_layer_coeff * self.dropout(
                self.self_attn(x_q, x, x, mask)
            )
            x = residual + stoch_layer_coeff * self.dropout(self.self_attn(x_q, x, x, mask))
        if not self.normalize_before:
            x = self.norm1(x)
@@ -135,6 +134,7 @@
            x = torch.cat([cache, x], dim=1)
        return x, mask
@tables.register("encoder_classes", "TransformerEncoder")
class TransformerEncoder(nn.Module):
@@ -243,9 +243,7 @@
            num_blocks,
            lambda lnum: EncoderLayer(
                output_size,
                MultiHeadedAttention(
                    attention_heads, output_size, attention_dropout_rate
                ),
                MultiHeadedAttention(attention_heads, output_size, attention_dropout_rate),
                positionwise_layer(*positionwise_layer_args),
                dropout_rate,
                normalize_before,
@@ -329,4 +327,3 @@
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None