funasr/modules/subsampling.py
@@ -506,9 +506,9 @@ ) self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, conv_size, 3, 2), torch.nn.Conv2d(1, conv_size, 3, 2, [1,0]), torch.nn.ReLU(), torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2), torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2, [(kernel_2-1)//2, 0]), torch.nn.ReLU(), ) @@ -597,7 +597,7 @@ mask: Mask of output sequences. (B, sub(T)) """ if self.subsampling_factor > 1: return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2] return mask[:, ::2][:, ::self.stride_2] else: return mask