| | |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | |
| | | ): |
| | | target_mask = kwargs.get("target_mask", None) |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | |
| | | |
| | | language = kwargs.get("language", None) |
| | | if language is not None: |
| | | language_query = self.embed(torch.LongTensor([[self.lid_dict[language] if language in self.lid_dict else 0]]).to(speech.device)).repeat(speech.size(0), 1, 1) |
| | | language_query = self.embed( |
| | | torch.LongTensor( |
| | | [[self.lid_dict[language] if language in self.lid_dict else 0]] |
| | | ).to(speech.device) |
| | | ).repeat(speech.size(0), 1, 1) |
| | | else: |
| | | language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat(speech.size(0), 1, 1) |
| | | language_query = self.embed(torch.LongTensor([[0]]).to(speech.device)).repeat( |
| | | speech.size(0), 1, 1 |
| | | ) |
| | | textnorm = kwargs.get("text_norm", "wotextnorm") |
| | | textnorm_query = self.embed(torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)).repeat(speech.size(0), 1, 1) |
| | | textnorm_query = self.embed( |
| | | torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device) |
| | | ).repeat(speech.size(0), 1, 1) |
| | | speech = torch.cat((textnorm_query, speech), dim=1) |
| | | speech_lengths += 1 |
| | | |
| | | event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1) |
| | | event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( |
| | | speech.size(0), 1, 1 |
| | | ) |
| | | input_query = torch.cat((language_query, event_emo_query), dim=1) |
| | | speech = torch.cat((input_query, speech), dim=1) |
| | | speech_lengths += 3 |