| | |
| | | return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) |
| | | |
| | | |
| | | class FsmnFeedForward(torch.nn.Module): |
| | | """Position-wise feed forward for FSMN blocks. |
| | | |
| | | This is a module of multi-leyered conv1d designed |
| | | to replace position-wise feed-forward network |
| | | in FSMN block. |
| | | """ |
| | | |
| | | def __init__(self, in_chans, hidden_chans, out_chans, kernel_size, dropout_rate): |
| | | """Initialize FsmnFeedForward module. |
| | | |
| | | Args: |
| | | in_chans (int): Number of input channels. |
| | | hidden_chans (int): Number of hidden channels. |
| | | out_chans (int): Number of output channels. |
| | | kernel_size (int): Kernel size of conv1d. |
| | | dropout_rate (float): Dropout rate. |
| | | |
| | | """ |
| | | super(FsmnFeedForward, self).__init__() |
| | | self.w_1 = torch.nn.Conv1d( |
| | | in_chans, |
| | | hidden_chans, |
| | | kernel_size, |
| | | stride=1, |
| | | padding=(kernel_size - 1) // 2, |
| | | ) |
| | | self.w_2 = torch.nn.Conv1d( |
| | | hidden_chans, |
| | | out_chans, |
| | | kernel_size, |
| | | stride=1, |
| | | padding=(kernel_size - 1) // 2, |
| | | bias=False |
| | | ) |
| | | self.norm = torch.nn.LayerNorm(hidden_chans) |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | |
| | | def forward(self, x, ilens=None): |
| | | """Calculate forward propagation. |
| | | |
| | | Args: |
| | | x (torch.Tensor): Batch of input tensors (B, T, in_chans). |
| | | |
| | | Returns: |
| | | torch.Tensor: Batch of output tensors (B, T, out_chans). |
| | | |
| | | """ |
| | | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) |
| | | return self.w_2(self.norm(self.dropout(x)).transpose(-1, 1)).transpose(-1, 1), ilens |
| | | |
| | | |
| | | class Conv1dLinear(torch.nn.Module): |
| | | """Conv1D + Linear for Transformer block. |
| | | |