From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/data2vec/data2vec_encoder.py |  204 ++++++++++++++++++++++++---------------------------
 1 files changed, 96 insertions(+), 108 deletions(-)

diff --git a/funasr/models/data2vec/data2vec_encoder.py b/funasr/models/data2vec/data2vec_encoder.py
index 1bcb639..f591dd6 100644
--- a/funasr/models/data2vec/data2vec_encoder.py
+++ b/funasr/models/data2vec/data2vec_encoder.py
@@ -11,7 +11,6 @@
 import torch.nn as nn
 import torch.nn.functional as F
 
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.data2vec.data_utils import compute_mask_indices
 from funasr.models.data2vec.ema_module import EMAModule
 from funasr.models.data2vec.grad_multiply import GradMultiply
@@ -28,73 +27,73 @@
     return end - r * pct_remaining
 
 
-class Data2VecEncoder(AbsEncoder):
+class Data2VecEncoder(nn.Module):
     def __init__(
-            self,
-            # for ConvFeatureExtractionModel
-            input_size: int = None,
-            extractor_mode: str = None,
-            conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
-            # for Transformer Encoder
-            ## model architecture
-            layer_type: str = "transformer",
-            layer_norm_first: bool = False,
-            encoder_layers: int = 12,
-            encoder_embed_dim: int = 768,
-            encoder_ffn_embed_dim: int = 3072,
-            encoder_attention_heads: int = 12,
-            activation_fn: str = "gelu",
-            ## dropouts
-            dropout: float = 0.1,
-            attention_dropout: float = 0.1,
-            activation_dropout: float = 0.0,
-            encoder_layerdrop: float = 0.0,
-            dropout_input: float = 0.0,
-            dropout_features: float = 0.0,
-            ## grad settings
-            feature_grad_mult: float = 1.0,
-            ## masking
-            mask_prob: float = 0.65,
-            mask_length: int = 10,
-            mask_selection: str = "static",
-            mask_other: int = 0,
-            no_mask_overlap: bool = False,
-            mask_min_space: int = 1,
-            require_same_masks: bool = True,  # if set as True, collate_fn should be clipping
-            mask_dropout: float = 0.0,
-            ## channel masking
-            mask_channel_length: int = 10,
-            mask_channel_prob: float = 0.0,
-            mask_channel_before: bool = False,
-            mask_channel_selection: str = "static",
-            mask_channel_other: int = 0,
-            no_mask_channel_overlap: bool = False,
-            mask_channel_min_space: int = 1,
-            ## positional embeddings
-            conv_pos: int = 128,
-            conv_pos_groups: int = 16,
-            pos_conv_depth: int = 1,
-            max_positions: int = 100000,
-            # EMA module
-            average_top_k_layers: int = 8,
-            layer_norm_target_layer: bool = False,
-            instance_norm_target_layer: bool = False,
-            instance_norm_targets: bool = False,
-            layer_norm_targets: bool = False,
-            batch_norm_target_layer: bool = False,
-            group_norm_target_layer: bool = False,
-            ema_decay: float = 0.999,
-            ema_end_decay: float = 0.9999,
-            ema_anneal_end_step: int = 100000,
-            ema_transformer_only: bool = True,
-            ema_layers_only: bool = True,
-            min_target_var: float = 0.1,
-            min_pred_var: float = 0.01,
-            # Loss
-            loss_beta: float = 0.0,
-            loss_scale: float = None,
-            # FP16 optimization
-            required_seq_len_multiple: int = 2,
+        self,
+        # for ConvFeatureExtractionModel
+        input_size: int = None,
+        extractor_mode: str = None,
+        conv_feature_layers: str = "[(512,2,2)] + [(512,2,2)]",
+        # for Transformer Encoder
+        ## model architecture
+        layer_type: str = "transformer",
+        layer_norm_first: bool = False,
+        encoder_layers: int = 12,
+        encoder_embed_dim: int = 768,
+        encoder_ffn_embed_dim: int = 3072,
+        encoder_attention_heads: int = 12,
+        activation_fn: str = "gelu",
+        ## dropouts
+        dropout: float = 0.1,
+        attention_dropout: float = 0.1,
+        activation_dropout: float = 0.0,
+        encoder_layerdrop: float = 0.0,
+        dropout_input: float = 0.0,
+        dropout_features: float = 0.0,
+        ## grad settings
+        feature_grad_mult: float = 1.0,
+        ## masking
+        mask_prob: float = 0.65,
+        mask_length: int = 10,
+        mask_selection: str = "static",
+        mask_other: int = 0,
+        no_mask_overlap: bool = False,
+        mask_min_space: int = 1,
+        require_same_masks: bool = True,  # if set as True, collate_fn should be clipping
+        mask_dropout: float = 0.0,
+        ## channel masking
+        mask_channel_length: int = 10,
+        mask_channel_prob: float = 0.0,
+        mask_channel_before: bool = False,
+        mask_channel_selection: str = "static",
+        mask_channel_other: int = 0,
+        no_mask_channel_overlap: bool = False,
+        mask_channel_min_space: int = 1,
+        ## positional embeddings
+        conv_pos: int = 128,
+        conv_pos_groups: int = 16,
+        pos_conv_depth: int = 1,
+        max_positions: int = 100000,
+        # EMA module
+        average_top_k_layers: int = 8,
+        layer_norm_target_layer: bool = False,
+        instance_norm_target_layer: bool = False,
+        instance_norm_targets: bool = False,
+        layer_norm_targets: bool = False,
+        batch_norm_target_layer: bool = False,
+        group_norm_target_layer: bool = False,
+        ema_decay: float = 0.999,
+        ema_end_decay: float = 0.9999,
+        ema_anneal_end_step: int = 100000,
+        ema_transformer_only: bool = True,
+        ema_layers_only: bool = True,
+        min_target_var: float = 0.1,
+        min_pred_var: float = 0.01,
+        # Loss
+        loss_beta: float = 0.0,
+        loss_scale: float = None,
+        # FP16 optimization
+        required_seq_len_multiple: int = 2,
     ):
         super().__init__()
 
