| | |
| | | except: |
| | | writer = None |
| | | |
| | | # if use_ddp or use_fsdp: |
| | | # context = Join([model]) |
| | | # else: |
| | | context = nullcontext() |
| | | if use_ddp or use_fsdp: |
| | | context = Join([model]) |
| | | else: |
| | | context = nullcontext() |
| | | |
| | | for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): |
| | | time1 = time.perf_counter() |
| | |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | with context: |
| | | trainer.validate_epoch( |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | scheduler.step() |
| | | trainer.validate_epoch( |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | |
| | | |
| | | trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler) |
| | |
| | | speed_stats = {} |
| | | time5 = time.perf_counter() |
| | | # iterator_stop = torch.tensor(0).to(self.device) |
| | | |
| | | dataloader_val.batch_sampler.set_epoch(epoch) |
| | | for batch_idx, batch in enumerate(dataloader_val): |
| | | # if self.use_ddp or self.use_fsdp: |
| | | # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) |