funasr/export/models/modules/encoder_layer.py
@@ -16,6 +16,7 @@ self.feed_forward = model.feed_forward self.norm1 = model.norm1 self.norm2 = model.norm2 self.in_size = model.in_size self.size = model.size def forward(self, x, mask): @@ -23,13 +24,12 @@ residual = x x = self.norm1(x) x = self.self_attn(x, mask) if x.size(2) == residual.size(2): if self.in_size == self.size: x = x + residual residual = x x = self.norm2(x) x = self.feed_forward(x) if x.size(2) == residual.size(2): x = x + residual x = x + residual return x, mask