liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/specaug/mask_along_axis.py
@@ -66,6 +66,7 @@
    spec = spec.view(*org_size)
    return spec, spec_lengths
def mask_along_axis_lfr(
    spec: torch.Tensor,
    spec_lengths: torch.Tensor,
@@ -150,8 +151,7 @@
            mask_width_range = (0, mask_width_range)
        if len(mask_width_range) != 2:
            raise TypeError(
                f"mask_width_range must be a tuple of int and int values: "
                f"{mask_width_range}",
                f"mask_width_range must be a tuple of int and int values: " f"{mask_width_range}",
            )
        assert mask_width_range[1] > mask_width_range[0]
@@ -271,6 +271,7 @@
            )
        return spec, spec_lengths
class MaskAlongAxisLFR(torch.nn.Module):
    def __init__(
        self,
@@ -284,8 +285,7 @@
            mask_width_range = (0, mask_width_range)
        if len(mask_width_range) != 2:
            raise TypeError(
                f"mask_width_range must be a tuple of int and int values: "
                f"{mask_width_range}",
                f"mask_width_range must be a tuple of int and int values: " f"{mask_width_range}",
            )
        assert mask_width_range[1] > mask_width_range[0]
@@ -333,4 +333,4 @@
            num_mask=self.num_mask,
            replace_with_zero=self.replace_with_zero,
            lfr_rate=self.lfr_rate,
        )
        )