From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/data2vec/data2vec.py | 51 ++++++++++++++++++++++++---------------------------
1 files changed, 24 insertions(+), 27 deletions(-)
diff --git a/funasr/models/data2vec/data2vec.py b/funasr/models/data2vec/data2vec.py
index 19c5612..ed58d1c 100644
--- a/funasr/models/data2vec/data2vec.py
+++ b/funasr/models/data2vec/data2vec.py
@@ -10,13 +10,15 @@
from typing import Tuple
import torch
+import torch.nn as nn
-from funasr.layers.abs_normalize import AbsNormalize
-from funasr.models.base_model import FunASRModel
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.specaug.abs_specaug import AbsSpecAug
+# from funasr.layers.abs_normalize import AbsNormalize
+# from funasr.models.base_model import FunASRModel
+# from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.frontends.abs_frontend import AbsFrontend
+
+# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+# from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.train_utils.device_funcs import force_gatherable
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@@ -28,16 +30,16 @@
yield
-class Data2VecPretrainModel(FunASRModel):
+class Data2VecPretrainModel(nn.Module):
"""Data2Vec Pretrain model"""
def __init__(
- self,
- frontend: Optional[AbsFrontend],
- specaug: Optional[AbsSpecAug],
- normalize: Optional[AbsNormalize],
- encoder: AbsEncoder,
- preencoder: Optional[AbsPreEncoder] = None,
+ self,
+ frontend=None,
+ specaug=None,
+ normalize=None,
+ encoder=None,
+ preencoder=None,
):
super().__init__()
@@ -50,9 +52,9 @@
self.num_updates = 0
def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Calc loss
Args:
@@ -60,10 +62,7 @@
speech_lengths: (Batch, )
"""
# Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape)
+ assert speech.shape[0] == speech_lengths.shape[0], (speech.shape, speech_lengths.shape)
self.encoder.set_num_updates(self.num_updates)
@@ -90,17 +89,15 @@
return loss, stats, weight
def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
):
"""Frontend + Encoder.
Args:
@@ -131,7 +128,7 @@
return encoder_out
def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
--
Gitblit v1.9.1