From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/data2vec/data2vec.py |   36 ++++++++++++++++--------------------
 1 files changed, 16 insertions(+), 20 deletions(-)

diff --git a/funasr/models/data2vec/data2vec.py b/funasr/models/data2vec/data2vec.py
index c77cedf..ed58d1c 100644
--- a/funasr/models/data2vec/data2vec.py
+++ b/funasr/models/data2vec/data2vec.py
@@ -16,6 +16,7 @@
 # from funasr.models.base_model import FunASRModel
 # from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.frontends.abs_frontend import AbsFrontend
+
 # from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
 # from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.train_utils.device_funcs import force_gatherable
@@ -33,12 +34,12 @@
     """Data2Vec Pretrain model"""
 
     def __init__(
-            self,
-            frontend = None,
-            specaug = None,
-            normalize = None,
-            encoder = None,
-            preencoder = None,
+        self,
+        frontend=None,
+        specaug=None,
+        normalize=None,
+        encoder=None,
+        preencoder=None,
     ):
 
         super().__init__()
@@ -51,9 +52,9 @@
         self.num_updates = 0
 
     def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Calc loss
         Args:
@@ -61,10 +62,7 @@
             speech_lengths: (Batch, )
         """
         # Check that batch_size is unified
-        assert (
-                speech.shape[0]
-                == speech_lengths.shape[0]
-        ), (speech.shape, speech_lengths.shape)
+        assert speech.shape[0] == speech_lengths.shape[0], (speech.shape, speech_lengths.shape)
 
         self.encoder.set_num_updates(self.num_updates)
 
@@ -91,17 +89,15 @@
         return loss, stats, weight
 
     def collect_feats(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Dict[str, torch.Tensor]:
         feats, feats_lengths = self._extract_feats(speech, speech_lengths)
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
     ):
         """Frontend + Encoder.
         Args:
@@ -132,7 +128,7 @@
         return encoder_out
 
     def _extract_feats(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         assert speech_lengths.dim() == 1, speech_lengths.shape
 

--
Gitblit v1.9.1