wanchen.swc
2023-03-06 69ccdd35cda4c8482e189fa350fbcb83997872f2
funasr/export/models/modules/encoder_layer.py
@@ -16,6 +16,7 @@
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2
        self.in_size = model.in_size
        self.size = model.size
    def forward(self, x, mask):
@@ -23,13 +24,12 @@
        residual = x
        x = self.norm1(x)
        x = self.self_attn(x, mask)
        if x.size(2) == residual.size(2):
        if self.in_size == self.size:
            x = x + residual
        residual = x
        x = self.norm2(x)
        x = self.feed_forward(x)
        if x.size(2) == residual.size(2):
            x = x + residual
        x = x + residual
        return x, mask