| | |
| | | conv_size: Union[int, Tuple], |
| | | subsampling_factor: int = 4, |
| | | vgg_like: bool = True, |
| | | conv_kernel_size: int = 3, |
| | | output_size: Optional[int] = None, |
| | | ) -> None: |
| | | """Construct a ConvInput object.""" |
| | |
| | | conv_size1, conv_size2 = conv_size |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((1, 2)), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((1, 2)), |
| | | ) |
| | |
| | | kernel_1 = int(subsampling_factor / 2) |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((kernel_1, 2)), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((2, 2)), |
| | | ) |
| | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]), |
| | | torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]), |
| | | torch.nn.ReLU(), |
| | | ) |
| | | |
| | | output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2) |
| | | |
| | | self.subsampling_factor = subsampling_factor |
| | | self.kernel_2 = 3 |
| | | self.kernel_2 = conv_kernel_size |
| | | self.stride_2 = 1 |
| | | |
| | | self.create_new_mask = self.create_new_conv2d_mask |