| | |
| | | tokenizer=tokenizer, |
| | | ) |
| | | |
| | | if len(kwargs.get("data_type", [])) > 1: |
| | | if ( |
| | | isinstance(kwargs.get("data_type", None), (list, tuple)) |
| | | and len(kwargs.get("data_type", [])) > 1 |
| | | ): |
| | | audio_sample_list, text_token_int_list = audio_sample_list |
| | | text_token_int = text_token_int_list[0] |
| | | else: |
| | |
| | | ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[ |
| | | None, : |
| | | ] |
| | | ys_pad_lens = torch.tensor([len(text_token_int)], dtype=torch.int64).to( |
| | | ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to( |
| | | kwargs["device"] |
| | | )[None, :] |
| | | decoder_out = self.model.decoder( |