| | |
| | | # convolution of frequency axis |
| | | weight_f = self.linear_weight_f(x).view(B * T, 1, k) # B x T x k |
| | | self.attn_f = weight_f.view(B, T, k).unsqueeze(1) |
| | | xf = F.conv1d( |
| | | x.view(1, B * T, C), weight_f, padding=self.padding_size, groups=B * T |
| | | ) |
| | | xf = F.conv1d(x.view(1, B * T, C), weight_f, padding=self.padding_size, groups=B * T) |
| | | xf = xf.view(B, T, C) |
| | | |
| | | # get kernel of convolution |
| | |
| | | weight_new = torch.zeros(B * H * T * (T + k - 1), dtype=weight.dtype) |
| | | weight_new = weight_new.view(B, H, T, T + k - 1).fill_(float("-inf")) |
| | | weight_new = weight_new.to(x.device) # B x H x T x T+k-1 |
| | | weight_new.as_strided( |
| | | (B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1) |
| | | ).copy_(weight) |
| | | weight_new.as_strided((B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1)).copy_( |
| | | weight |
| | | ) |
| | | weight_new = weight_new.narrow(-1, int((k - 1) / 2), T) # B x H x T x T(k) |
| | | if self.use_kernel_mask: |
| | | kernel_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0) |