From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/data2vec/quant_noise.py |   20 +++++++-------------
 1 files changed, 7 insertions(+), 13 deletions(-)

diff --git a/funasr/models/data2vec/quant_noise.py b/funasr/models/data2vec/quant_noise.py
index 11a82b6..50b2c15 100644
--- a/funasr/models/data2vec/quant_noise.py
+++ b/funasr/models/data2vec/quant_noise.py
@@ -40,7 +40,7 @@
     # 2D matrix
     if not is_conv:
         assert (
-                module.weight.size(1) % block_size == 0
+            module.weight.size(1) % block_size == 0
         ), "Input features must be a multiple of block sizes"
 
     # 4D matrix
@@ -48,7 +48,7 @@
         # 1x1 convolutions
         if module.kernel_size == (1, 1):
             assert (
-                    module.in_channels % block_size == 0
+                module.in_channels % block_size == 0
             ), "Input channels must be a multiple of block sizes"
         # regular convolutions
         else:
@@ -65,9 +65,7 @@
                 out_features = weight.size(0)
 
                 # split weight matrix into blocks and randomly drop selected blocks
-                mask = torch.zeros(
-                    in_features // block_size * out_features, device=weight.device
-                )
+                mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
                 mask.bernoulli_(p)
                 mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
 
@@ -86,20 +84,16 @@
                     mask.bernoulli_(p)
                     mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
                 else:
-                    mask = torch.zeros(
-                        weight.size(0), weight.size(1), device=weight.device
-                    )
+                    mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
                     mask.bernoulli_(p)
                     mask = (
                         mask.unsqueeze(2)
-                            .unsqueeze(3)
-                            .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
+                        .unsqueeze(3)
+                        .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
                     )
 
             # scale weights and apply mask
-            mask = mask.to(
-                torch.bool
-            )  # x.bool() is not currently supported in TorchScript
+            mask = mask.to(torch.bool)  # x.bool() is not currently supported in TorchScript
             s = 1 / (1 - p)
             mod.weight.data = s * weight.masked_fill(mask, 0)
 

--
Gitblit v1.9.1