| | |
| | | logger = logging.getLogger(__name__) |
| | | |
| | | |
| | | |
| | | MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"]) |
| | | MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"]) |
| | | |
| | |
| | | self.alibi_scale = nn.Parameter( |
| | | torch.full( |
| | | ( |
| | | (modality_cfg.prenet_depth + modality_cfg.model_depth) |
| | | if modality_cfg.learned_alibi_scale_per_layer |
| | | else 1, |
| | | ( |
| | | (modality_cfg.prenet_depth + modality_cfg.model_depth) |
| | | if modality_cfg.learned_alibi_scale_per_layer |
| | | else 1 |
| | | ), |
| | | 1, |
| | | self.modality_cfg.num_alibi_heads |
| | | if modality_cfg.learned_alibi_scale_per_head |
| | | else 1, |
| | | ( |
| | | self.modality_cfg.num_alibi_heads |
| | | if modality_cfg.learned_alibi_scale_per_head |
| | | else 1 |
| | | ), |
| | | 1, |
| | | 1, |
| | | ), |
| | |
| | | device="cpu", |
| | | ) |
| | | self.alibi_bias = nn.Parameter(alibi_bias) |
| | | self.get_alibi_bias = partial( |
| | | _learned_alibi_bias, alibi_bias=self.alibi_bias |
| | | ) |
| | | self.get_alibi_bias = partial(_learned_alibi_bias, alibi_bias=self.alibi_bias) |
| | | |
| | | def upgrade_state_dict_named(self, state_dict, name): |
| | | k = f"{name}.alibi_scale" |
| | |
| | | if self.local_grad_mult == 1.0: |
| | | x = self.local_encoder(features) |
| | | else: |
| | | x = GradMultiply.apply( |
| | | self.local_encoder(features), self.local_grad_mult |
| | | ) |
| | | x = GradMultiply.apply(self.local_encoder(features), self.local_grad_mult) |
| | | else: |
| | | with torch.no_grad(): |
| | | x = self.local_encoder(features) |
| | |
| | | x = x.repeat_interleave(clone_batch, 0) |
| | | if mask_seeds is not None: |
| | | clone_hash = [ |
| | | int(hash((mask_seeds.seed, ind)) % 1e10) |
| | | for ind in range(clone_batch - 1) |
| | | int(hash((mask_seeds.seed, ind)) % 1e10) for ind in range(clone_batch - 1) |
| | | ] |
| | | clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1) |
| | | |
| | |
| | | id = id.repeat_interleave(clone_batch, 0) |
| | | id = id.view(-1, clone_batch) + clone_hash.to(id) |
| | | id = id.view(-1) |
| | | mask_seeds = MaskSeed( |
| | | seed=mask_seeds.seed, update=mask_seeds.update, ids=id |
| | | ) |
| | | mask_seeds = MaskSeed(seed=mask_seeds.seed, update=mask_seeds.update, ids=id) |
| | | if padding_mask is not None: |
| | | padding_mask = padding_mask.repeat_interleave(clone_batch, 0) |
| | | |
| | |
| | | x, |
| | | masked_padding_mask, |
| | | alibi_bias, |
| | | alibi_scale[: self.modality_cfg.prenet_depth] |
| | | if alibi_scale is not None |
| | | else None, |
| | | alibi_scale[: self.modality_cfg.prenet_depth] if alibi_scale is not None else None, |
| | | ) |
| | | |
| | | return { |
| | |
| | | "local_features": local_features, |
| | | "padding_mask": masked_padding_mask, |
| | | "alibi_bias": alibi_bias, |
| | | "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :] |
| | | if alibi_scale is not None and alibi_scale.size(0) > 1 |
| | | else alibi_scale, |
| | | "alibi_scale": ( |
| | | alibi_scale[self.modality_cfg.prenet_depth :] |
| | | if alibi_scale is not None and alibi_scale.size(0) > 1 |
| | | else alibi_scale |
| | | ), |
| | | "encoder_mask": mask_info, |
| | | } |
| | | |
| | |
| | | x = x * (1 - mask.type_as(x).unsqueeze(-1)) |
| | | else: |
| | | num_masks = mask.sum().item() |
| | | masks = x.new_empty(num_masks, x.size(-1)).normal_( |
| | | 0, cfg.mask_noise_std |
| | | ) |
| | | masks = x.new_empty(num_masks, x.size(-1)).normal_(0, cfg.mask_noise_std) |
| | | x = index_put(x, mask, masks) |
| | | if cfg.mask_channel_prob > 0: |
| | | mask_channel = compute_mask_indices( |
| | |
| | | cfg.mask_channel_length, |
| | | ) |
| | | mask_channel = ( |
| | | torch.from_numpy(mask_channel) |
| | | .to(x.device) |
| | | .unsqueeze(1) |
| | | .expand(-1, T, -1) |
| | | torch.from_numpy(mask_channel).to(x.device).unsqueeze(1).expand(-1, T, -1) |
| | | ) |
| | | x = index_put(x, mask_channel, 0) |
| | | return x |
| | |
| | | |
| | | generator = None |
| | | if mask_seed is not None: |
| | | seed = int( |
| | | hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6 |
| | | ) |
| | | seed = int(hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6) |
| | | generator = torch.Generator(device=x.device) |
| | | generator.manual_seed(seed) |
| | | |
| | |
| | | |
| | | ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D) |
| | | |
| | | return MaskInfo( |
| | | x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep |
| | | ) |
| | | return MaskInfo(x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep) |
| | | |
| | | |
| | | def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor: |
| | |
| | | # autoregressive model so we want a symmetric mask with 0 on the |
| | | # diagonal and other wise linear decreasing valuees |
| | | pos_bias = ( |
| | | torch.abs( |
| | | torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1) |
| | | ) |
| | | * -1 |
| | | torch.abs(torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1)) * -1 |
| | | ) |
| | | elif dims == 2: |
| | | if distance == "manhattan": |
| | |
| | | else: |
| | | raise Exception(f"unsupported number of alibi dims: {dims}") |
| | | |
| | | alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand( |
| | | attn_heads, -1, -1 |
| | | ) |
| | | alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand(attn_heads, -1, -1) |
| | | |
| | | return alibi_bias |
| | | |