| | |
| | | import logging |
| | | |
| | | import torch |
| | | import random |
| | | |
| | |
| | | self.float_pad_value = float_pad_value |
| | | self.sos = kwargs.get("sos", "<|startoftranscript|>") |
| | | self.eos = kwargs.get("eos", "<|endoftext|>") |
| | | self.batch_size = kwargs.get("batch_size") |
| | | self.batch_type = kwargs.get("batch_type") |
| | | |
| | | def get_source_len(self, index): |
| | | item = self.index_ds[index] |
| | |
| | | outputs[key] = torch.nn.utils.rnn.pad_sequence( |
| | | data_list, batch_first=True, padding_value=pad_value |
| | | ) |
| | | |
| | | if self.batch_type != "example": |
| | | b, t, _ = outputs["speech"].shape |
| | | if b * t > self.batch_size: |
| | | beg = torch.randint(0, 2, ()).item() |
| | | logging.info( |
| | | f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 1st, beg:{beg}" |
| | | ) |
| | | for key, data_list in outputs.items(): |
| | | outputs[key] = outputs[key][beg : beg + b : 2] |
| | | |
| | | b, t, _ = outputs["speech"].shape |
| | | if b * t > self.batch_size: |
| | | beg = torch.randint(0, 2, ()).item() |
| | | logging.info( |
| | | f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 2nd, beg:{beg}" |
| | | ) |
| | | for key, data_list in outputs.items(): |
| | | outputs[key] = outputs[key][beg : beg + b : 2] |
| | | |
| | | b, t, _ = outputs["speech"].shape |
| | | if b * t > self.batch_size: |
| | | beg = torch.randint(0, 2, ()).item() |
| | | logging.info( |
| | | f"Warning, b * t: {b * t} > {self.batch_size}, drop half data 3th, beg:{beg}" |
| | | ) |
| | | for key, data_list in outputs.items(): |
| | | outputs[key] = outputs[key][beg : beg + b : 2] |
| | | return outputs |