| | |
| | | import torch.nn.functional as F |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.data2vec.data_utils import compute_mask_indices |
| | | from funasr.modules.data2vec.ema_module import EMAModule |
| | | from funasr.modules.data2vec.grad_multiply import GradMultiply |
| | |
| | | return end - r * pct_remaining |
| | | |
| | | |
| | | class Data2VecEncoder(torch.nn.Module): |
| | | class Data2VecEncoder(AbsEncoder): |
| | | def __init__( |
| | | self, |
| | | # for ConvFeatureExtractionModel |
| | |
| | | ) |
| | | |
| | | def output_size(self) -> int: |
| | | return self.encoder_embed_dim |
| | | return self.encoder_embed_dim |