嘉渊
2023-04-28 675951eab83e66c7654b71eca6c94b788c5876d5
update
1个文件已修改
20 ■■■■ 已修改文件
funasr/models/encoder/mfcca_encoder.py 20 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/mfcca_encoder.py
@@ -34,15 +34,16 @@
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.encoder.abs_encoder import AbsEncoder
import pdb
import math
class ConvolutionModule(nn.Module):
    """ConvolutionModule in Conformer model.
    Args:
        channels (int): The number of channels of conv layers.
        kernel_size (int): Kernerl size of conv layers.
    """
    def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
@@ -81,13 +82,10 @@
    def forward(self, x):
        """Compute convolution module.
        Args:
            x (torch.Tensor): Input tensor (#batch, time, channels).
        Returns:
            torch.Tensor: Output tensor (#batch, time, channels).
        """
        # exchange the temporal dimension and the feature dimension
        x = x.transpose(1, 2)
@@ -105,10 +103,8 @@
        return x.transpose(1, 2)
class MFCCAEncoder(torch.nn.Module):
class MFCCAEncoder(AbsEncoder):
    """Conformer encoder module.
    Args:
        input_size (int): Input dimension.
        output_size (int): Dimention of attention.
@@ -138,7 +134,6 @@
        zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
        cnn_module_kernel (int): Kernerl size of convolution module.
        padding_idx (int): Padding idx for input_layer=embed.
    """
    def __init__(
@@ -343,17 +338,14 @@
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Calculate forward propagation.
        Args:
            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
            ilens (torch.Tensor): Input length (#batch).
            prev_states (torch.Tensor): Not to be used now.
        Returns:
            torch.Tensor: Output tensor (#batch, L, output_size).
            torch.Tensor: Output length (#batch).
            torch.Tensor: Not to be used now.
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        if (
@@ -396,6 +388,7 @@
        olens = masks.squeeze(1).sum(1)
        return xs_pad, olens, None
    def forward_hidden(
        self,
        xs_pad: torch.Tensor,
@@ -403,17 +396,14 @@
        prev_states: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Calculate forward propagation.
        Args:
            xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
            ilens (torch.Tensor): Input length (#batch).
            prev_states (torch.Tensor): Not to be used now.
        Returns:
            torch.Tensor: Output tensor (#batch, L, output_size).
            torch.Tensor: Output length (#batch).
            torch.Tensor: Not to be used now.
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        if (