| | |
| | | from typing import Optional, Tuple, List |
| | | import numpy as np |
| | | |
| | | |
| | | def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): |
| | | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) |
| | | return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) |
| | | |
| | | |
| | | class SamePad(nn.Module): |
| | | def __init__(self, kernel_size, causal=False): |
| | | super().__init__() |
| | | if causal: |
| | | self.remove = kernel_size - 1 |
| | | else: |
| | | self.remove = 1 if kernel_size % 2 == 0 else 0 |
| | | |
| | | def forward(self, x): |
| | | if self.remove > 0: |
| | | x = x[:, :, : -self.remove] |
| | | return x |
| | | def __init__(self, kernel_size, causal=False): |
| | | super().__init__() |
| | | if causal: |
| | | self.remove = kernel_size - 1 |
| | | else: |
| | | self.remove = 1 if kernel_size % 2 == 0 else 0 |
| | | |
| | | def forward(self, x): |
| | | if self.remove > 0: |
| | | x = x[:, :, : -self.remove] |
| | | return x |
| | | |
| | | |
| | | class TransposeLast(nn.Module): |
| | | def __init__(self, deconstruct_idx=None): |
| | | super().__init__() |
| | | self.deconstruct_idx = deconstruct_idx |
| | | |
| | | def forward(self, x): |
| | | if self.deconstruct_idx is not None: |
| | | x = x[self.deconstruct_idx] |
| | | return x.transpose(-2, -1) |
| | | def __init__(self, deconstruct_idx=None): |
| | | super().__init__() |
| | | self.deconstruct_idx = deconstruct_idx |
| | | |
| | | def forward(self, x): |
| | | if self.deconstruct_idx is not None: |
| | | x = x[self.deconstruct_idx] |
| | | return x.transpose(-2, -1) |
| | | |
| | | |
| | | class Fp32LayerNorm(nn.LayerNorm): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__(*args, **kwargs) |
| | | |
| | | def forward(self, input): |
| | | output = F.layer_norm( |
| | | input.float(), |
| | | self.normalized_shape, |
| | | self.weight.float() if self.weight is not None else None, |
| | | self.bias.float() if self.bias is not None else None, |
| | | self.eps, |
| | | ) |
| | | return output.type_as(input) |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__(*args, **kwargs) |
| | | |
| | | def forward(self, input): |
| | | output = F.layer_norm( |
| | | input.float(), |
| | | self.normalized_shape, |
| | | self.weight.float() if self.weight is not None else None, |
| | | self.bias.float() if self.bias is not None else None, |
| | | self.eps, |
| | | ) |
| | | return output.type_as(input) |
| | | |
| | | |
| | | class Fp32GroupNorm(nn.GroupNorm): |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__(*args, **kwargs) |
| | | |
| | | def forward(self, input): |
| | | output = F.group_norm( |
| | | input.float(), |
| | | self.num_groups, |
| | | self.weight.float() if self.weight is not None else None, |
| | | self.bias.float() if self.bias is not None else None, |
| | | self.eps, |
| | | ) |
| | | return output.type_as(input) |
| | | def __init__(self, *args, **kwargs): |
| | | super().__init__(*args, **kwargs) |
| | | |
| | | def forward(self, input): |
| | | output = F.group_norm( |
| | | input.float(), |
| | | self.num_groups, |
| | | self.weight.float() if self.weight is not None else None, |
| | | self.bias.float() if self.bias is not None else None, |
| | | self.eps, |
| | | ) |
| | | return output.type_as(input) |
| | | |
| | | |
| | | class ConvFeatureExtractionModel(nn.Module): |
| | | def __init__( |
| | | self, |
| | | conv_layers: List[Tuple[int, int, int]], |
| | | dropout: float = 0.0, |
| | | mode: str = "default", |
| | | conv_bias: bool = False, |
| | | ): |
| | | super().__init__() |
| | | |
| | | assert mode in {"default", "layer_norm"} |
| | | |
| | | def block( |
| | | n_in, |
| | | n_out, |
| | | k, |
| | | stride, |
| | | is_layer_norm=False, |
| | | is_group_norm=False, |
| | | conv_bias=False, |
| | | ): |
| | | def make_conv(): |
| | | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) |
| | | nn.init.kaiming_normal_(conv.weight) |
| | | return conv |
| | | |
| | | assert ( |
| | | is_layer_norm and is_group_norm |
| | | ) == False, "layer norm and group norm are exclusive" |
| | | |
| | | if is_layer_norm: |
| | | return nn.Sequential( |
| | | make_conv(), |
| | | nn.Dropout(p=dropout), |
| | | nn.Sequential( |
| | | TransposeLast(), |
| | | Fp32LayerNorm(dim, elementwise_affine=True), |
| | | TransposeLast(), |
| | | ), |
| | | nn.GELU(), |
| | | ) |
| | | elif is_group_norm: |
| | | return nn.Sequential( |
| | | make_conv(), |
| | | nn.Dropout(p=dropout), |
| | | Fp32GroupNorm(dim, dim, affine=True), |
| | | nn.GELU(), |
| | | ) |
| | | else: |
| | | return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) |
| | | |
| | | in_d = 1 |
| | | self.conv_layers = nn.ModuleList() |
| | | for i, cl in enumerate(conv_layers): |
| | | assert len(cl) == 3, "invalid conv definition: " + str(cl) |
| | | (dim, k, stride) = cl |
| | | |
| | | self.conv_layers.append( |
| | | block( |
| | | in_d, |
| | | dim, |
| | | k, |
| | | stride, |
| | | is_layer_norm=mode == "layer_norm", |
| | | is_group_norm=mode == "default" and i == 0, |
| | | conv_bias=conv_bias, |
| | | ) |
| | | ) |
| | | in_d = dim |
| | | |
| | | def forward(self, x): |
| | | |
| | | # BxT -> BxCxT |
| | | x = x.unsqueeze(1) |
| | | |
| | | for conv in self.conv_layers: |
| | | x = conv(x) |
| | | |
| | | return x |
| | | def __init__( |
| | | self, |
| | | conv_layers: List[Tuple[int, int, int]], |
| | | dropout: float = 0.0, |
| | | mode: str = "default", |
| | | conv_bias: bool = False, |
| | | ): |
| | | super().__init__() |
| | | |
| | | assert mode in {"default", "layer_norm"} |
| | | |
| | | def block( |
| | | n_in, |
| | | n_out, |
| | | k, |
| | | stride, |
| | | is_layer_norm=False, |
| | | is_group_norm=False, |
| | | conv_bias=False, |
| | | ): |
| | | def make_conv(): |
| | | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) |
| | | nn.init.kaiming_normal_(conv.weight) |
| | | return conv |
| | | |
| | | assert ( |
| | | is_layer_norm and is_group_norm |
| | | ) == False, "layer norm and group norm are exclusive" |
| | | |
| | | if is_layer_norm: |
| | | return nn.Sequential( |
| | | make_conv(), |
| | | nn.Dropout(p=dropout), |
| | | nn.Sequential( |
| | | TransposeLast(), |
| | | Fp32LayerNorm(dim, elementwise_affine=True), |
| | | TransposeLast(), |
| | | ), |
| | | nn.GELU(), |
| | | ) |
| | | elif is_group_norm: |
| | | return nn.Sequential( |
| | | make_conv(), |
| | | nn.Dropout(p=dropout), |
| | | Fp32GroupNorm(dim, dim, affine=True), |
| | | nn.GELU(), |
| | | ) |
| | | else: |
| | | return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) |
| | | |
| | | in_d = 1 |
| | | self.conv_layers = nn.ModuleList() |
| | | for i, cl in enumerate(conv_layers): |
| | | assert len(cl) == 3, "invalid conv definition: " + str(cl) |
| | | (dim, k, stride) = cl |
| | | |
| | | self.conv_layers.append( |
| | | block( |
| | | in_d, |
| | | dim, |
| | | k, |
| | | stride, |
| | | is_layer_norm=mode == "layer_norm", |
| | | is_group_norm=mode == "default" and i == 0, |
| | | conv_bias=conv_bias, |
| | | ) |
| | | ) |
| | | in_d = dim |
| | | |
| | | def forward(self, x): |
| | | |
| | | # BxT -> BxCxT |
| | | x = x.unsqueeze(1) |
| | | |
| | | for conv in self.conv_layers: |
| | | x = conv(x) |
| | | |
| | | return x |
| | | |
| | | |
| | | def compute_mask_indices( |
| | | shape: Tuple[int, int], |
| | |
| | | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) |
| | | |
| | | mask_idc = np.asarray( |
| | | [ |
| | | mask_idc[j] + offset |
| | | for j in range(len(mask_idc)) |
| | | for offset in range(lengths[j]) |
| | | ] |
| | | [mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])] |
| | | ) |
| | | |
| | | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) |
| | |
| | | mask_idc = np.random.choice(mask_idc, min_len, replace=False) |
| | | if mask_dropout > 0: |
| | | num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) |
| | | mask_idc = np.random.choice( |
| | | mask_idc, len(mask_idc) - num_holes, replace=False |
| | | ) |
| | | mask_idc = np.random.choice(mask_idc, len(mask_idc) - num_holes, replace=False) |
| | | |
| | | mask[i, mask_idc] = True |
| | | |
| | |
| | | @staticmethod |
| | | def backward(ctx, grad): |
| | | return grad * ctx.scale, None |
| | | |
| | | |
| | | |
| | | |
| | | def is_xla_tensor(tensor): |
| | | return torch.is_tensor(tensor) and tensor.device.type == "xla" |
| | | |
| | |
| | | tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) |
| | | else: |
| | | tensor[indices] = value |
| | | return tensor |
| | | return tensor |