From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/data2vec/wav2vec2.py |  116 +++++++++++++++++++++++++++------------------------------
 1 files changed, 55 insertions(+), 61 deletions(-)

diff --git a/funasr/models/data2vec/wav2vec2.py b/funasr/models/data2vec/wav2vec2.py
index 6631c57..a3e837a 100644
--- a/funasr/models/data2vec/wav2vec2.py
+++ b/funasr/models/data2vec/wav2vec2.py
@@ -18,25 +18,25 @@
 
 class ConvFeatureExtractionModel(nn.Module):
     def __init__(
-            self,
-            conv_layers: List[Tuple[int, int, int]],
-            dropout: float = 0.0,
-            mode: str = "default",
-            conv_bias: bool = False,
-            in_d: int = 1
+        self,
+        conv_layers: List[Tuple[int, int, int]],
+        dropout: float = 0.0,
+        mode: str = "default",
+        conv_bias: bool = False,
+        in_d: int = 1,
     ):
         super().__init__()
 
         assert mode in {"default", "layer_norm"}
 
         def block(
-                n_in,
-                n_out,
-                k,
-                stride,
-                is_layer_norm=False,
-                is_group_norm=False,
-                conv_bias=False,
+            n_in,
+            n_out,
+            k,
+            stride,
+            is_layer_norm=False,
+            is_group_norm=False,
+            conv_bias=False,
         ):
             def make_conv():
                 conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
@@ -44,8 +44,8 @@
                 return conv
 
             assert (
-                           is_layer_norm and is_group_norm
-                   ) == False, "layer norm and group norm are exclusive"
+                is_layer_norm and is_group_norm
+            ) == False, "layer norm and group norm are exclusive"
 
             if is_layer_norm:
                 return nn.Sequential(
@@ -134,25 +134,25 @@
         return layer
 
     def __init__(
-            self,
-            # position
-            dropout,
-            encoder_embed_dim,
-            required_seq_len_multiple,
-            pos_conv_depth,
-            conv_pos,
-            conv_pos_groups,
-            # transformer layers
-            layer_type,
-            encoder_layers,
-            encoder_ffn_embed_dim,
-            encoder_attention_heads,
-            attention_dropout,
-            activation_dropout,
-            activation_fn,
-            layer_norm_first,
-            encoder_layerdrop,
-            max_positions,
+        self,
+        # position
+        dropout,
+        encoder_embed_dim,
+        required_seq_len_multiple,
+        pos_conv_depth,
+        conv_pos,
+        conv_pos_groups,
+        # transformer layers
+        layer_type,
+        encoder_layers,
+        encoder_ffn_embed_dim,
+        encoder_attention_heads,
+        attention_dropout,
+        activation_dropout,
+        activation_fn,
+        layer_norm_first,
+        encoder_layerdrop,
+        max_positions,
     ):
         super().__init__()
 
@@ -185,9 +185,7 @@
                     ]
                 )
 
-            self.pos_conv = make_conv_block(
-                self.embedding_dim, k, conv_pos_groups, num_layers
-            )
+            self.pos_conv = make_conv_block(self.embedding_dim, k, conv_pos_groups, num_layers)
 
         else:
             self.pos_conv = make_conv_pos(
@@ -206,9 +204,7 @@
         self.layer_norm_first = layer_norm_first
         self.layerdrop = encoder_layerdrop
         self.max_positions = max_positions
-        self.layers = nn.ModuleList(
-            [self.build_encoder_layer() for _ in range(encoder_layers)]
-        )
+        self.layers = nn.ModuleList([self.build_encoder_layer() for _ in range(encoder_layers)])
         self.layer_norm = torch.nn.LayerNorm(self.embedding_dim)
 
         self.apply(utils.init_bert_params)
@@ -222,11 +218,11 @@
         return x, layer_results
 
     def extract_features(
-            self,
-            x,
-            padding_mask=None,
-            tgt_layer=None,
-            min_layer=0,
+        self,
+        x,
+        padding_mask=None,
+        tgt_layer=None,
+        min_layer=0,
     ):
 
         if padding_mask is not None:
@@ -240,9 +236,7 @@
             x = self.layer_norm(x)
 
         # pad to the sequence length dimension
-        x, pad_length = utils.pad_to_multiple(
-            x, self.required_seq_len_multiple, dim=-2, value=0
-        )
+        x, pad_length = utils.pad_to_multiple(x, self.required_seq_len_multiple, dim=-2, value=0)
         if pad_length > 0 and padding_mask is None:
             padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool)
             padding_mask[:, -pad_length:] = True
@@ -304,15 +298,15 @@
     """
 
     def __init__(
-            self,
-            embedding_dim: int = 768,
-            ffn_embedding_dim: int = 3072,
-            num_attention_heads: int = 8,
-            dropout: float = 0.1,
-            attention_dropout: float = 0.1,
-            activation_dropout: float = 0.1,
-            activation_fn: str = "relu",
-            layer_norm_first: bool = False,
+        self,
+        embedding_dim: int = 768,
+        ffn_embedding_dim: int = 3072,
+        num_attention_heads: int = 8,
+        dropout: float = 0.1,
+        attention_dropout: float = 0.1,
+        activation_dropout: float = 0.1,
+        activation_fn: str = "relu",
+        layer_norm_first: bool = False,
     ) -> None:
 
         super().__init__()
@@ -345,10 +339,10 @@
         self.final_layer_norm = torch.nn.LayerNorm(self.embedding_dim)
 
     def forward(
-            self,
-            x: torch.Tensor,  # (T, B, C)
-            self_attn_mask: torch.Tensor = None,
-            self_attn_padding_mask: torch.Tensor = None,
+        self,
+        x: torch.Tensor,  # (T, B, C)
+        self_attn_mask: torch.Tensor = None,
+        self_attn_padding_mask: torch.Tensor = None,
     ):
         """
         LayerNorm is applied either before or after the self-attention/ffn

--
Gitblit v1.9.1