zhifu gao
2023-07-03 edec2fe85eda80ff1e24aef30b36c7bbbb55ec2a
funasr/modules/subsampling.py
@@ -427,6 +427,7 @@
        conv_size: Union[int, Tuple],
        subsampling_factor: int = 4,
        vgg_like: bool = True,
        conv_kernel_size: int = 3,
        output_size: Optional[int] = None,
    ) -> None:
        """Construct a ConvInput object."""
@@ -436,14 +437,14 @@
                conv_size1, conv_size2 = conv_size
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((1, 2)),
                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((1, 2)),
                )
@@ -462,14 +463,14 @@
                kernel_1 = int(subsampling_factor / 2)
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((kernel_1, 2)),
                    torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1),
                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
                    torch.nn.ReLU(),
                    torch.nn.MaxPool2d((2, 2)),
                )
@@ -487,14 +488,14 @@
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]),
                    torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]),
                    torch.nn.ReLU(),
                )
                output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
                self.subsampling_factor = subsampling_factor
                self.kernel_2 = 3
                self.kernel_2 = conv_kernel_size
                self.stride_2 = 1
                self.create_new_mask = self.create_new_conv2d_mask
@@ -506,9 +507,9 @@
                )
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size, 3, 2),
                    torch.nn.Conv2d(1, conv_size, 3, 2, [1,0]),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2, [(kernel_2-1)//2, 0]),
                    torch.nn.ReLU(),
                )
@@ -597,7 +598,7 @@
            mask: Mask of output sequences. (B, sub(T))
        """
        if self.subsampling_factor > 1:
            return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
            return mask[:, ::2][:, ::self.stride_2]
        else:
            return mask