| | |
| | | speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech] |
| | | speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels] |
| | | orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders] |
| | | batch = dict(speech=speech, |
| | | speaker_labels=speaker_labels, |
| | | orders=orders) |
| | | batch = dict(speech=speech, speaker_labels=speaker_labels, orders=orders) |
| | | |
| | | return keys, batch |
| | | |
| | | |
| | | class EENDOLADataset(Dataset): |
| | | def __init__( |
| | | self, |
| | | data_file, |
| | | self, |
| | | data_file, |
| | | ): |
| | | self.data_file = data_file |
| | | with open(data_file) as f: |
| | |
| | | return key, speech, speaker_label, order |
| | | |
| | | |
| | | class EENDOLADataLoader(): |
| | | class EENDOLADataLoader: |
| | | def __init__(self, data_file, batch_size, shuffle=True, num_workers=8): |
| | | dataset = EENDOLADataset(data_file) |
| | | self.data_loader = DataLoader(dataset, |
| | | batch_size=batch_size, |
| | | collate_fn=custom_collate, |
| | | shuffle=shuffle, |
| | | num_workers=num_workers) |
| | | self.data_loader = DataLoader( |
| | | dataset, |
| | | batch_size=batch_size, |
| | | collate_fn=custom_collate, |
| | | shuffle=shuffle, |
| | | num_workers=num_workers, |
| | | ) |
| | | |
| | | def build_iter(self, epoch): |
| | | return self.data_loader |
| | | return self.data_loader |