| | |
| | | |
| | | idx_map = self.shuffle_idx[idx] |
| | | # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] |
| | | sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \ |
| | | self.dataset.indexed_dataset[idx_map]["target_len"] |
| | | sample_len_cur = self.dataset.indexed_dataset.get_source_len(self.dataset.indexed_dataset[idx_map]) + \ |
| | | self.dataset.indexed_dataset.get_target_len(self.dataset.indexed_dataset[idx_map]) |
| | | |
| | | datalen_with_index.append([idx, sample_len_cur]) |
| | | |
| | |
| | | collate_fn=dataset.collator, |
| | | batch_sampler=batch_sampler, |
| | | shuffle=False, |
| | | num_workers=8, |
| | | num_workers=0, |
| | | pin_memory=True) |
| | | |
| | | print(len(dataset)) |
| | |
| | | |
| | | def __getitem__(self, index): |
| | | return self.contents[index] |
| | | |
| | | def get_source_len(self, data_dict): |
| | | return data_dict["source_len"] |
| | | |
| | | def get_target_len(self, data_dict): |
| | | |
| | | return data_dict["target_len"] if "target_len" in data_dict else 0 |
| | | |
| | | |
| | | class AudioDataset(torch.utils.data.Dataset): |