| | |
| | | import pdb |
| | | import math |
| | | |
| | | |
| | | class ConvolutionModule(nn.Module): |
| | | """ConvolutionModule in Conformer model. |
| | | |
| | | Args: |
| | | channels (int): The number of channels of conv layers. |
| | | kernel_size (int): Kernerl size of conv layers. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): |
| | |
| | | |
| | | def forward(self, x): |
| | | """Compute convolution module. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Input tensor (#batch, time, channels). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, time, channels). |
| | | |
| | | """ |
| | | # exchange the temporal dimension and the feature dimension |
| | | x = x.transpose(1, 2) |
| | |
| | | return x.transpose(1, 2) |
| | | |
| | | |
| | | |
| | | class MFCCAEncoder(AbsEncoder): |
| | | """Conformer encoder module. |
| | | |
| | | Args: |
| | | input_size (int): Input dimension. |
| | | output_size (int): Dimention of attention. |
| | |
| | | zero_triu (bool): Whether to zero the upper triangular part of attention matrix. |
| | | cnn_module_kernel (int): Kernerl size of convolution module. |
| | | padding_idx (int): Padding idx for input_layer=embed. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "conv2d", |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 3, |
| | | macaron_style: bool = False, |
| | | rel_pos_type: str = "legacy", |
| | | pos_enc_layer_type: str = "rel_pos", |
| | | selfattention_layer_type: str = "rel_selfattn", |
| | | activation_type: str = "swish", |
| | | use_cnn_module: bool = True, |
| | | zero_triu: bool = False, |
| | | cnn_module_kernel: int = 31, |
| | | padding_idx: int = -1, |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | input_layer: str = "conv2d", |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 3, |
| | | macaron_style: bool = False, |
| | | rel_pos_type: str = "legacy", |
| | | pos_enc_layer_type: str = "rel_pos", |
| | | selfattention_layer_type: str = "rel_selfattn", |
| | | activation_type: str = "swish", |
| | | use_cnn_module: bool = True, |
| | | zero_triu: bool = False, |
| | | cnn_module_kernel: int = 31, |
| | | padding_idx: int = -1, |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | |
| | | ) |
| | | else: |
| | | raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) |
| | | |
| | | |
| | | if input_layer == "linear": |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Linear(input_size, output_size), |
| | |
| | | assert pos_enc_layer_type == "legacy_rel_pos" |
| | | encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention |
| | | encoder_selfattn_layer_args = ( |
| | | attention_heads, |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | ) |
| | |
| | | ) |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(output_size) |
| | | self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3)) |
| | | self.conv1 = torch.nn.Conv2d(8, 16, [5, 7], stride=[1, 1], padding=(2, 3)) |
| | | |
| | | self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3)) |
| | | self.conv2 = torch.nn.Conv2d(16, 32, [5, 7], stride=[1, 1], padding=(2, 3)) |
| | | |
| | | self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3)) |
| | | self.conv3 = torch.nn.Conv2d(32, 16, [5, 7], stride=[1, 1], padding=(2, 3)) |
| | | |
| | | self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3)) |
| | | self.conv4 = torch.nn.Conv2d(16, 1, [5, 7], stride=[1, 1], padding=(2, 3)) |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | channel_size: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | channel_size: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Calculate forward propagation. |
| | | |
| | | Args: |
| | | xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). |
| | | ilens (torch.Tensor): Input length (#batch). |
| | | prev_states (torch.Tensor): Not to be used now. |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, L, output_size). |
| | | torch.Tensor: Output length (#batch). |
| | | torch.Tensor: Not to be used now. |
| | | |
| | | """ |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | if ( |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | ): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | |
| | | |
| | | t_leng = xs_pad.size(1) |
| | | d_dim = xs_pad.size(2) |
| | | xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim) |
| | | #pdb.set_trace() |
| | | if(channel_size<8): |
| | | repeat_num = math.ceil(8/channel_size) |
| | | xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:] |
| | | xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim) |
| | | # pdb.set_trace() |
| | | if (channel_size < 8): |
| | | repeat_num = math.ceil(8 / channel_size) |
| | | xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :] |
| | | xs_pad = self.conv1(xs_pad) |
| | | xs_pad = self.conv2(xs_pad) |
| | | xs_pad = self.conv3(xs_pad) |
| | | xs_pad = self.conv4(xs_pad) |
| | | xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim) |
| | | xs_pad = xs_pad.squeeze().reshape(-1, t_leng, d_dim) |
| | | mask_tmp = masks.size(1) |
| | | masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:] |
| | | masks = masks.reshape(-1, channel_size, mask_tmp, t_leng)[:, 0, :, :] |
| | | |
| | | if self.normalize_before: |
| | | xs_pad = self.after_norm(xs_pad) |
| | | |
| | | olens = masks.squeeze(1).sum(1) |
| | | return xs_pad, olens, None |
| | | |
| | | def forward_hidden( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | prev_states: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Calculate forward propagation. |
| | | |
| | | Args: |
| | | xs_pad (torch.Tensor): Input tensor (#batch, L, input_size). |
| | | ilens (torch.Tensor): Input length (#batch). |
| | | prev_states (torch.Tensor): Not to be used now. |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, L, output_size). |
| | | torch.Tensor: Output length (#batch). |
| | | torch.Tensor: Not to be used now. |
| | | |
| | | """ |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | if ( |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | isinstance(self.embed, Conv2dSubsampling) |
| | | or isinstance(self.embed, Conv2dSubsampling6) |
| | | or isinstance(self.embed, Conv2dSubsampling8) |
| | | ): |
| | | short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1)) |
| | | if short_status: |
| | |
| | | self.hidden_feature = self.after_norm(hidden_feature) |
| | | |
| | | olens = masks.squeeze(1).sum(1) |
| | | return xs_pad, olens, None |
| | | return xs_pad, olens, None |