zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/models/eend/eend_ola_dataloader.py
@@ -12,9 +12,7 @@
    speech = [torch.from_numpy(np.copy(sph)).to(torch.float32) for sph in speech]
    speaker_labels = [torch.from_numpy(np.copy(spk)).to(torch.float32) for spk in speaker_labels]
    orders = [torch.from_numpy(np.copy(o)).to(torch.int64) for o in orders]
    batch = dict(speech=speech,
                 speaker_labels=speaker_labels,
                 orders=orders)
    batch = dict(speech=speech, speaker_labels=speaker_labels, orders=orders)
    return keys, batch
@@ -44,14 +42,16 @@
        return key, speech, speaker_label, order
class EENDOLADataLoader():
class EENDOLADataLoader:
    def __init__(self, data_file, batch_size, shuffle=True, num_workers=8):
        dataset = EENDOLADataset(data_file)
        self.data_loader = DataLoader(dataset,
        self.data_loader = DataLoader(
            dataset,
                                      batch_size=batch_size,
                                      collate_fn=custom_collate,
                                      shuffle=shuffle,
                                      num_workers=num_workers)
            num_workers=num_workers,
        )
    def build_iter(self, epoch):
        return self.data_loader