jmwang66
2023-05-16 6f7e27eb7c2d0a7649ec8f14d167c8da8e29f906
funasr/models/encoder/mfcca_encoder.py
@@ -38,13 +38,12 @@
import pdb
import math
class ConvolutionModule(nn.Module):
    """ConvolutionModule in Conformer model.
    Args:
        channels (int): The number of channels of conv layers.
        kernel_size (int): Kernerl size of conv layers.
    """
    def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
@@ -83,13 +82,10 @@
    def forward(self, x):
        """Compute convolution module.
        Args:
            x (torch.Tensor): Input tensor (#batch, time, channels).
        Returns:
            torch.Tensor: Output tensor (#batch, time, channels).
        """
        # exchange the temporal dimension and the feature dimension
        x = x.transpose(1, 2)
@@ -107,10 +103,8 @@
        return x.transpose(1, 2)
class MFCCAEncoder(AbsEncoder):
    """Conformer encoder module.
    Args:
        input_size (int): Input dimension.
        output_size (int): Dimention of attention.
@@ -140,7 +134,6 @@
        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
        cnn_module_kernel (int): Kernerl size of convolution module.
        padding_idx (int): Padding idx for input_layer=embed.
    """
    def __init__(
@@ -345,17 +338,14 @@
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Calculate forward propagation.
        Args:
            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
            ilens (torch.Tensor): Input length (#batch).
            prev_states (torch.Tensor): Not to be used now.
        Returns:
            torch.Tensor: Output tensor (#batch, L, output_size).
            torch.Tensor: Output length (#batch).
            torch.Tensor: Not to be used now.
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        if (
@@ -398,6 +388,7 @@
        olens = masks.squeeze(1).sum(1)
        return xs_pad, olens, None
    def forward_hidden(
        self,
        xs_pad: torch.Tensor,
@@ -405,17 +396,14 @@
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Calculate forward propagation.
        Args:
            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
            ilens (torch.Tensor): Input length (#batch).
            prev_states (torch.Tensor): Not to be used now.
        Returns:
            torch.Tensor: Output tensor (#batch, L, output_size).
            torch.Tensor: Output length (#batch).
            torch.Tensor: Not to be used now.
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        if (