| | |
| | | """ |
| | | if self.dim == -1: |
| | | return super(LayerNorm, self).forward(x) |
| | | return ( |
| | | super(LayerNorm, self) |
| | | .forward(x.transpose(self.dim, -1)) |
| | | .transpose(self.dim, -1) |
| | | ) |
| | | return super(LayerNorm, self).forward(x.transpose(self.dim, -1)).transpose(self.dim, -1) |
| | | |
| | | |
| | | class GlobalLayerNorm(nn.Module): |
| | |
| | | mean = torch.mean(x, (1, 2), keepdim=True) |
| | | var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True) |
| | | if self.elementwise_affine: |
| | | x = ( |
| | | self.weight * (x - mean) / torch.sqrt(var + self.eps) |
| | | + self.bias |
| | | ) |
| | | x = self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias |
| | | else: |
| | | x = (x - mean) / torch.sqrt(var + self.eps) |
| | | |
| | |
| | | mean = torch.mean(x, (1, 2, 3), keepdim=True) |
| | | var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True) |
| | | if self.elementwise_affine: |
| | | x = ( |
| | | self.weight * (x - mean) / torch.sqrt(var + self.eps) |
| | | + self.bias |
| | | ) |
| | | x = self.weight * (x - mean) / torch.sqrt(var + self.eps) + self.bias |
| | | else: |
| | | x = (x - mean) / torch.sqrt(var + self.eps) |
| | | return x |
| | |
| | | |
| | | |
| | | class ScaleNorm(nn.Module): |
| | | def __init__(self, dim, eps = 1e-5): |
| | | def __init__(self, dim, eps=1e-5): |
| | | super().__init__() |
| | | self.scale = dim ** -0.5 |
| | | self.scale = dim**-0.5 |
| | | self.eps = eps |
| | | self.g = nn.Parameter(torch.ones(1)) |
| | | |
| | | def forward(self, x): |
| | | norm = torch.norm(x, dim = -1, keepdim = True) * self.scale |
| | | return x / norm.clamp(min = self.eps) * self.g |
| | | |
| | | norm = torch.norm(x, dim=-1, keepdim=True) * self.scale |
| | | return x / norm.clamp(min=self.eps) * self.g |