aky15
2023-05-24 8c24f52fd25fce2df46dc4d9dffb45619dc38f9f
funasr/modules/subsampling.py
@@ -506,9 +506,9 @@
                )
                self.conv = torch.nn.Sequential(
                    torch.nn.Conv2d(1, conv_size, 3, 2),
                    torch.nn.Conv2d(1, conv_size, 3, 2, [1,0]),
                    torch.nn.ReLU(),
                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2),
                    torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2, [(kernel_2-1)//2, 0]),
                    torch.nn.ReLU(),
                )
@@ -597,7 +597,7 @@
            mask: Mask of output sequences. (B, sub(T))
        """
        if self.subsampling_factor > 1:
            return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2]
            return mask[:, ::2][:, ::self.stride_2]
        else:
            return mask