| | |
| | | from funasr.models.transformer.positionwise_feed_forward import ( |
| | | PositionwiseFeedForward, # noqa: H301 |
| | | ) |
| | | from funasr.models.transformer.repeat import repeat |
| | | from funasr.models.transformer.subsampling import Conv2dSubsampling |
| | | from funasr.models.transformer.subsampling import Conv2dSubsampling2 |
| | | from funasr.models.transformer.subsampling import Conv2dSubsampling6 |
| | | from funasr.models.transformer.subsampling import Conv2dSubsampling8 |
| | | from funasr.models.transformer.subsampling import TooShortUttError |
| | | from funasr.models.transformer.subsampling import check_short_utt |
| | | from funasr.models.transformer.utils.repeat import repeat |
| | | from funasr.models.transformer.utils.subsampling import Conv2dSubsampling |
| | | from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2 |
| | | from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6 |
| | | from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8 |
| | | from funasr.models.transformer.utils.subsampling import TooShortUttError |
| | | from funasr.models.transformer.utils.subsampling import check_short_utt |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | import pdb |
| | | import math |
| | | |
| | | |
| | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "conv2d", |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 3, |
| | | macaron_style: bool = False, |
| | | rel_pos_type: str = "legacy", |
| | | pos_enc_layer_type: str = "rel_pos", |
| | | selfattention_layer_type: str = "rel_selfattn", |
| | | activation_type: str = "swish", |
| | | use_cnn_module: bool = True, |
| | | zero_triu: bool = False, |
| | | cnn_module_kernel: int = 31, |
| | | padding_idx: int = -1, |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "conv2d", |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 3, |
| | | macaron_style: bool = False, |
| | | rel_pos_type: str = "legacy", |
| | | pos_enc_layer_type: str = "rel_pos", |
| | | selfattention_layer_type: str = "rel_selfattn", |
| | | activation_type: str = "swish", |
| | | use_cnn_module: bool = True, |
| | | zero_triu: bool = False, |
| | | cnn_module_kernel: int = 31, |
| | | padding_idx: int = -1, |
| | | ): |
| | | super().__init__() |
| | | self._output_size = output_size |
| | |
| | | elif pos_enc_layer_type == "legacy_rel_pos": |
| | | assert selfattention_layer_type == "legacy_rel_selfattn" |
| | | pos_enc_class = LegacyRelPositionalEncoding |
| | | logging.warning( |
| | | "Using legacy_rel_pos and it will be deprecated in the future." |
| | | ) |
| | | logging.warning("Using legacy_rel_pos and it will be deprecated in the future.") |
| | | else: |
| | | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) |
| | | |
| | |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | elif input_layer is None: |
| | | self.embed = torch.nn.Sequential( |
| | | pos_enc_class(output_size, positional_dropout_rate) |
| | | ) |
| | | self.embed = torch.nn.Sequential(pos_enc_class(output_size, positional_dropout_rate)) |
| | | else: |
| | | raise ValueError("unknown input_layer: " + input_layer) |
| | | self.normalize_before = normalize_before |
| | |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | | logging.warning( |
| | | "Using legacy_rel_selfattn and it will be deprecated in the future." |
| | | ) |
| | | logging.warning("Using legacy_rel_selfattn and it will be deprecated in the future.") |
| | | elif selfattention_layer_type == "rel_selfattn": |
| | | assert pos_enc_layer_type == "rel_pos" |
| | | encoder_selfattn_layer = RelPositionMultiHeadedAttention |
| | |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | channel_size: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | channel_size: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Calculate forward propagation. |
| | | Args: |
| | |
| | | """ |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | if ( |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | ): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | |
| | | t_leng = xs_pad.size(1) |
| | | d_dim = xs_pad.size(2) |
| | | xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim) |
| | | # pdb.set_trace() |
| | | if (channel_size < 8): |
| | | if channel_size < 8: |
| | | repeat_num = math.ceil(8 / channel_size) |
| | | xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :] |
| | | xs_pad = self.conv1(xs_pad) |
| | |
| | | return xs_pad, olens, None |
| | | |
| | | def forward_hidden( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Calculate forward propagation. |
| | | Args: |
| | |
| | | """ |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | if ( |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | ): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | |
| | | self.hidden_feature = self.after_norm(hidden_feature) |
| | | |
| | | olens = masks.squeeze(1).sum(1) |
| | | return xs_pad, olens, None |
| | | return xs_pad, olens, None |