| | |
| | | import torch |
| | | |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.utils.register import register_class, registry_tables |
| | | from funasr.register import tables |
| | | |
| | | @register_class("normalize_classes", "GlobalMVN") |
| | | |
| | | @tables.register("normalize_classes", "GlobalMVN") |
| | | class GlobalMVN(torch.nn.Module): |
| | | """Apply global mean and variance normalization |
| | | TODO(kamo): Make this class portable somehow |
| | |
| | | if norm_means: |
| | | x += self.mean |
| | | x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) |
| | | return x, ilens |
| | | return x, ilens |