kongdeqiang
6 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/data2vec/data2vec.py
@@ -10,13 +10,15 @@
from typing import Tuple
import torch
import torch.nn as nn
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.base_model import FunASRModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
# from funasr.layers.abs_normalize import AbsNormalize
# from funasr.models.base_model import FunASRModel
# from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.frontends.abs_frontend import AbsFrontend
# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
# from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.train_utils.device_funcs import force_gatherable
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@@ -28,16 +30,16 @@
        yield
class Data2VecPretrainModel(FunASRModel):
class Data2VecPretrainModel(nn.Module):
    """Data2Vec Pretrain model"""
    def __init__(
            self,
            frontend: Optional[AbsFrontend],
            specaug: Optional[AbsSpecAug],
            normalize: Optional[AbsNormalize],
            encoder: AbsEncoder,
            preencoder: Optional[AbsPreEncoder] = None,
        self,
        frontend=None,
        specaug=None,
        normalize=None,
        encoder=None,
        preencoder=None,
    ):
        super().__init__()
@@ -50,9 +52,9 @@
        self.num_updates = 0
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Calc loss
        Args:
@@ -60,10 +62,7 @@
            speech_lengths: (Batch, )
        """
        # Check that batch_size is unified
        assert (
                speech.shape[0]
                == speech_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape)
        assert speech.shape[0] == speech_lengths.shape[0], (speech.shape, speech_lengths.shape)
        self.encoder.set_num_updates(self.num_updates)
@@ -90,17 +89,15 @@
        return loss, stats, weight
    def collect_feats(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        return {"feats": feats, "feats_lengths": feats_lengths}
    def encode(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ):
        """Frontend + Encoder.
        Args:
@@ -131,7 +128,7 @@
        return encoder_out
    def _extract_feats(
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        assert speech_lengths.dim() == 1, speech_lengths.shape