| | |
| | | |
| | | import torch |
| | | from torch import nn |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.attention import ( |
| | | MultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttentionChunk, |
| | | LegacyRelPositionMultiHeadedAttention, # noqa: H301 |
| | | ) |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.embedding import ( |
| | | PositionalEncoding, # noqa: H301 |
| | | ScaledPositionalEncoding, # noqa: H301 |
| | |
| | | StreamingRelPositionalEncoding, |
| | | ) |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.normalization import get_normalization |
| | | from funasr.modules.multi_layer_conv import Conv1dLinear |
| | | from funasr.modules.multi_layer_conv import MultiLayeredConv1d |
| | | from funasr.modules.nets_utils import get_activation |
| | |
| | | feed_forward: torch.nn.Module, |
| | | feed_forward_macaron: torch.nn.Module, |
| | | conv_mod: torch.nn.Module, |
| | | norm_class: torch.nn.Module = torch.nn.LayerNorm, |
| | | norm_class: torch.nn.Module = LayerNorm, |
| | | norm_args: Dict = {}, |
| | | dropout_rate: float = 0.0, |
| | | ) -> None: |
| | |
| | | interctc_use_conditioning: bool = False, |
| | | stochastic_depth_rate: Union[float, List[float]] = 0.0, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | |
| | | |
| | | return x, cache |
| | | |
| | | class ConformerChunkEncoder(torch.nn.Module): |
| | | class ConformerChunkEncoder(AbsEncoder): |
| | | """Encoder module definition. |
| | | Args: |
| | | input_size: Input size. |
| | |
| | | default_chunk_size: int = 16, |
| | | jitter_range: int = 4, |
| | | subsampling_factor: int = 1, |
| | | **activation_parameters, |
| | | ) -> None: |
| | | """Construct an Encoder object.""" |
| | | super().__init__() |
| | | |
| | | assert check_argument_types() |
| | | |
| | | self.embed = StreamingConvInput( |
| | | input_size, |
| | |
| | | ) |
| | | |
| | | activation = get_activation( |
| | | activation_type, **activation_parameters |
| | | activation_type |
| | | ) |
| | | |
| | | pos_wise_args = ( |
| | |
| | | simplified_att_score, |
| | | ) |
| | | |
| | | norm_class, norm_args = get_normalization( |
| | | norm_type, |
| | | ) |
| | | |
| | | fn_modules = [] |
| | | for _ in range(num_blocks): |
| | |
| | | PositionwiseFeedForward(*pos_wise_args), |
| | | PositionwiseFeedForward(*pos_wise_args), |
| | | CausalConvolution(*conv_mod_args), |
| | | norm_class=norm_class, |
| | | norm_args=norm_args, |
| | | dropout_rate=dropout_rate, |
| | | ) |
| | | fn_modules.append(module) |
| | |
| | | self.encoders = MultiBlocks( |
| | | [fn() for fn in fn_modules], |
| | | output_size, |
| | | norm_class=norm_class, |
| | | norm_args=norm_args, |
| | | ) |
| | | |
| | | self.output_size = output_size |
| | | self._output_size = output_size |
| | | |
| | | self.dynamic_chunk_training = dynamic_chunk_training |
| | | self.short_chunk_threshold = short_chunk_threshold |
| | |
| | | self.jitter_range = jitter_range |
| | | |
| | | self.time_reduction_factor = time_reduction_factor |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: |
| | | """Return the corresponding number of sample for a given chunk size, in frames. |
| | |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len) |
| | | mask = make_source_mask(x_len).to(x.device) |
| | | |
| | | if self.unified_model_training: |
| | | chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() |
| | | if self.training: |
| | | chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() |
| | | else: |
| | | chunk_size = self.default_chunk_size |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = make_chunk_mask( |
| | |
| | | |
| | | elif self.dynamic_chunk_training: |
| | | max_len = x.size(1) |
| | | chunk_size = torch.randint(1, max_len, (1,)).item() |
| | | if self.training: |
| | | chunk_size = torch.randint(1, max_len, (1,)).item() |
| | | |
| | | if chunk_size > (max_len * self.short_chunk_threshold): |
| | | chunk_size = max_len |
| | | if chunk_size > (max_len * self.short_chunk_threshold): |
| | | chunk_size = max_len |
| | | else: |
| | | chunk_size = (chunk_size % self.short_chunk_size) + 1 |
| | | else: |
| | | chunk_size = (chunk_size % self.short_chunk_size) + 1 |
| | | chunk_size = self.default_chunk_size |
| | | |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x, olens |
| | | return x, olens, None |
| | | |
| | | def full_utt_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | x_len: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: Encoder input features. (B, T_in, F) |
| | | x_len: Encoder input features lengths. (B,) |
| | | Returns: |
| | | x: Encoder outputs. (B, T_out, D_enc) |
| | | x_len: Encoder outputs lenghts. (B,) |
| | | """ |
| | | short_status, limit_size = check_short_utt( |
| | | self.embed.subsampling_factor, x.size(1) |
| | | ) |
| | | |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {x.size(1)} frames and is too short for subsampling " |
| | | + f"(it needs more than {limit_size} frames), return empty results", |
| | | x.size(1), |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len).to(x.device) |
| | | x, mask = self.embed(x, mask, None) |
| | | pos_enc = self.pos_enc(x) |
| | | x_utt = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=None, |
| | | ) |
| | | |
| | | if self.time_reduction_factor > 1: |
| | | x_utt = x_utt[:,::self.time_reduction_factor,:] |
| | | return x_utt |
| | | |
| | | def simu_chunk_forward( |
| | | self, |