| | |
| | | |
| | | |
| | | def main(**kwargs): |
| | | |
| | | print(kwargs) |
| | | # set random seed |
| | | tables.print() |
| | | set_all_random_seed(kwargs.get("seed", 0)) |
| | |
| | | if batch_sampler is not None: |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) |
| | | batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf")) |
| | | batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf")) |
| | | dataloader_tr = torch.utils.data.DataLoader(dataset_tr, |
| | | collate_fn=dataset_tr.collator, |
| | | batch_sampler=batch_sampler, |