From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/data2vec/data2vec.py | 36 ++++++++++++++++--------------------
1 files changed, 16 insertions(+), 20 deletions(-)
diff --git a/funasr/models/data2vec/data2vec.py b/funasr/models/data2vec/data2vec.py
index c77cedf..ed58d1c 100644
--- a/funasr/models/data2vec/data2vec.py
+++ b/funasr/models/data2vec/data2vec.py
@@ -16,6 +16,7 @@
# from funasr.models.base_model import FunASRModel
# from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.frontends.abs_frontend import AbsFrontend
+
# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
# from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.train_utils.device_funcs import force_gatherable
@@ -33,12 +34,12 @@
"""Data2Vec Pretrain model"""
def __init__(
- self,
- frontend = None,
- specaug = None,
- normalize = None,
- encoder = None,
- preencoder = None,
+ self,
+ frontend=None,
+ specaug=None,
+ normalize=None,
+ encoder=None,
+ preencoder=None,
):
super().__init__()
@@ -51,9 +52,9 @@
self.num_updates = 0
def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Calc loss
Args:
@@ -61,10 +62,7 @@
speech_lengths: (Batch, )
"""
# Check that batch_size is unified
- assert (
- speech.shape[0]
- == speech_lengths.shape[0]
- ), (speech.shape, speech_lengths.shape)
+ assert speech.shape[0] == speech_lengths.shape[0], (speech.shape, speech_lengths.shape)
self.encoder.set_num_updates(self.num_updates)
@@ -91,17 +89,15 @@
return loss, stats, weight
def collect_feats(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Dict[str, torch.Tensor]:
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
return {"feats": feats, "feats_lengths": feats_lengths}
def encode(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
):
"""Frontend + Encoder.
Args:
@@ -132,7 +128,7 @@
return encoder_out
def _extract_feats(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
assert speech_lengths.dim() == 1, speech_lengths.shape
--
Gitblit v1.9.1