| | |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size = speech.shape[0] |
| | | |
| | | |
| | | # Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | # Encoder |
| | | if kwargs.get("fp16", False): |
| | | speech = speech.half() |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |