| | |
| | | self.num_heads = model.encoders[0].self_attn.h |
| | | self.hidden_size = model.encoders[0].self_attn.linear_out.out_features |
| | | |
| | | def prepare_mask(self, mask): |
| | | def prepare_mask(self, mask, sub_masks): |
| | | mask_3d_btd = mask[:, :, None] |
| | | sub_masks = subsequent_mask(mask.size(-1)).type(torch.float32) |
| | | # sub_masks = subsequent_mask(mask.size(-1)).type(torch.float32) |
| | | if len(mask.shape) == 2: |
| | | mask_4d_bhlt = 1 - sub_masks[:, None, None, :] |
| | | elif len(mask.shape) == 3: |
| | |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | vad_mask: torch.Tensor, |
| | | sub_masks: torch.Tensor, |
| | | ): |
| | | speech = speech * self._output_size ** 0.5 |
| | | mask = self.make_pad_mask(speech_lengths) |