| | |
| | | self.feed_forward = model.feed_forward |
| | | self.norm1 = model.norm1 |
| | | self.norm2 = model.norm2 |
| | | self.in_size = model.in_size |
| | | self.size = model.size |
| | | |
| | | def forward(self, x, mask): |
| | |
| | | residual = x |
| | | x = self.norm1(x) |
| | | x = self.self_attn(x, mask) |
| | | if x.size(2) == residual.size(2): |
| | | if self.in_size == self.size: |
| | | x = x + residual |
| | | residual = x |
| | | x = self.norm2(x) |
| | | x = self.feed_forward(x) |
| | | if x.size(2) == residual.size(2): |
| | | x = x + residual |
| | | x = x + residual |
| | | |
| | | return x, mask |
| | | |
| | |
| | | if self.feed_forward_macaron is not None: |
| | | residual = x |
| | | x = self.norm_ff_macaron(x) |
| | | x = residual + self.feed_forward_macaron(x) |
| | | x = residual + self.feed_forward_macaron(x) * 0.5 |
| | | |
| | | residual = x |
| | | x = self.norm_mha(x) |
| | |
| | | |
| | | residual = x |
| | | x = self.norm_ff(x) |
| | | x = residual + self.feed_forward(x) |
| | | x = residual + self.feed_forward(x) * 0.5 |
| | | |
| | | x = self.norm_final(x) |
| | | |