| | |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.layers.inversible_interface import InversibleInterface |
| | | |
| | | class GlobalMVN(torch.nn.Module): |
| | | |
| | | class GlobalMVN(AbsNormalize, InversibleInterface): |
| | | """Apply global mean and variance normalization |
| | | |
| | | TODO(kamo): Make this class portable somehow |
| | | |
| | | Args: |
| | | stats_file: npy file |
| | | norm_means: Apply mean normalization |
| | |
| | | self, x: torch.Tensor, ilens: torch.Tensor = None |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Forward function |
| | | |
| | | Args: |
| | | x: (B, L, ...) |
| | | ilens: (B,) |
| | |
| | | if norm_means: |
| | | x += self.mean |
| | | x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) |
| | | return x, ilens |
| | | return x, ilens |