| | |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.attention import ( |
| | | MultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttentionChunk, |
| | | LegacyRelPositionMultiHeadedAttention, # noqa: H301 |
| | | ) |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.embedding import ( |
| | | PositionalEncoding, # noqa: H301 |
| | | ScaledPositionalEncoding, # noqa: H301 |
| | |
| | | feed_forward: torch.nn.Module, |
| | | feed_forward_macaron: torch.nn.Module, |
| | | conv_mod: torch.nn.Module, |
| | | norm_class: torch.nn.Module = torch.nn.LayerNorm, |
| | | norm_class: torch.nn.Module = LayerNorm, |
| | | norm_args: Dict = {}, |
| | | dropout_rate: float = 0.0, |
| | | ) -> None: |
| | |
| | | |
| | | return x, cache |
| | | |
| | | class ConformerChunkEncoder(torch.nn.Module): |
| | | class ConformerChunkEncoder(AbsEncoder): |
| | | """Encoder module definition. |
| | | Args: |
| | | input_size: Input size. |
| | |
| | | output_size, |
| | | ) |
| | | |
| | | self.output_size = output_size |
| | | self._output_size = output_size |
| | | |
| | | self.dynamic_chunk_training = dynamic_chunk_training |
| | | self.short_chunk_threshold = short_chunk_threshold |
| | |
| | | self.jitter_range = jitter_range |
| | | |
| | | self.time_reduction_factor = time_reduction_factor |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: |
| | | """Return the corresponding number of sample for a given chunk size, in frames. |
| | |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len) |
| | | mask = make_source_mask(x_len).to(x.device) |
| | | |
| | | if self.unified_model_training: |
| | | chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() |
| | |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x, olens |
| | | return x, olens, None |
| | | |
| | | def simu_chunk_forward( |
| | | self, |