| | |
| | | spec = spec.view(*org_size) |
| | | return spec, spec_lengths |
| | | |
| | | |
| | | def mask_along_axis_lfr( |
| | | spec: torch.Tensor, |
| | | spec_lengths: torch.Tensor, |
| | |
| | | mask_width_range = (0, mask_width_range) |
| | | if len(mask_width_range) != 2: |
| | | raise TypeError( |
| | | f"mask_width_range must be a tuple of int and int values: " |
| | | f"{mask_width_range}", |
| | | f"mask_width_range must be a tuple of int and int values: " f"{mask_width_range}", |
| | | ) |
| | | |
| | | assert mask_width_range[1] > mask_width_range[0] |
| | |
| | | ) |
| | | return spec, spec_lengths |
| | | |
| | | |
| | | class MaskAlongAxisLFR(torch.nn.Module): |
| | | def __init__( |
| | | self, |
| | |
| | | mask_width_range = (0, mask_width_range) |
| | | if len(mask_width_range) != 2: |
| | | raise TypeError( |
| | | f"mask_width_range must be a tuple of int and int values: " |
| | | f"{mask_width_range}", |
| | | f"mask_width_range must be a tuple of int and int values: " f"{mask_width_range}", |
| | | ) |
| | | |
| | | assert mask_width_range[1] > mask_width_range[0] |
| | |
| | | num_mask=self.num_mask, |
| | | replace_with_zero=self.replace_with_zero, |
| | | lfr_rate=self.lfr_rate, |
| | | ) |
| | | ) |