游雁
2023-12-19 0e622e694e6cb4459955f1e5942a7c53349ce640
funasr/models/data2vec/data2vec.py
@@ -10,13 +10,14 @@
from typing import Tuple
import torch
import torch.nn as nn
from funasr.layers.abs_normalize import AbsNormalize
from funasr.models.base_model import FunASRModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.specaug.abs_specaug import AbsSpecAug
# from funasr.layers.abs_normalize import AbsNormalize
# from funasr.models.base_model import FunASRModel
# from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.frontends.abs_frontend import AbsFrontend
# from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
# from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.train_utils.device_funcs import force_gatherable
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
@@ -28,16 +29,16 @@
        yield
class Data2VecPretrainModel(FunASRModel):
class Data2VecPretrainModel(nn.Module):
    """Data2Vec Pretrain model"""
    def __init__(
            self,
            frontend: Optional[AbsFrontend],
            specaug: Optional[AbsSpecAug],
            normalize: Optional[AbsNormalize],
            encoder: AbsEncoder,
            preencoder: Optional[AbsPreEncoder] = None,
            frontend = None,
            specaug = None,
            normalize = None,
            encoder = None,
            preencoder = None,
    ):
        super().__init__()