| | |
| | | speed_stats = {} |
| | | time5 = time.perf_counter() |
| | | iterator_stop = torch.tensor(0).to(self.device) |
| | | dist.barrier() |
| | | print(f"before iter, iterator_stop: {iterator_stop}\n") |
| | | dataloader_train.batch_sampler.set_epoch(epoch) |
| | | for batch_idx, batch in enumerate(dataloader_train): |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) |
| | |
| | | speed_stats = {} |
| | | time5 = time.perf_counter() |
| | | iterator_stop = torch.tensor(0).to(self.device) |
| | | dist.barrier() |
| | | print(f"before iter, iterator_stop: {iterator_stop}\n") |
| | | for batch_idx, batch in enumerate(dataloader_val): |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) |
| | | if epoch >= 1: |
| | | print(f"iterator_stop: {iterator_stop}\n") |
| | | if iterator_stop > 0: |
| | | break |
| | | time1 = time.perf_counter() |