From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/data2vec/data_utils.py | 32 +++++++++++++-------------------
1 files changed, 13 insertions(+), 19 deletions(-)
diff --git a/funasr/models/data2vec/data_utils.py b/funasr/models/data2vec/data_utils.py
index abd6d51..f486edf 100644
--- a/funasr/models/data2vec/data_utils.py
+++ b/funasr/models/data2vec/data_utils.py
@@ -11,17 +11,17 @@
def compute_mask_indices(
- shape: Tuple[int, int],
- padding_mask: Optional[torch.Tensor],
- mask_prob: float,
- mask_length: int,
- mask_type: str = "static",
- mask_other: float = 0.0,
- min_masks: int = 0,
- no_overlap: bool = False,
- min_space: int = 0,
- require_same_masks: bool = True,
- mask_dropout: float = 0.0,
+ shape: Tuple[int, int],
+ padding_mask: Optional[torch.Tensor],
+ mask_prob: float,
+ mask_length: int,
+ mask_type: str = "static",
+ mask_other: float = 0.0,
+ min_masks: int = 0,
+ no_overlap: bool = False,
+ min_space: int = 0,
+ require_same_masks: bool = True,
+ mask_dropout: float = 0.0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape
@@ -123,11 +123,7 @@
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray(
- [
- mask_idc[j] + offset
- for j in range(len(mask_idc))
- for offset in range(lengths[j])
- ]
+ [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]
)
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
@@ -138,9 +134,7 @@
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
if mask_dropout > 0:
num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int)
- mask_idc = np.random.choice(
- mask_idc, len(mask_idc) - num_holes, replace=False
- )
+ mask_idc = np.random.choice(mask_idc, len(mask_idc) - num_holes, replace=False)
mask[i, mask_idc] = True
--
Gitblit v1.9.1