From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/data2vec/data_utils.py |   32 +++++++++++++-------------------
 1 files changed, 13 insertions(+), 19 deletions(-)

diff --git a/funasr/models/data2vec/data_utils.py b/funasr/models/data2vec/data_utils.py
index abd6d51..f486edf 100644
--- a/funasr/models/data2vec/data_utils.py
+++ b/funasr/models/data2vec/data_utils.py
@@ -11,17 +11,17 @@
 
 
 def compute_mask_indices(
-        shape: Tuple[int, int],
-        padding_mask: Optional[torch.Tensor],
-        mask_prob: float,
-        mask_length: int,
-        mask_type: str = "static",
-        mask_other: float = 0.0,
-        min_masks: int = 0,
-        no_overlap: bool = False,
-        min_space: int = 0,
-        require_same_masks: bool = True,
-        mask_dropout: float = 0.0,
+    shape: Tuple[int, int],
+    padding_mask: Optional[torch.Tensor],
+    mask_prob: float,
+    mask_length: int,
+    mask_type: str = "static",
+    mask_other: float = 0.0,
+    min_masks: int = 0,
+    no_overlap: bool = False,
+    min_space: int = 0,
+    require_same_masks: bool = True,
+    mask_dropout: float = 0.0,
 ) -> np.ndarray:
     """
     Computes random mask spans for a given shape
@@ -123,11 +123,7 @@
             mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
 
             mask_idc = np.asarray(
-                [
-                    mask_idc[j] + offset
-                    for j in range(len(mask_idc))
-                    for offset in range(lengths[j])
-                ]
+                [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]
             )
 
         mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
@@ -138,9 +134,7 @@
             mask_idc = np.random.choice(mask_idc, min_len, replace=False)
         if mask_dropout > 0:
             num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
-            mask_idc = np.random.choice(
-                mask_idc, len(mask_idc) - num_holes, replace=False
-            )
+            mask_idc = np.random.choice(mask_idc, len(mask_idc) - num_holes, replace=False)
 
         mask[i, mask_idc] = True
 

--
Gitblit v1.9.1