| | |
| | | |
| | | import math |
| | | import torch |
| | | |
| | | import torch.nn.functional as F |
| | | |
| | | def _pre_hook( |
| | | state_dict, |
| | |
| | | |
| | | def forward_chunk(self, x, cache=None): |
| | | start_idx = 0 |
| | | pad_left = 0 |
| | | pad_right = 0 |
| | | batch_size, timesteps, input_dim = x.size() |
| | | if cache is not None: |
| | | start_idx = cache["start_idx"] |
| | | pad_left = cache["left"] |
| | | pad_right = cache["right"] |
| | | positions = torch.arange(1, timesteps+start_idx+1)[None, :] |
| | | position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) |
| | | return x + position_encoding[:, start_idx: start_idx + timesteps] |
| | | outputs = x + position_encoding[:, start_idx: start_idx + timesteps] |
| | | outputs = outputs.transpose(1,2) |
| | | outputs = F.pad(outputs, (pad_left, pad_right)) |
| | | outputs = outputs.transpose(1,2) |
| | | return outputs |
| | | |