| | |
| | | return encoding.type(dtype) |
| | | |
| | | def forward(self, x, cache=None): |
| | | start_idx = 0 |
| | | pad_left = 0 |
| | | pad_right = 0 |
| | | batch_size, timesteps, input_dim = x.size() |
| | | start_idx = 0 |
| | | if cache is not None: |
| | | start_idx = cache["start_idx"] |
| | | pad_left = cache["left"] |
| | | pad_right = cache["right"] |
| | | cache["start_idx"] += timesteps |
| | | positions = torch.arange(1, timesteps+start_idx+1)[None, :] |
| | | position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) |
| | | outputs = x + position_encoding[:, start_idx: start_idx + timesteps] |
| | | outputs = outputs.transpose(1, 2) |
| | | outputs = F.pad(outputs, (pad_left, pad_right)) |
| | | outputs = outputs.transpose(1, 2) |
| | | return outputs |
| | | return x + position_encoding[:, start_idx: start_idx + timesteps] |
| | | |
| | | class StreamingRelPositionalEncoding(torch.nn.Module): |
| | | """Relative positional encoding. |