| | |
| | | """SpecAugment module.""" |
| | | |
| | | from typing import Optional |
| | | from typing import Sequence |
| | | from typing import Union |
| | |
| | | from funasr.register import tables |
| | | |
| | | import torch.nn as nn |
| | | |
| | | |
| | | @tables.register("specaug_classes", "SpecAug") |
| | | class SpecAug(nn.Module): |
| | |
| | | num_time_mask: int = 2, |
| | | ): |
| | | if not apply_time_warp and not apply_time_mask and not apply_freq_mask: |
| | | raise ValueError( |
| | | "Either one of time_warp, time_mask, or freq_mask should be applied" |
| | | ) |
| | | raise ValueError("Either one of time_warp, time_mask, or freq_mask should be applied") |
| | | if ( |
| | | apply_time_mask |
| | | and (time_mask_width_range is not None) |
| | |
| | | x, x_lengths = self.time_mask(x, x_lengths) |
| | | return x, x_lengths |
| | | |
| | | |
| | | @tables.register("specaug_classes", "SpecAugLFR") |
| | | class SpecAugLFR(nn.Module): |
| | | """Implementation of SpecAug. |
| | |
| | | num_time_mask: int = 2, |
| | | ): |
| | | if not apply_time_warp and not apply_time_mask and not apply_freq_mask: |
| | | raise ValueError( |
| | | "Either one of time_warp, time_mask, or freq_mask should be applied" |
| | | ) |
| | | raise ValueError("Either one of time_warp, time_mask, or freq_mask should be applied") |
| | | if ( |
| | | apply_time_mask |
| | | and (time_mask_width_range is not None) |
| | |
| | | dim="freq", |
| | | mask_width_range=freq_mask_width_range, |
| | | num_mask=num_freq_mask, |
| | | lfr_rate=lfr_rate+1, |
| | | lfr_rate=lfr_rate + 1, |
| | | ) |
| | | |
| | | else: |
| | |
| | | x, x_lengths = self.freq_mask(x, x_lengths) |
| | | if self.time_mask is not None: |
| | | x, x_lengths = self.time_mask(x, x_lengths) |
| | | return x, x_lengths |
| | | return x, x_lengths |