| | |
| | | self.mask_other = mask_other |
| | | self.no_mask_overlap = no_mask_overlap |
| | | self.mask_min_space = mask_min_space |
| | | self.require_same_masks = require_same_masks # if set as True, collate_fn should be clipping |
| | | self.require_same_masks = ( |
| | | require_same_masks # if set as True, collate_fn should be clipping |
| | | ) |
| | | self.mask_dropout = mask_dropout |
| | | ## channel masking |
| | | self.mask_channel_length = mask_channel_length |
| | |
| | | min_space=self.mask_channel_min_space, |
| | | ) |
| | | mask_channel_indices = ( |
| | | torch.from_numpy(mask_channel_indices) |
| | | .to(x.device) |
| | | .unsqueeze(1) |
| | | .expand(-1, T, -1) |
| | | torch.from_numpy(mask_channel_indices).to(x.device).unsqueeze(1).expand(-1, T, -1) |
| | | ) |
| | | x[mask_channel_indices] = 0 |
| | | |
| | |
| | | |
| | | if self.batch_norm_target_layer: |
| | | target_layer_results = [ |
| | | F.batch_norm( |
| | | tl.float(), running_mean=None, running_var=None, training=True |
| | | ) |
| | | F.batch_norm(tl.float(), running_mean=None, running_var=None, training=True) |
| | | for tl in target_layer_results |
| | | ] |
| | | |
| | | if self.instance_norm_target_layer: |
| | | target_layer_results = [ |
| | | F.instance_norm(tl.float()) for tl in target_layer_results |
| | | ] |
| | | target_layer_results = [F.instance_norm(tl.float()) for tl in target_layer_results] |
| | | |
| | | if permuted: |
| | | target_layer_results = [ |
| | |
| | | |
| | | if self.group_norm_target_layer: |
| | | target_layer_results = [ |
| | | F.layer_norm(tl.float(), tl.shape[-2:]) |
| | | for tl in target_layer_results |
| | | F.layer_norm(tl.float(), tl.shape[-2:]) for tl in target_layer_results |
| | | ] |
| | | |
| | | if self.layer_norm_target_layer: |
| | | target_layer_results = [ |
| | | F.layer_norm(tl.float(), tl.shape[-1:]) |
| | | for tl in target_layer_results |
| | | F.layer_norm(tl.float(), tl.shape[-1:]) for tl in target_layer_results |
| | | ] |
| | | |
| | | y = sum(target_layer_results) / len(target_layer_results) |
| | |
| | | f"target var is {result['target_var'].item()} < {self.min_target_var}, exiting" |
| | | ) |
| | | if self.num_updates > 5000 and result["pred_var"] < self.min_pred_var: |
| | | logging.error( |
| | | f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting" |
| | | ) |
| | | logging.error(f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting") |
| | | raise Exception( |
| | | f"pred var is {result['pred_var'].item()} < {self.min_pred_var}, exiting" |
| | | ) |
| | |
| | | else: |
| | | return torch.sqrt(y.var(dim=0) + 1e-6).mean() |
| | | |
| | | def extract_features( |
| | | self, xs_pad, ilens, mask=False, layer=None |
| | | ): |
| | | def extract_features(self, xs_pad, ilens, mask=False, layer=None): |
| | | res = self.forward( |
| | | xs_pad, |
| | | ilens, |