| | |
| | | def set_epoch(self, epoch): |
| | | self.epoch = epoch |
| | | |
| | | |
| | | def CustomDistributedBatchSampler_fn(dataset, **kwargs): |
| | | dataloader_args = {"dataset": dataset} |
| | | dataloader_args["batch_sampler"] = CustomDistributedBatchSampler(dataset, **kwargs) |
| | | dataloader_args["num_workers"] = kwargs.get("num_workers", 4) |
| | | dataloader_args["pin_memory"] = kwargs.get("pin_memory", True) |
| | | |
| | | return dataloader_args |
| | | |
| | | @tables.register("batch_sampler_classes", "CustomDistributedBatchSampler") |
| | | class CustomDistributedBatchSampler(Sampler): |
| | | def __init__(self, dataset, |