| | |
| | | """Layer normalization module.""" |
| | | |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | |
| | | class LayerNorm(torch.nn.LayerNorm): |
| | |
| | | .forward(x.transpose(self.dim, -1)) |
| | | .transpose(self.dim, -1) |
| | | ) |
| | | |
| | | |
| | | class GlobalLayerNorm(nn.Module): |
| | | """Calculate Global Layer Normalization. |
| | | |
| | | Arguments |
| | | --------- |
| | | dim : (int or list or torch.Size) |
| | | Input shape from an expected input of size. |
| | | eps : float |
| | | A value added to the denominator for numerical stability. |
| | | elementwise_affine : bool |
| | | A boolean value that when set to True, |
| | | this module has learnable per-element affine parameters |
| | | initialized to ones (for weights) and zeros (for biases). |
| | | |
| | | Example |
| | | ------- |
| | | >>> x = torch.randn(5, 10, 20) |
| | | >>> GLN = GlobalLayerNorm(10, 3) |
| | | >>> x_norm = GLN(x) |
| | | """ |
| | | |
| | | def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True): |
| | | super(GlobalLayerNorm, self).__init__() |
| | | self.dim = dim |
| | | self.eps = eps |
| | | self.elementwise_affine = elementwise_affine |
| | | |
| | | if self.elementwise_affine: |
| | | if shape == 3: |
| | | self.weight = nn.Parameter(torch.ones(self.dim, 1)) |
| | | self.bias = nn.Parameter(torch.zeros(self.dim, 1)) |
| | | if shape == 4: |
| | | self.weight = nn.Parameter(torch.ones(self.dim, 1, 1)) |
| | | self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1)) |
| | | else: |
| | | self.register_parameter("weight", None) |
| | | self.register_parameter("bias", None) |
| | | |
| | | def forward(self, x): |
| | | """Returns the normalized tensor. |
| | | |
| | | Arguments |
| | | --------- |
| | | x : torch.Tensor |
| | | Tensor of size [N, C, K, S] or [N, C, L]. |
| | | """ |
| | | # x = N x C x K x S or N x C x L |
| | | # N x 1 x 1 |
| | | # cln: mean,var N x 1 x K x S |
| | | # gln: mean,var N x 1 x 1 |
| | | if x.dim() == 3: |
| | | mean = torch.mean(x, (1, 2), keepdim=True) |
| | | var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True) |
| | | if self.elementwise_affine: |
| | | x = ( |
| | | self.weight * (x - mean) / torch.sqrt(var + self.eps) |
| | | + self.bias |
| | | ) |
| | | else: |
| | | x = (x - mean) / torch.sqrt(var + self.eps) |
| | | |
| | | if x.dim() == 4: |
| | | mean = torch.mean(x, (1, 2, 3), keepdim=True) |
| | | var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True) |
| | | if self.elementwise_affine: |
| | | x = ( |
| | | self.weight * (x - mean) / torch.sqrt(var + self.eps) |
| | | + self.bias |
| | | ) |
| | | else: |
| | | x = (x - mean) / torch.sqrt(var + self.eps) |
| | | return x |
| | | |
| | | |
| | | class CumulativeLayerNorm(nn.LayerNorm): |
| | | """Calculate Cumulative Layer Normalization. |
| | | |
| | | Arguments |
| | | --------- |
| | | dim : int |
| | | Dimension that you want to normalize. |
| | | elementwise_affine : True |
| | | Learnable per-element affine parameters. |
| | | |
| | | Example |
| | | ------- |
| | | >>> x = torch.randn(5, 10, 20) |
| | | >>> CLN = CumulativeLayerNorm(10) |
| | | >>> x_norm = CLN(x) |
| | | """ |
| | | |
| | | def __init__(self, dim, elementwise_affine=True): |
| | | super(CumulativeLayerNorm, self).__init__( |
| | | dim, elementwise_affine=elementwise_affine, eps=1e-8 |
| | | ) |
| | | |
| | | def forward(self, x): |
| | | """Returns the normalized tensor. |
| | | |
| | | Arguments |
| | | --------- |
| | | x : torch.Tensor |
| | | Tensor size [N, C, K, S] or [N, C, L] |
| | | """ |
| | | # x: N x C x K x S or N x C x L |
| | | # N x K x S x C |
| | | if x.dim() == 4: |
| | | x = x.permute(0, 2, 3, 1).contiguous() |
| | | # N x K x S x C == only channel norm |
| | | x = super().forward(x) |
| | | # N x C x K x S |
| | | x = x.permute(0, 3, 1, 2).contiguous() |
| | | if x.dim() == 3: |
| | | x = torch.transpose(x, 1, 2) |
| | | # N x L x C == only channel norm |
| | | x = super().forward(x) |
| | | # N x C x L |
| | | x = torch.transpose(x, 1, 2) |
| | | return x |
| | | |
| | | |
| | | class ScaleNorm(nn.Module): |
| | | def __init__(self, dim, eps = 1e-5): |
| | | super().__init__() |
| | | self.scale = dim ** -0.5 |
| | | self.eps = eps |
| | | self.g = nn.Parameter(torch.ones(1)) |
| | | |
| | | def forward(self, x): |
| | | norm = torch.norm(x, dim = -1, keepdim = True) * self.scale |
| | | return x / norm.clamp(min = self.eps) * self.g |
| | | |