| | |
| | | dropout_rate: float = 0.1, |
| | | use_pos_emb: bool = False): |
| | | super(EENDOLATransformerEncoder, self).__init__() |
| | | self.linear_in = nn.Linear(idim, n_units) |
| | | self.lnorm_in = nn.LayerNorm(n_units) |
| | | self.n_layers = n_layers |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | |
| | | setattr(self, '{}{:d}'.format("ff_", i), |
| | | PositionwiseFeedForward(n_units, e_units, dropout_rate)) |
| | | self.lnorm_out = nn.LayerNorm(n_units) |
| | | if use_pos_emb: |
| | | self.pos_enc = torch.nn.Sequential( |
| | | torch.nn.Linear(idim, n_units), |
| | | torch.nn.LayerNorm(n_units), |
| | | torch.nn.Dropout(dropout_rate), |
| | | torch.nn.ReLU(), |
| | | PositionalEncoding(n_units, dropout_rate), |
| | | ) |
| | | else: |
| | | self.linear_in = nn.Linear(idim, n_units) |
| | | self.pos_enc = None |
| | | |
| | | def __call__(self, x, x_mask=None): |
| | | BT_size = x.shape[0] * x.shape[1] |
| | | if self.pos_enc is not None: |
| | | e = self.pos_enc(x) |
| | | e = e.view(BT_size, -1) |
| | | else: |
| | | e = self.linear_in(x.reshape(BT_size, -1)) |
| | | e = self.linear_in(x.reshape(BT_size, -1)) |
| | | for i in range(self.n_layers): |
| | | e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e) |
| | | s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask) |
| | |
| | | e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e) |
| | | s = getattr(self, '{}{:d}'.format("ff_", i))(e) |
| | | e = e + self.dropout(s) |
| | | return self.lnorm_out(e) |
| | | return self.lnorm_out(e) |