| | |
| | | import math |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import einsum |
| | | |
| | | def _pre_hook( |
| | | state_dict, |
| | |
| | | def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32): |
| | | batch_size = positions.size(0) |
| | | positions = positions.type(dtype) |
| | | log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1) |
| | | inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment)) |
| | | device = positions.device |
| | | log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (depth / 2 - 1) |
| | | inv_timescales = torch.exp(torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)) |
| | | inv_timescales = torch.reshape(inv_timescales, [batch_size, -1]) |
| | | scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1]) |
| | | encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) |
| | |
| | | |
| | | def forward(self, x): |
| | | batch_size, timesteps, input_dim = x.size() |
| | | positions = torch.arange(1, timesteps+1)[None, :] |
| | | positions = torch.arange(1, timesteps+1, device=x.device)[None, :] |
| | | position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) |
| | | |
| | | return x + position_encoding |
| | |
| | | pos_enc = self.dropout(pos_enc) |
| | | |
| | | return pos_enc |
| | | |
| | | |
| | | class ScaledSinuEmbedding(torch.nn.Module): |
| | | def __init__(self, dim): |
| | | super().__init__() |
| | | self.scale = torch.nn.Parameter(torch.ones(1,)) |
| | | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| | | self.register_buffer('inv_freq', inv_freq) |
| | | |
| | | def forward(self, x): |
| | | n, device = x.shape[1], x.device |
| | | t = torch.arange(n, device = device).type_as(self.inv_freq) |
| | | sinu = einsum('i , j -> i j', t, self.inv_freq) |
| | | emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1) |
| | | return emb * self.scale |
| | | |