From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/data2vec/quant_noise.py | 20 +++++++-------------
1 files changed, 7 insertions(+), 13 deletions(-)
diff --git a/funasr/models/data2vec/quant_noise.py b/funasr/models/data2vec/quant_noise.py
index 11a82b6..50b2c15 100644
--- a/funasr/models/data2vec/quant_noise.py
+++ b/funasr/models/data2vec/quant_noise.py
@@ -40,7 +40,7 @@
# 2D matrix
if not is_conv:
assert (
- module.weight.size(1) % block_size == 0
+ module.weight.size(1) % block_size == 0
), "Input features must be a multiple of block sizes"
# 4D matrix
@@ -48,7 +48,7 @@
# 1x1 convolutions
if module.kernel_size == (1, 1):
assert (
- module.in_channels % block_size == 0
+ module.in_channels % block_size == 0
), "Input channels must be a multiple of block sizes"
# regular convolutions
else:
@@ -65,9 +65,7 @@
out_features = weight.size(0)
# split weight matrix into blocks and randomly drop selected blocks
- mask = torch.zeros(
- in_features // block_size * out_features, device=weight.device
- )
+ mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
@@ -86,20 +84,16 @@
mask.bernoulli_(p)
mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
else:
- mask = torch.zeros(
- weight.size(0), weight.size(1), device=weight.device
- )
+ mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
mask.bernoulli_(p)
mask = (
mask.unsqueeze(2)
- .unsqueeze(3)
- .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+ .unsqueeze(3)
+ .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
)
# scale weights and apply mask
- mask = mask.to(
- torch.bool
- ) # x.bool() is not currently supported in TorchScript
+ mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript
s = 1 / (1 - p)
mod.weight.data = s * weight.masked_fill(mask, 0)
--
Gitblit v1.9.1