@@ -134,7 +133,9 @@
         self.mask_other = mask_other
         self.no_mask_overlap = no_mask_overlap
         self.mask_min_space = mask_min_space
-        self.require_same_masks = require_same_masks  # if set as True, collate_fn should be clipping
+        self.require_same_masks = (
+            require_same_masks  # if set as True, collate_fn should be clipping
+        )
         self.mask_dropout = mask_dropout
         ## channel masking
         self.mask_channel_length = mask_channel_length
@@ -240,11 +241,11 @@
         self.num_updates = num_updates
 
     def apply_mask(
-            self,
-            x,
-            padding_mask,
-            mask_indices=None,
-            mask_channel_indices=None,
+        self,
+        x,
+        padding_mask,
+        mask_indices=None,
+        mask_channel_indices=None,
     ):
         B, T, C = x.shape
 
@@ -260,10 +261,7 @@
                 min_space=self.mask_channel_min_space,
             )
             mask_channel_indices = (
-                torch.from_numpy(mask_channel_indices)
-                    .to(x.device)
-                    .unsqueeze(1)
-                    .expand(-1, T, -1)
+                torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1)
             )
             x[mask_channel_indices] = 0
 
@@ -301,9 +299,9 @@
                 )
                 mask_channel_indices = (
                     torch.from_numpy(mask_channel_indices)
-                        .to(x.device)
-                        .unsqueeze(1)
-                        .expand(-1, T, -1)
+                    .to(x.device)
+                    .unsqueeze(1)
+                    .expand(-1, T, -1)
                 )
             x[mask_channel_indices] = 0
 
@@ -327,15 +325,15 @@
         return input_lengths.to(torch.long)
 
     def forward(
-            self,
-            xs_pad,
-            ilens=None,
-            mask=False,
-            features_only=True,
-            layer=None,
-            mask_indices=None,
-            mask_channel_indices=None,
-            padding_count=None,
+        self,
+        xs_pad,
+        ilens=None,
+        mask=False,
+        features_only=True,
+        layer=None,
+        mask_indices=None,
+        mask_channel_indices=None,
+        padding_count=None,
     ):
         # create padding_mask by ilens
         if ilens is not None:
@@ -447,16 +445,12 @@
 
             if self.batch_norm_target_layer:
                 target_layer_results = [
-                    F.batch_norm(
-                        tl.float(), running_mean=None, running_var=None, training=True
-                    )
+                    F.batch_norm(tl.float(), running_mean=None, running_var=None, training=True)
                     for tl in target_layer_results
                 ]
 
             if self.instance_norm_target_layer:
-                target_layer_results = [
-                    F.instance_norm(tl.float()) for tl in target_layer_results
-                ]
+                target_layer_results = [F.instance_norm(tl.float()) for tl in target_layer_results]
 
             if permuted:
                 target_layer_results = [
@@ -465,14 +459,12 @@
 
             if self.group_norm_target_layer:
                 target_layer_results = [
-                    F.layer_norm(tl.float(), tl.shape[-2:])
-                    for tl in target_layer_results
+                    F.layer_norm(tl.float(), tl.shape[-2:]) for tl in target_layer_results
                 ]
 
             if self.layer_norm_target_layer:
                 target_layer_results = [
-                    F.layer_norm(tl.float(), tl.shape[-1:])
-                    for tl in target_layer_results
+                    F.layer_norm(tl.float(), tl.shape[-1:]) for tl in target_layer_results
                 ]
 
             y = sum(target_layer_results) / len(target_layer_results)
@@ -522,9 +514,7 @@
                 f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting"
             )
         if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var:
-            logging.error(
-                f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
-            )
+            logging.error(f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting")
             raise Exception(
                 f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting"
             )
@@ -540,20 +530,18 @@
         if dist.is_initialized():
             zc = torch.tensor(y.size(0)).cuda()
             zs = y.sum(dim=0)
-            zss = (y ** 2).sum(dim=0)
+            zss = (y**2).sum(dim=0)
 
             dist.all_reduce(zc)
             dist.all_reduce(zs)
             dist.all_reduce(zss)
 
-            var = zss / (zc - 1) - (zs ** 2) / (zc * (zc - 1))
+            var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1))
             return torch.sqrt(var + 1e-6).mean()
         else:
             return torch.sqrt(y.var(dim=0) + 1e-6).mean()
 
-    def extract_features(
-            self, xs_pad, ilens, mask=False, layer=None
-    ):
+    def extract_features(self, xs_pad, ilens, mask=False, layer=None):
         res = self.forward(
             xs_pad,
             ilens,
@@ -572,4 +560,4 @@
             )
 
     def output_size(self) -> int:
-        return self.encoder_embed_dim
\ No newline at end of file
+        return self.encoder_embed_dim

--
Gitblit v1.9.1