| | |
| | | self.act = nn.GLU() |
| | | |
| | | # lightconv related |
| | | self.weight = nn.Parameter( |
| | | torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1) |
| | | ) |
| | | self.weight = nn.Parameter(torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1)) |
| | | self.use_bias = use_bias |
| | | if self.use_bias: |
| | | self.bias = nn.Parameter(torch.Tensor(n_feat)) |
| | |
| | | self.kernel_mask = self.kernel_mask.to(x.device) |
| | | weight = weight.masked_fill(self.kernel_mask == 0.0, float("-inf")) |
| | | weight = F.softmax(weight, dim=-1) |
| | | x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view( |
| | | B, C, T |
| | | ) |
| | | x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view(B, C, T) |
| | | if self.use_bias: |
| | | x = x + self.bias.view(1, -1, 1) |
| | | x = x.transpose(1, 2) # B x T x C |