游雁
2024-04-29 f57b68121a526baea43b2e93f4540d8a2995f633
funasr/models/eend/eend_ola_dataloader.py
@@ -12,17 +12,15 @@
    speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
    speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
    orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
    batch = dict(speech=speech,
                 speaker_labels=speaker_labels,
                 orders=orders)
    batch = dict(speech=speech, speaker_labels=speaker_labels, orders=orders)
    return keys, batch
class EENDOLADataset(Dataset):
    def __init__(
            self,
            data_file,
        self,
        data_file,
    ):
        self.data_file = data_file
        with open(data_file) as f:
@@ -44,14 +42,16 @@
        return key, speech, speaker_label, order
class EENDOLADataLoader():
class EENDOLADataLoader:
    def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
        dataset = EENDOLADataset(data_file)
        self.data_loader = DataLoader(dataset,
                                      batch_size=batch_size,
                                      collate_fn=custom_collate,
                                      shuffle=shuffle,
                                      num_workers=num_workers)
        self.data_loader = DataLoader(
            dataset,
            batch_size=batch_size,
            collate_fn=custom_collate,
            shuffle=shuffle,
            num_workers=num_workers,
        )
    def build_iter(self, epoch):
        return self.data_loader
        return self.data_loader