| | |
| | | # 2D matrix |
| | | if not is_conv: |
| | | assert ( |
| | | module.weight.size(1) % block_size == 0 |
| | | module.weight.size(1) % block_size == 0 |
| | | ), "Input features must be a multiple of block sizes" |
| | | |
| | | # 4D matrix |
| | |
| | | # 1x1 convolutions |
| | | if module.kernel_size == (1, 1): |
| | | assert ( |
| | | module.in_channels % block_size == 0 |
| | | module.in_channels % block_size == 0 |
| | | ), "Input channels must be a multiple of block sizes" |
| | | # regular convolutions |
| | | else: |
| | |
| | | out_features = weight.size(0) |
| | | |
| | | # split weight matrix into blocks and randomly drop selected blocks |
| | | mask = torch.zeros( |
| | | in_features // block_size * out_features, device=weight.device |
| | | ) |
| | | mask = torch.zeros(in_features // block_size * out_features, device=weight.device) |
| | | mask.bernoulli_(p) |
| | | mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) |
| | | |
| | |
| | | mask.bernoulli_(p) |
| | | mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) |
| | | else: |
| | | mask = torch.zeros( |
| | | weight.size(0), weight.size(1), device=weight.device |
| | | ) |
| | | mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device) |
| | | mask.bernoulli_(p) |
| | | mask = ( |
| | | mask.unsqueeze(2) |
| | | .unsqueeze(3) |
| | | .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) |
| | | .unsqueeze(3) |
| | | .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) |
| | | ) |
| | | |
| | | # scale weights and apply mask |
| | | mask = mask.to( |
| | | torch.bool |
| | | ) # x.bool() is not currently supported in TorchScript |
| | | mask = mask.to(torch.bool) # x.bool() is not currently supported in TorchScript |
| | | s = 1 / (1 - p) |
| | | mod.weight.data = s * weight.masked_fill(mask, 0) |
| | | |