From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/data2vec/data2vec.py |   51 ++++++++++++++++++++++++---------------------------
 1 files changed, 24 insertions(+), 27 deletions(-)

diff --git a/funasr/models/data2vec/data2vec.py b/funasr/models/data2vec/data2vec.py
index 19c5612..ed58d1c 100644
--- a/funasr/models/data2vec/data2vec.py
+++ b/funasr/models/data2vec/data2vec.py
@@ -10,13 +10,15 @@
 from typing import Tuple
 
 import torch
+import torch.nn as nn
 
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.models.base_model import FunASRModel
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.specaug.abs_specaug import AbsSpecAug
+# from funasr.layers.abs_normalize import AbsNormalize
+# from funasr.models.base_model import FunASRModel
+# from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.frontends.abs_frontend import AbsFrontend
+
+# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+# from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.train_utils.device_funcs import force_gatherable
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@@ -28,16 +30,16 @@
         yield
 
 
-class Data2VecPretrainModel(FunASRModel):
+class Data2VecPretrainModel(nn.Module):
     """Data2Vec Pretrain model"""
 
     def __init__(
-            self,
-            frontend: Optional[AbsFrontend],
-            specaug: Optional[AbsSpecAug],
-            normalize: Optional[AbsNormalize],
-            encoder: AbsEncoder,
-            preencoder: Optional[AbsPreEncoder] = None,
+        self,
+        frontend=None,
+        specaug=None,
+        normalize=None,
+        encoder=None,
+        preencoder=None,
     ):
 
         super().__init__()
@@ -50,9 +52,9 @@
         self.num_updates = 0
 
     def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Calc loss
         Args:
@@ -60,10 +62,7 @@
             speech_lengths: (Batch, )
         """
         # Check that batch_size is unified
-        assert (
-                speech.shape[0]
-                == speech_lengths.shape[0]
-        ), (speech.shape, speech_lengths.shape)
+        assert speech.shape[0] == speech_lengths.shape[0], (speech.shape, speech_lengths.shape)
 
         self.encoder.set_num_updates(self.num_updates)
 
@@ -90,17 +89,15 @@
         return loss, stats, weight
 
     def collect_feats(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Dict[str, torch.Tensor]:
         feats, feats_lengths = self._extract_feats(speech, speech_lengths)
         return {"feats": feats, "feats_lengths": feats_lengths}
 
     def encode(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
     ):
         """Frontend + Encoder.
         Args:
@@ -131,7 +128,7 @@
         return encoder_out
 
     def _extract_feats(
-            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         assert speech_lengths.dim() == 1, speech_lengths.shape
 

--
Gitblit v1.9.1