| | |
| | | weight_new = torch.zeros(B * H * T * (T + k - 1), dtype=weight.dtype) |
| | | weight_new = weight_new.view(B, H, T, T + k - 1).fill_(float("-inf")) |
| | | weight_new = weight_new.to(x.device) # B x H x T x T+k-1 |
| | | weight_new.as_strided( |
| | | (B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1) |
| | | ).copy_(weight) |
| | | weight_new.as_strided((B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1)).copy_( |
| | | weight |
| | | ) |
| | | weight_new = weight_new.narrow(-1, int((k - 1) / 2), T) # B x H x T x T(k) |
| | | if self.use_kernel_mask: |
| | | kernel_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0) |