| | |
| | | target_mask = ( |
| | | [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1] |
| | | ) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] |
| | | target_mask_lengths = len(target_mask) |
| | | target_mask = torch.tensor(target_mask, dtype=torch.float32) |
| | | |
| | | target_mask_lengths = torch.tensor([target_mask_lengths], dtype=torch.int32) |
| | | return { |
| | | "speech": speech[0, :, :], |
| | | "speech_lengths": speech_lengths, |
| | |
| | | ) |
| | | |
| | | if self.batch_type != "example": |
| | | b, t, _ = outputs["speech"].shape |
| | | if b * t > self.batch_size: |
| | | beg = torch.randint(0, 2, ()).item() |
| | | logging.info( |
| | | f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 1st, beg:{beg}" |
| | | ) |
| | | for key, data_list in outputs.items(): |
| | | outputs[key] = outputs[key][beg : beg + b : 2] |
| | | for i in range(3): |
| | | outputs = self._filter_badcase(outputs) |
| | | |
| | | b, t, _ = outputs["speech"].shape |
| | | if b * t > self.batch_size: |
| | | beg = torch.randint(0, 2, ()).item() |
| | | logging.info( |
| | | f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 2nd, beg:{beg}" |
| | | ) |
| | | for key, data_list in outputs.items(): |
| | | outputs[key] = outputs[key][beg : beg + b : 2] |
| | | return outputs |
| | | |
| | | b, t, _ = outputs["speech"].shape |
| | | if b * t > self.batch_size: |
| | | beg = torch.randint(0, 2, ()).item() |
| | | logging.info( |
| | | f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 3th, beg:{beg}" |
| | | ) |
| | | for key, data_list in outputs.items(): |
| | | outputs[key] = outputs[key][beg : beg + b : 2] |
| | | def _filter_badcase(self, outputs, i=0): |
| | | b, t, _ = outputs["speech"].shape |
| | | if b * t > self.batch_size: |
| | | beg = torch.randint(0, 2, ()).item() |
| | | logging.info( |
| | | f"Warning, b * t: {b * t} > {self.batch_size}, drop half data {i}th, beg:{beg}" |
| | | ) |
| | | for key, data_list in outputs.items(): |
| | | outputs[key] = outputs[key][beg : beg + b : 2] |
| | | |
| | | speech_lengths_max = outputs["speech_lengths_max"].max().item() |
| | | outputs["speech"] = outputs["speech"][:, :speech_lengths_max, :] |
| | | text_lengths_max = outputs["text_lengths"].max().item() |
| | | outputs["text"] = outputs["text"][:, :text_lengths_max] |
| | | target_mask_lengths_max = outputs["target_mask_lengths_max"].max().item() |
| | | outputs["target_mask"] = outputs["target_mask"][:, :target_mask_lengths_max] |
| | | |
| | | return outputs |