kongdeqiang
9 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/data2vec/data2vec_encoder.py
@@ -11,7 +11,6 @@
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.data2vec.data_utils import compute_mask_indices
from funasr.models.data2vec.ema_module import EMAModule
from funasr.models.data2vec.grad_multiply import GradMultiply
@@ -28,73 +27,73 @@
    return end - r * pct_remaining
class Data2VecEncoder(AbsEncoder):
class Data2VecEncoder(nn.Module):
    def __init__(
            self,
            # for ConvFeatureExtractionModel
            input_size: int = None,
            extractor_mode: str = None,
            conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
            # for Transformer Encoder
            ## model architecture
            layer_type: str = "transformer",
            layer_norm_first: bool = False,
            encoder_layers: int = 12,
            encoder_embed_dim: int = 768,
            encoder_ffn_embed_dim: int = 3072,
            encoder_attention_heads: int = 12,
            activation_fn: str = "gelu",
            ## dropouts
            dropout: float = 0.1,
            attention_dropout: float = 0.1,
            activation_dropout: float = 0.0,
            encoder_layerdrop: float = 0.0,
            dropout_input: float = 0.0,
            dropout_features: float = 0.0,
            ## grad settings
            feature_grad_mult: float = 1.0,
            ## masking
            mask_prob: float = 0.65,
            mask_length: int = 10,
            mask_selection: str = "static",
            mask_other: int = 0,
            no_mask_overlap: bool = False,
            mask_min_space: int = 1,
            require_same_masks: bool = True,  # if set as True, collate_fn should be clipping
            mask_dropout: float = 0.0,
            ## channel masking
            mask_channel_length: int = 10,
            mask_channel_prob: float = 0.0,
            mask_channel_before: bool = False,
            mask_channel_selection: str = "static",
            mask_channel_other: int = 0,
            no_mask_channel_overlap: bool = False,
            mask_channel_min_space: int = 1,
            ## positional embeddings
            conv_pos: int = 128,
            conv_pos_groups: int = 16,
            pos_conv_depth: int = 1,
            max_positions: int = 100000,
            # EMA module
            average_top_k_layers: int = 8,
            layer_norm_target_layer: bool = False,
            instance_norm_target_layer: bool = False,
            instance_norm_targets: bool = False,
            layer_norm_targets: bool = False,
            batch_norm_target_layer: bool = False,
            group_norm_target_layer: bool = False,
            ema_decay: float = 0.999,
            ema_end_decay: float = 0.9999,
            ema_anneal_end_step: int = 100000,
            ema_transformer_only: bool = True,
            ema_layers_only: bool = True,
            min_target_var: float = 0.1,
            min_pred_var: float = 0.01,
            # Loss
            loss_beta: float = 0.0,
            loss_scale: float = None,
            # FP16 optimization
            required_seq_len_multiple: int = 2,
        self,
        # for ConvFeatureExtractionModel
        input_size: int = None,
        extractor_mode: str = None,
        conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
        # for Transformer Encoder
        ## model architecture
        layer_type: str = "transformer",
        layer_norm_first: bool = False,
        encoder_layers: int = 12,
        encoder_embed_dim: int = 768,
        encoder_ffn_embed_dim: int = 3072,
        encoder_attention_heads: int = 12,
        activation_fn: str = "gelu",
        ## dropouts
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        activation_dropout: float = 0.0,
        encoder_layerdrop: float = 0.0,
        dropout_input: float = 0.0,
        dropout_features: float = 0.0,
        ## grad settings
        feature_grad_mult: float = 1.0,
        ## masking
        mask_prob: float = 0.65,
        mask_length: int = 10,
        mask_selection: str = "static",
        mask_other: int = 0,
        no_mask_overlap: bool = False,
        mask_min_space: int = 1,
        require_same_masks: bool = True,  # if set as True, collate_fn should be clipping
        mask_dropout: float = 0.0,
        ## channel masking
        mask_channel_length: int = 10,
        mask_channel_prob: float = 0.0,
        mask_channel_before: bool = False,
        mask_channel_selection: str = "static",
        mask_channel_other: int = 0,
        no_mask_channel_overlap: bool = False,
        mask_channel_min_space: int = 1,
        ## positional embeddings
        conv_pos: int = 128,
        conv_pos_groups: int = 16,
        pos_conv_depth: int = 1,
        max_positions: int = 100000,
        # EMA module
        average_top_k_layers: int = 8,
        layer_norm_target_layer: bool = False,
        instance_norm_target_layer: bool = False,
        instance_norm_targets: bool = False,
        layer_norm_targets: bool = False,
        batch_norm_target_layer: bool = False,
        group_norm_target_layer: bool = False,
        ema_decay: float = 0.999,
        ema_end_decay: float = 0.9999,
        ema_anneal_end_step: int = 100000,
        ema_transformer_only: bool = True,
        ema_layers_only: bool = True,
        min_target_var: float = 0.1,
        min_pred_var: float = 0.01,
        # Loss
        loss_beta: float = 0.0,
        loss_scale: float = None,
        # FP16 optimization
        required_seq_len_multiple: int = 2,
    ):
        super().__init__()
