| | |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | ) |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | | from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.models.base_model import FunASRModel |
| | | from funasr.modules.streaming_utils.chunk_utilis import sequence_mask |
| New file |
| | |
| | | from abc import ABC |
| | | from abc import abstractmethod |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | |
| | | |
| | | class AbsEncoder(torch.nn.Module, ABC): |
| | | @abstractmethod |
| | | def output_size(self) -> int: |
| | | raise NotImplementedError |
| | | |
| | | @abstractmethod |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | raise NotImplementedError |