| | |
| | | positions = torch.arange(1, timesteps+1)[None, :] |
| | | position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) |
| | | |
| | | return x + position_encoding |
| | | return x + position_encoding |
| | | |
| | | def forward_chunk(self, x, cache=None): |
| | | start_idx = 0 |
| | | batch_size, timesteps, input_dim = x.size() |
| | | if cache is not None: |
| | | start_idx = cache["start_idx"] |
| | | positions = torch.arange(1, timesteps+start_idx+1)[None, :] |
| | | position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) |
| | | return x + position_encoding[:, start_idx: start_idx + timesteps] |