@@ -134,7 +133,9 @@
        self.mask_other = mask_other
        self.no_mask_overlap = no_mask_overlap
        self.mask_min_space = mask_min_space
        self.require_same_masks = require_same_masks  # if set as True, collate_fn should be clipping
        self.require_same_masks = (
            require_same_masks  # if set as True, collate_fn should be clipping
        )
        self.mask_dropout = mask_dropout
        ## channel masking
        self.mask_channel_length = mask_channel_length
@@ -240,11 +241,11 @@
        self.num_updates = num_updates
    def apply_mask(
            self,
            x,
            padding_mask,
            mask_indices=None,
            mask_channel_indices=None,
        self,
        x,
        padding_mask,
        mask_indices=None,
        mask_channel_indices=None,
    ):
        B, T, C = x.shape
@@ -260,10 +261,7 @@
                min_space=self.mask_channel_min_space,
            )
            mask_channel_indices = (
                torch.from_numpy(mask_channel_indices)
                    .to(x.device)
                    .unsqueeze(1)
                    .expand(-1, T, -1)
                torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1)
            )
            x[mask_channel_indices] = 0
@@ -301,9 +299,9 @@
                )
                mask_channel_indices = (
                    torch.from_numpy(mask_channel_indices)
                        .to(x.device)
                        .unsqueeze(1)
                        .expand(-1, T, -1)
                    .to(x.device)
                    .unsqueeze(1)
                    .expand(-1, T, -1)
                )
            x[mask_channel_indices] = 0
@@ -327,15 +325,15 @@
        return input_lengths.to(torch.long)
    def forward(
            self,
            xs_pad,
            ilens=None,
            mask=False,
            features_only=True,
            layer=None,
            mask_indices=None,
            mask_channel_indices=None,
            padding_count=None,
        self,
        xs_pad,
        ilens=None,
        mask=False,
        features_only=True,
        layer=None,
        mask_indices=None,
        mask_channel_indices=None,
        padding_count=None,
    ):
        # create padding_mask by ilens
        if ilens is not None:
@@ -447,16 +445,12 @@
            if self.batch_norm_target_layer:
                target_layer_results = [
                    F.batch_norm(
                        tl.float(), running_mean=None, running_var=None, training=True
                    )
                    F.batch_norm(tl.float(), running_mean=None, running_var=None, training=True)
                    for tl in target_layer_results
                ]
            if self.instance_norm_target_layer:
                target_layer_results = [
                    F.instance_norm(tl.float()) for tl in target_layer_results
                ]
                target_layer_results = [F.instance_norm(tl.float()) for tl in target_layer_results]
            if permuted:
                target_layer_results = [
@@ -465,14 +459,12 @@
            if self.group_norm_target_layer:
                target_layer_results = [
                    F.layer_norm(tl.float(), tl.shape[-2:])
                    for tl in target_layer_results
                    F.layer_norm(tl.float(), tl.shape[-2:]) for tl in target_layer_results
                ]
            if self.layer_norm_target_layer:
                target_layer_results = [
                    F.layer_norm(tl.float(), tl.shape[-1:])
                    for tl in target_layer_results
                    F.layer_norm(tl.float(), tl.shape[-1:]) for tl in target_layer_results
                ]
            y = sum(target_layer_results) / len(target_layer_results)
@@ -522,9 +514,7 @@
                f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
            )
        if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var:
            logging.error(
                f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
            )
            logging.error(f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting")
            raise Exception(
                f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
            )
@@ -540,20 +530,18 @@
        if dist.is_initialized():
            zc = torch.tensor(y.size(0)).cuda()
            zs = y.sum(dim=0)
            zss = (y ** 2).sum(dim=0)
            zss = (y**2).sum(dim=0)
            dist.all_reduce(zc)
            dist.all_reduce(zs)
            dist.all_reduce(zss)
            var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
            var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
            return torch.sqrt(var + 1e-6).mean()
        else:
            return torch.sqrt(y.var(dim=0) + 1e-6).mean()
    def extract_features(
            self, xs_pad, ilens, mask=False, layer=None
    ):
    def extract_features(self, xs_pad, ilens, mask=False, layer=None):
        res = self.forward(
            xs_pad,
            ilens,
@@ -572,4 +560,4 @@
            )
    def output_size(self) -> int:
        return self.encoder_embed_dim
        return self.encoder_embed_dim