From ea2c102e6162c924c682aabfe8a052ce9a766a4d Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 10 八月 2023 20:17:53 +0800
Subject: [PATCH] Merge pull request #832 from alibaba-damo-academy/dev_lhn

---
 funasr/modules/layer_norm.py |  135 +++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 135 insertions(+), 0 deletions(-)

diff --git a/funasr/modules/layer_norm.py b/funasr/modules/layer_norm.py
index 6e934e6..8683230 100644
--- a/funasr/modules/layer_norm.py
+++ b/funasr/modules/layer_norm.py
@@ -7,6 +7,7 @@
 """Layer normalization module."""
 
 import torch
+import torch.nn as nn
 
 
 class LayerNorm(torch.nn.LayerNorm):
@@ -40,3 +41,137 @@
             .forward(x.transpose(self.dim, -1))
             .transpose(self.dim, -1)
         )
+
+
+class GlobalLayerNorm(nn.Module):
+    """Calculate Global Layer Normalization.
+
+    Arguments
+    ---------
+       dim : (int or list or torch.Size)
+           Input shape from an expected input of size.
+       eps : float
+           A value added to the denominator for numerical stability.
+       elementwise_affine : bool
+          A boolean value that when set to True,
+          this module has learnable per-element affine parameters
+          initialized to ones (for weights) and zeros (for biases).
+
+    Example
+    -------
+    >>> x = torch.randn(5, 10, 20)
+    >>> GLN = GlobalLayerNorm(10, 3)
+    >>> x_norm = GLN(x)
+    """
+
+    def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
+        super(GlobalLayerNorm, self).__init__()
+        self.dim = dim
+        self.eps = eps
+        self.elementwise_affine = elementwise_affine
+
+        if self.elementwise_affine:
+            if shape == 3:
+                self.weight = nn.Parameter(torch.ones(self.dim, 1))
+                self.bias = nn.Parameter(torch.zeros(self.dim, 1))
+            if shape == 4:
+                self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
+                self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
+        else:
+            self.register_parameter("weight", None)
+            self.register_parameter("bias", None)
+
+    def forward(self, x):
+        """Returns the normalized tensor.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of size [N, C, K, S] or [N, C, L].
+        """
+        # x = N x C x K x S or N x C x L
+        # N x 1 x 1
+        # cln: mean,var N x 1 x K x S
+        # gln: mean,var N x 1 x 1
+        if x.dim() == 3:
+            mean = torch.mean(x, (1, 2), keepdim=True)
+            var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
+            if self.elementwise_affine:
+                x = (
+                    self.weight * (x - mean) / torch.sqrt(var + self.eps)
+                    + self.bias
+                )
+            else:
+                x = (x - mean) / torch.sqrt(var + self.eps)
+
+        if x.dim() == 4:
+            mean = torch.mean(x, (1, 2, 3), keepdim=True)
+            var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
+            if self.elementwise_affine:
+                x = (
+                    self.weight * (x - mean) / torch.sqrt(var + self.eps)
+                    + self.bias
+                )
+            else:
+                x = (x - mean) / torch.sqrt(var + self.eps)
+        return x
+
+
+class CumulativeLayerNorm(nn.LayerNorm):
+    """Calculate Cumulative Layer Normalization.
+
+       Arguments
+       ---------
+       dim : int
+        Dimension that you want to normalize.
+       elementwise_affine : True
+        Learnable per-element affine parameters.
+
+    Example
+    -------
+    >>> x = torch.randn(5, 10, 20)
+    >>> CLN = CumulativeLayerNorm(10)
+    >>> x_norm = CLN(x)
+    """
+
+    def __init__(self, dim, elementwise_affine=True):
+        super(CumulativeLayerNorm, self).__init__(
+            dim, elementwise_affine=elementwise_affine, eps=1e-8
+        )
+
+    def forward(self, x):
+        """Returns the normalized tensor.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor size [N, C, K, S] or [N, C, L]
+        """
+        # x: N x C x K x S or N x C x L
+        # N x K x S x C
+        if x.dim() == 4:
+            x = x.permute(0, 2, 3, 1).contiguous()
+            # N x K x S x C == only channel norm
+            x = super().forward(x)
+            # N x C x K x S
+            x = x.permute(0, 3, 1, 2).contiguous()
+        if x.dim() == 3:
+            x = torch.transpose(x, 1, 2)
+            # N x L x C == only channel norm
+            x = super().forward(x)
+            # N x C x L
+            x = torch.transpose(x, 1, 2)
+        return x
+
+
+class ScaleNorm(nn.Module):
+    def __init__(self, dim, eps = 1e-5):
+        super().__init__()
+        self.scale = dim ** -0.5
+        self.eps = eps
+        self.g = nn.Parameter(torch.ones(1))
+
+    def forward(self, x):
+        norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
+        return x / norm.clamp(min = self.eps) * self.g
+

--
Gitblit v1.9.1