kongdeqiang
5 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/models/data2vec/quant_noise.py
@@ -40,7 +40,7 @@
    # 2D matrix
    if not is_conv:
        assert (
                module.weight.size(1) % block_size == 0
            module.weight.size(1) % block_size == 0
        ), "Input features must be a multiple of block sizes"
    # 4D matrix
@@ -48,7 +48,7 @@
        # 1x1 convolutions
        if module.kernel_size == (1, 1):
            assert (
                    module.in_channels % block_size == 0
                module.in_channels % block_size == 0
            ), "Input channels must be a multiple of block sizes"
        # regular convolutions
        else:
@@ -65,9 +65,7 @@
                out_features = weight.size(0)
                # split weight matrix into blocks and randomly drop selected blocks
                mask = torch.zeros(
                    in_features // block_size * out_features, device=weight.device
                )
                mask = torch.zeros(in_features // block_size * out_features, device=weight.device)
                mask.bernoulli_(p)
                mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
@@ -86,20 +84,16 @@
                    mask.bernoulli_(p)
                    mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
                else:
                    mask = torch.zeros(
                        weight.size(0), weight.size(1), device=weight.device
                    )
                    mask = torch.zeros(weight.size(0), weight.size(1), device=weight.device)
                    mask.bernoulli_(p)
                    mask = (
                        mask.unsqueeze(2)
                            .unsqueeze(3)
                            .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
                        .unsqueeze(3)
                        .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
                    )
            # scale weights and apply mask
            mask = mask.to(
                torch.bool
            )  # x.bool() is not currently supported in TorchScript
            mask = mask.to(torch.bool)  # x.bool() is not currently supported in TorchScript
            s = 1 / (1 - p)
            mod.weight.data = s * weight.masked_fill(mask, 0)