| | |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.attention import ( |
| | | MultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttentionChunk, |
| | | LegacyRelPositionMultiHeadedAttention, # noqa: H301 |
| | | ) |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.embedding import ( |
| | | PositionalEncoding, # noqa: H301 |
| | | ScaledPositionalEncoding, # noqa: H301 |
| | |
| | | feed_forward: torch.nn.Module, |
| | | feed_forward_macaron: torch.nn.Module, |
| | | conv_mod: torch.nn.Module, |
| | | norm_class: torch.nn.Module = torch.nn.LayerNorm, |
| | | norm_class: torch.nn.Module = LayerNorm, |
| | | norm_args: Dict = {}, |
| | | dropout_rate: float = 0.0, |
| | | ) -> None: |
| | |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len) |
| | | mask = make_source_mask(x_len).to(x.device) |
| | | |
| | | if self.unified_model_training: |
| | | chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() |
| | |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x, olens |
| | | return x, olens, None |
| | | |
| | | def simu_chunk_forward( |
| | | self, |