| | |
| | | max_seq_len=512, |
| | | **kwargs, |
| | | ): |
| | | |
| | | self.device = kwargs.get("device") |
| | | is_onnx = kwargs.get("type", "onnx") == "onnx" |
| | | encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export") |
| | | self.encoder = encoder_class(self.encoder, onnx=is_onnx) |
| | |
| | | |
| | | return encoder_model, decoder_model |
| | | |
| | | def _export_encoder_forward( |
| | | def export_encoder_forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | |
| | | def export_encoder_name(self): |
| | | return "model.onnx" |
| | | |
| | | def _export_decoder_forward( |
| | | def export_decoder_forward( |
| | | self, |
| | | enc: torch.Tensor, |
| | | enc_len: torch.Tensor, |