游雁
2022-11-26 c087854f71960341933a71442583dbc53d9b4e14
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
 
# Copyright 2019 Shigeki Karita
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
 
"""Layer normalization module."""
 
import torch
 
 
class LayerNorm(torch.nn.LayerNorm):
    """Layer normalization module.
 
    Args:
        nout (int): Output dim size.
        dim (int): Dimension to be normalized.
 
    """
 
    def __init__(self, nout, dim=-1):
        """Construct an LayerNorm object."""
        super(LayerNorm, self).__init__(nout, eps=1e-12)
        self.dim = dim
 
    def forward(self, x):
        """Apply layer normalization.
 
        Args:
            x (torch.Tensor): Input tensor.
 
        Returns:
            torch.Tensor: Normalized tensor.
 
        """
        if self.dim == -1:
            return super(LayerNorm, self).forward(x)
        return (
            super(LayerNorm, self)
            .forward(x.transpose(self.dim, -1))
            .transpose(self.dim, -1)
        )