| | |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.data2vec.data_utils import compute_mask_indices |
| | | from funasr.models.data2vec.ema_module import EMAModule |
| | | from funasr.models.data2vec.grad_multiply import GradMultiply |
| | |
| | | return end - r * pct_remaining |
| | | |
| | | |
| | | class Data2VecEncoder(AbsEncoder): |
| | | class Data2VecEncoder(nn.Module): |
| | | def __init__( |
| | | self, |
| | | # for ConvFeatureExtractionModel |
| | | input_size: int = None, |
| | | extractor_mode: str = None, |
| | | conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]", |
| | | # for Transformer Encoder |
| | | ## model architecture |
| | | layer_type: str = "transformer", |
| | | layer_norm_first: bool = False, |
| | | encoder_layers: int = 12, |
| | | encoder_embed_dim: int = 768, |
| | | encoder_ffn_embed_dim: int = 3072, |
| | | encoder_attention_heads: int = 12, |
| | | activation_fn: str = "gelu", |
| | | ## dropouts |
| | | dropout: float = 0.1, |
| | | attention_dropout: float = 0.1, |
| | | activation_dropout: float = 0.0, |
| | | encoder_layerdrop: float = 0.0, |
| | | dropout_input: float = 0.0, |
| | | dropout_features: float = 0.0, |
| | | ## grad settings |
| | | feature_grad_mult: float = 1.0, |
| | | ## masking |
| | | mask_prob: float = 0.65, |
| | | mask_length: int = 10, |
| | | mask_selection: str = "static", |
| | | mask_other: int = 0, |
| | | no_mask_overlap: bool = False, |
| | | mask_min_space: int = 1, |
| | | require_same_masks: bool = True, # if set as True, collate_fn should be clipping |
| | | mask_dropout: float = 0.0, |
| | | ## channel masking |
| | | mask_channel_length: int = 10, |
| | | mask_channel_prob: float = 0.0, |
| | | mask_channel_before: bool = False, |
| | | mask_channel_selection: str = "static", |
| | | mask_channel_other: int = 0, |
| | | no_mask_channel_overlap: bool = False, |
| | | mask_channel_min_space: int = 1, |
| | | ## positional embeddings |
| | | conv_pos: int = 128, |
| | | conv_pos_groups: int = 16, |
| | | pos_conv_depth: int = 1, |
| | | max_positions: int = 100000, |
| | | # EMA module |
| | | average_top_k_layers: int = 8, |
| | | layer_norm_target_layer: bool = False, |
| | | instance_norm_target_layer: bool = False, |
| | | instance_norm_targets: bool = False, |
| | | layer_norm_targets: bool = False, |
| | | batch_norm_target_layer: bool = False, |
| | | group_norm_target_layer: bool = False, |
| | | ema_decay: float = 0.999, |
| | | ema_end_decay: float = 0.9999, |
| | | ema_anneal_end_step: int = 100000, |
| | | ema_transformer_only: bool = True, |
| | | ema_layers_only: bool = True, |
| | | min_target_var: float = 0.1, |
| | | min_pred_var: float = 0.01, |
| | | # Loss |
| | | loss_beta: float = 0.0, |
| | | loss_scale: float = None, |
| | | # FP16 optimization |
| | | required_seq_len_multiple: int = 2, |
| | | self, |
| | | # for ConvFeatureExtractionModel |
| | | input_size: int = None, |
| | | extractor_mode: str = None, |
| | | conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]", |
| | | # for Transformer Encoder |
| | | ## model architecture |
| | | layer_type: str = "transformer", |
| | | layer_norm_first: bool = False, |
| | | encoder_layers: int = 12, |
| | | encoder_embed_dim: int = 768, |
| | | encoder_ffn_embed_dim: int = 3072, |
| | | encoder_attention_heads: int = 12, |
| | | activation_fn: str = "gelu", |
| | | ## dropouts |
| | | dropout: float = 0.1, |
| | | attention_dropout: float = 0.1, |
| | | activation_dropout: float = 0.0, |
| | | encoder_layerdrop: float = 0.0, |
| | | dropout_input: float = 0.0, |
| | | dropout_features: float = 0.0, |
| | | ## grad settings |
| | | feature_grad_mult: float = 1.0, |
| | | ## masking |
| | | mask_prob: float = 0.65, |
| | | mask_length: int = 10, |
| | | mask_selection: str = "static", |
| | | mask_other: int = 0, |
| | | no_mask_overlap: bool = False, |
| | | mask_min_space: int = 1, |
| | | require_same_masks: bool = True, # if set as True, collate_fn should be clipping |
| | | mask_dropout: float = 0.0, |
| | | ## channel masking |
| | | mask_channel_length: int = 10, |
| | | mask_channel_prob: float = 0.0, |
| | | mask_channel_before: bool = False, |
| | | mask_channel_selection: str = "static", |
| | | mask_channel_other: int = 0, |
| | | no_mask_channel_overlap: bool = False, |
| | | mask_channel_min_space: int = 1, |
| | | ## positional embeddings |
| | | conv_pos: int = 128, |
| | | conv_pos_groups: int = 16, |
| | | pos_conv_depth: int = 1, |
| | | max_positions: int = 100000, |
| | | # EMA module |
| | | average_top_k_layers: int = 8, |
| | | layer_norm_target_layer: bool = False, |
| | | instance_norm_target_layer: bool = False, |
| | | instance_norm_targets: bool = False, |
| | | layer_norm_targets: bool = False, |
| | | batch_norm_target_layer: bool = False, |
| | | group_norm_target_layer: bool = False, |
| | | ema_decay: float = 0.999, |
| | | ema_end_decay: float = 0.9999, |
| | | ema_anneal_end_step: int = 100000, |
| | | ema_transformer_only: bool = True, |
| | | ema_layers_only: bool = True, |
| | | min_target_var: float = 0.1, |
| | | min_pred_var: float = 0.01, |
| | | # Loss |
| | | loss_beta: float = 0.0, |
| | | loss_scale: float = None, |
| | | # FP16 optimization |
| | | required_seq_len_multiple: int = 2, |
| | | ): |
| | | super().__init__() |
| | | |
| | |
| | | self.mask_other = mask_other |
| | | self.no_mask_overlap = no_mask_overlap |
| | | self.mask_min_space = mask_min_space |
| | | self.require_same_masks = require_same_masks # if set as True, collate_fn should be clipping |
| | | self.require_same_masks = ( |
| | | require_same_masks # if set as True, collate_fn should be clipping |
| | | ) |
| | | self.mask_dropout = mask_dropout |
| | | ## channel masking |
| | | self.mask_channel_length = mask_channel_length |
| | |
| | | self.num_updates = num_updates |
| | | |
| | | def apply_mask( |
| | | self, |
| | | x, |
| | | padding_mask, |
| | | mask_indices=None, |
| | | mask_channel_indices=None, |
| | | self, |
| | | x, |
| | | padding_mask, |
| | | mask_indices=None, |
| | | mask_channel_indices=None, |
| | | ): |
| | | B, T, C = x.shape |
| | | |
| | |
| | | min_space=self.mask_channel_min_space, |
| | | ) |
| | | mask_channel_indices = ( |
| | | torch.from_numpy(mask_channel_indices) |
| | | .to(x.device) |
| | | .unsqueeze(1) |
| | | .expand(-1, T, -1) |
| | | torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1) |
| | | ) |
| | | x[mask_channel_indices] = 0 |
| | | |
| | |
| | | ) |
| | | mask_channel_indices = ( |
| | | torch.from_numpy(mask_channel_indices) |
| | | .to(x.device) |
| | | .unsqueeze(1) |
| | | .expand(-1, T, -1) |
| | | .to(x.device) |
| | | .unsqueeze(1) |
| | | .expand(-1, T, -1) |
| | | ) |
| | | x[mask_channel_indices] = 0 |
| | | |
| | |
| | | return input_lengths.to(torch.long) |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad, |
| | | ilens=None, |
| | | mask=False, |
| | | features_only=True, |
| | | layer=None, |
| | | mask_indices=None, |
| | | mask_channel_indices=None, |
| | | padding_count=None, |
| | | self, |
| | | xs_pad, |
| | | ilens=None, |
| | | mask=False, |
| | | features_only=True, |
| | | layer=None, |
| | | mask_indices=None, |
| | | mask_channel_indices=None, |
| | | padding_count=None, |
| | | ): |
| | | # create padding_mask by ilens |
| | | if ilens is not None: |
| | |
| | | |
| | | if self.batch_norm_target_layer: |
| | | target_layer_results = [ |
| | | F.batch_norm( |
| | | tl.float(), running_mean=None, running_var=None, training=True |
| | | ) |
| | | F.batch_norm(tl.float(), running_mean=None, running_var=None, training=True) |
| | | for tl in target_layer_results |
| | | ] |
| | | |
| | | if self.instance_norm_target_layer: |
| | | target_layer_results = [ |
| | | F.instance_norm(tl.float()) for tl in target_layer_results |
| | | ] |
| | | target_layer_results = [F.instance_norm(tl.float()) for tl in target_layer_results] |
| | | |
| | | if permuted: |
| | | target_layer_results = [ |
| | |
| | | |
| | | if self.group_norm_target_layer: |
| | | target_layer_results = [ |
| | | F.layer_norm(tl.float(), tl.shape[-2:]) |
| | | for tl in target_layer_results |
| | | F.layer_norm(tl.float(), tl.shape[-2:]) for tl in target_layer_results |
| | | ] |
| | | |
| | | if self.layer_norm_target_layer: |
| | | target_layer_results = [ |
| | | F.layer_norm(tl.float(), tl.shape[-1:]) |
| | | for tl in target_layer_results |
| | | F.layer_norm(tl.float(), tl.shape[-1:]) for tl in target_layer_results |
| | | ] |
| | | |
| | | y = sum(target_layer_results) / len(target_layer_results) |
| | |
| | | f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting" |
| | | ) |
| | | if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var: |
| | | logging.error( |
| | | f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting" |
| | | ) |
| | | logging.error(f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting") |
| | | raise Exception( |
| | | f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting" |
| | | ) |
| | |
| | | if dist.is_initialized(): |
| | | zc = torch.tensor(y.size(0)).cuda() |
| | | zs = y.sum(dim=0) |
| | | zss = (y ** 2).sum(dim=0) |
| | | zss = (y**2).sum(dim=0) |
| | | |
| | | dist.all_reduce(zc) |
| | | dist.all_reduce(zs) |
| | | dist.all_reduce(zss) |
| | | |
| | | var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1)) |
| | | var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1)) |
| | | return torch.sqrt(var + 1e-6).mean() |
| | | else: |
| | | return torch.sqrt(y.var(dim=0) + 1e-6).mean() |
| | | |
| | | def extract_features( |
| | | self, xs_pad, ilens, mask=False, layer=None |
| | | ): |
| | | def extract_features(self, xs_pad, ilens, mask=False, layer=None): |
| | | res = self.forward( |
| | | xs_pad, |
| | | ilens, |
| | |
| | | ) |
| | | |
| | | def output_size(self) -> int: |
| | | return self.encoder_embed_dim |
| | | return self.encoder_embed_dim |