| | |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | ) |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.base_model import FunASRModel |
| | |
| | | self, |
| | | vocab_size: int, |
| | | token_list: Union[Tuple[str, ...], List[str]], |
| | | frontend: Optional[torch.nn.Module], |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[torch.nn.Module], |
| | | normalize: Optional[torch.nn.Module], |
| | | encoder: AbsEncoder, |
| | |
| | | RelPositionMultiHeadedAttention, # noqa: H301 |
| | | LegacyRelPositionMultiHeadedAttention, # noqa: H301 |
| | | ) |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.embedding import ( |
| | | PositionalEncoding, # noqa: H301 |
| | | ScaledPositionalEncoding, # noqa: H301 |
| | |
| | | from funasr.modules.subsampling import TooShortUttError |
| | | from funasr.modules.subsampling import check_short_utt |
| | | from funasr.modules.subsampling import Conv2dSubsamplingPad |
| | | class ConvolutionModule(nn.Module): |
| | | |
| | | class ConvolutionModule(AbsEncoder): |
| | | """ConvolutionModule in Conformer model. |
| | | |
| | | Args: |
| | |
| | | import logging |
| | | |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.attention import MultiHeadedAttention |
| | | from funasr.modules.embedding import PositionalEncoding |
| | | from funasr.modules.layer_norm import LayerNorm |
| | |
| | | from funasr.modules.subsampling import check_short_utt |
| | | |
| | | |
| | | class EncoderLayer(nn.Module): |
| | | class EncoderLayer(AbsEncoder): |
| | | """Encoder layer module. |
| | | |
| | | Args: |
| New file |
| | |
| | | from abc import ABC |
| | | from abc import abstractmethod |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | |
| | | |
| | | class AbsFrontend(torch.nn.Module, ABC): |
| | | @abstractmethod |
| | | def output_size(self) -> int: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def forward( |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | raise NotImplementedError |
| New file |
| | |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | |
| | | |
| | | class AbsSpecAug(torch.nn.Module): |
| | | """Abstract class for the augmentation of spectrogram |
| | | The process-flow: |
| | | Frontend -> SpecAug -> Normalization -> Encoder -> Decoder |
| | | """ |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, x_lengths: torch.Tensor = None |
| | | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | | raise NotImplementedError |
| | |
| | | from typing import Sequence |
| | | from typing import Union |
| | | |
| | | import torch.nn |
| | | |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.layers.mask_along_axis import MaskAlongAxis |
| | | from funasr.layers.mask_along_axis import MaskAlongAxisVariableMaxWidth |
| | | from funasr.layers.mask_along_axis import MaskAlongAxisLFR |
| | | from funasr.layers.time_warp import TimeWarp |
| | | |
| | | |
| | | class SpecAug(torch.nn.Module): |
| | | class SpecAug(AbsSpecAug): |
| | | """Implementation of SpecAug. |
| | | |
| | | Reference: |