| | |
| | | |
| | | import numpy as np |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | |
| | | |
| | | class GlobalMVN(AbsNormalize, InversibleInterface): |
| | | """Apply global mean and variance normalization |
| | | |
| | | TODO(kamo): Make this class portable somehow |
| | | |
| | | Args: |
| | | stats_file: npy file |
| | | norm_means: Apply mean normalization |
| | |
| | | norm_vars: bool = True, |
| | | eps: float = 1.0e-20, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self.norm_means = norm_means |
| | | self.norm_vars = norm_vars |
| | |
| | | self, x: torch.Tensor, ilens: torch.Tensor = None |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward function |
| | | |
| | | Args: |
| | | x: (B, L, ...) |
| | | ilens: (B,) |
| | |
| | | if norm_means: |
| | | x += self.mean |
| | | x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) |
| | | return x, ilens |
| | | return x, ilens |