| | |
| | | def forward(self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | vad_mask: torch.Tensor, |
| | | vad_masks: torch.Tensor, |
| | | sub_masks: torch.Tensor, |
| | | ): |
| | | speech = speech * self._output_size ** 0.5 |
| | | mask = self.make_pad_mask(speech_lengths) |
| | | vad_masks = self.prepare_mask(mask, vad_masks) |
| | | mask = self.prepare_mask(mask, sub_masks) |
| | | vad_mask = self.prepare_mask(mask, vad_mask) |
| | | |
| | | if self.embed is None: |
| | | xs_pad = speech |
| | | else: |
| | |
| | | # encoder_outs = self.model.encoders(xs_pad, mask) |
| | | for layer_idx, encoder_layer in enumerate(self.model.encoders): |
| | | if layer_idx == len(self.model.encoders) - 1: |
| | | mask = vad_mask |
| | | mask = vad_masks |
| | | encoder_outs = encoder_layer(xs_pad, mask) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | |
| | |
| | | def get_output_size(self): |
| | | return self.model.encoders[0].size |
| | | |
| | | def get_dummy_inputs(self): |
| | | feats = torch.randn(1, 100, self.feats_dim) |
| | | return (feats) |
| | | |
| | | def get_input_names(self): |
| | | return ['feats'] |
| | | |
| | | def get_output_names(self): |
| | | return ['encoder_out', 'encoder_out_lens', 'predictor_weight'] |
| | | |
| | | def get_dynamic_axes(self): |
| | | return { |
| | | 'feats': { |
| | | 1: 'feats_length' |
| | | }, |
| | | 'encoder_out': { |
| | | 1: 'enc_out_length' |
| | | }, |
| | | 'predictor_weight': { |
| | | 1: 'pre_out_length' |
| | | } |
| | | |
| | | } |
| | | # def get_dummy_inputs(self): |
| | | # feats = torch.randn(1, 100, self.feats_dim) |
| | | # return (feats) |
| | | # |
| | | # def get_input_names(self): |
| | | # return ['feats'] |
| | | # |
| | | # def get_output_names(self): |
| | | # return ['encoder_out', 'encoder_out_lens', 'predictor_weight'] |
| | | # |
| | | # def get_dynamic_axes(self): |
| | | # return { |
| | | # 'feats': { |
| | | # 1: 'feats_length' |
| | | # }, |
| | | # 'encoder_out': { |
| | | # 1: 'enc_out_length' |
| | | # }, |
| | | # 'predictor_weight': { |
| | | # 1: 'pre_out_length' |
| | | # } |
| | | # |
| | | # } |
| | |
| | | length = 10 |
| | | text_indexes = torch.tensor([[266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757, 266757]], dtype=torch.int32) |
| | | text_lengths = torch.tensor([length], dtype=torch.int32) |
| | | vad_mask = vad_mask(10, 3, dtype=torch.float32)[None, None, :, :] |
| | | vad_masks = vad_mask(10, 3, dtype=torch.float32)[None, None, :, :] |
| | | sub_masks = torch.ones(length, length, dtype=torch.float32) |
| | | sub_masks = torch.tril(sub_masks).type(torch.float32) |
| | | return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :]) |
| | | return (text_indexes, text_lengths, vad_masks, sub_masks[None, None, :, :]) |
| | | |
| | | def get_input_names(self): |
| | | return ['input', 'text_lengths', 'vad_mask', 'sub_masks'] |
| | | return ['input', 'text_lengths', 'vad_masks', 'sub_masks'] |
| | | |
| | | def get_output_names(self): |
| | | return ['logits'] |
| | |
| | | 'input': { |
| | | 1: 'feats_length' |
| | | }, |
| | | 'vad_mask': { |
| | | 'vad_masks': { |
| | | 2: 'feats_length1', |
| | | 3: 'feats_length2' |
| | | }, |