| | |
| | | import math |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from torch import einsum |
| | | |
| | | def _pre_hook( |
| | | state_dict, |
| | |
| | | pos_enc = self.dropout(pos_enc) |
| | | |
| | | return pos_enc |
| | | |
| | | |
| | | class ScaledSinuEmbedding(torch.nn.Module): |
| | | def __init__(self, dim): |
| | | super().__init__() |
| | | self.scale = torch.nn.Parameter(torch.ones(1,)) |
| | | inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim)) |
| | | self.register_buffer('inv_freq', inv_freq) |
| | | |
| | | def forward(self, x): |
| | | n, device = x.shape[1], x.device |
| | | t = torch.arange(n, device = device).type_as(self.inv_freq) |
| | | sinu = einsum('i , j -> i j', t, self.inv_freq) |
| | | emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1) |
| | | return emb * self.scale |
| | | |