| | |
| | | return self.w_2(self.dropout(self.activation(self.w_1(x)))) |
| | | |
| | | |
| | | class PositionwiseFeedForwardDecoderSANM(torch.nn.Module): |
| | | """Positionwise feed forward layer. |
| | | |
| | | Args: |
| | | idim (int): Input dimenstion. |
| | | hidden_units (int): The number of hidden units. |
| | | dropout_rate (float): Dropout rate. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, idim, hidden_units, dropout_rate, adim=None, activation=torch.nn.ReLU()): |
| | | """Construct an PositionwiseFeedForward object.""" |
| | | super(PositionwiseFeedForwardDecoderSANM, self).__init__() |
| | | self.w_1 = torch.nn.Linear(idim, hidden_units) |
| | | self.w_2 = torch.nn.Linear(hidden_units, idim if adim is None else adim, bias=False) |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | self.activation = activation |
| | | self.norm = LayerNorm(hidden_units) |
| | | |
| | | def forward(self, x): |
| | | """Forward function.""" |
| | | return self.w_2(self.norm(self.dropout(self.activation(self.w_1(x))))) |