zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/data2vec/data2vec.py
@@ -16,6 +16,7 @@
# from funasr.models.base_model import FunASRModel
# from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.frontends.abs_frontend import AbsFrontend
# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
# from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.train_utils.device_funcs import force_gatherable
@@ -61,10 +62,7 @@
            speech_lengths: (Batch, )
        """
        # Check that batch_size is unified
        assert (
                speech.shape[0]
                == speech_lengths.shape[0]
        ), (speech.shape, speech_lengths.shape)
        assert speech.shape[0] == speech_lengths.shape[0], (speech.shape, speech_lengths.shape)
        self.encoder.set_num_updates(self.num_updates)
@@ -91,9 +89,7 @@
        return loss, stats, weight
    def collect_feats(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor
        self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        return {"feats": feats, "feats_lengths": feats_lengths}