| | |
| | | |
| | | |
| | | def compute_mask_indices( |
| | | shape: Tuple[int, int], |
| | | padding_mask: Optional[torch.Tensor], |
| | | mask_prob: float, |
| | | mask_length: int, |
| | | mask_type: str = "static", |
| | | mask_other: float = 0.0, |
| | | min_masks: int = 0, |
| | | no_overlap: bool = False, |
| | | min_space: int = 0, |
| | | require_same_masks: bool = True, |
| | | mask_dropout: float = 0.0, |
| | | shape: Tuple[int, int], |
| | | padding_mask: Optional[torch.Tensor], |
| | | mask_prob: float, |
| | | mask_length: int, |
| | | mask_type: str = "static", |
| | | mask_other: float = 0.0, |
| | | min_masks: int = 0, |
| | | no_overlap: bool = False, |
| | | min_space: int = 0, |
| | | require_same_masks: bool = True, |
| | | mask_dropout: float = 0.0, |
| | | ) -> np.ndarray: |
| | | """ |
| | | Computes random mask spans for a given shape |
| | |
| | | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) |
| | | |
| | | mask_idc = np.asarray( |
| | | [ |
| | | mask_idc[j] + offset |
| | | for j in range(len(mask_idc)) |
| | | for offset in range(lengths[j]) |
| | | ] |
| | | [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])] |
| | | ) |
| | | |
| | | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) |
| | |
| | | mask_idc = np.random.choice(mask_idc, min_len, replace=False) |
| | | if mask_dropout > 0: |
| | | num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) |
| | | mask_idc = np.random.choice( |
| | | mask_idc, len(mask_idc) - num_holes, replace=False |
| | | ) |
| | | mask_idc = np.random.choice(mask_idc, len(mask_idc) - num_holes, replace=False) |
| | | |
| | | mask[i, mask_idc] = True |
| | | |