Creates tensor in target device to avoid high CPU occupation. (#695)
| | |
| | | def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32): |
| | | batch_size = positions.size(0) |
| | | positions = positions.type(dtype) |
| | | log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1) |
| | | inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment)) |
| | | device = positions.device |
| | | log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (depth / 2 - 1) |
| | | inv_timescales = torch.exp(torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)) |
| | | inv_timescales = torch.reshape(inv_timescales, [batch_size, -1]) |
| | | scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1]) |
| | | encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) |
| | |
| | | |
| | | def forward(self, x): |
| | | batch_size, timesteps, input_dim = x.size() |
| | | positions = torch.arange(1, timesteps+1)[None, :] |
| | | positions = torch.arange(1, timesteps+1, device=x.device)[None, :] |
| | | position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) |
| | | |
| | | return x + position_encoding |