| | |
| | | |
| | | |
| | | class PositionwiseFeedForwardDecoderSANMExport(torch.nn.Module): |
| | | def __init__(self, model): |
| | | super().__init__() |
| | | self.w_1 = model.w_1 |
| | | self.w_2 = model.w_2 |
| | | self.activation = model.activation |
| | | self.norm = model.norm |
| | | |
| | | def forward(self, x): |
| | | x = self.activation(self.w_1(x)) |
| | | x = self.w_2(self.norm(x)) |
| | | return x |
| | | def __init__(self, model): |
| | | super().__init__() |
| | | self.w_1 = model.w_1 |
| | | self.w_2 = model.w_2 |
| | | self.activation = model.activation |
| | | self.norm = model.norm |
| | | |
| | | def forward(self, x): |
| | | x = self.activation(self.w_1(x)) |
| | | x = self.w_2(self.norm(x)) |
| | | return x |