| | |
| | | else: |
| | | model = model.to(device=kwargs.get("device", "cuda")) |
| | | |
| | | logging.info(f"{model}") |
| | | if local_rank == 0: |
| | | logging.info(f"{model}") |
| | | kwargs["device"] = next(model.parameters()).device |
| | | |
| | | # optim |
| | |
| | | # if use_ddp or use_fsdp: |
| | | # context = Join([model]) |
| | | # else: |
| | | # context = nullcontext() |
| | | context = nullcontext() |
| | | |
| | | for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): |
| | | time1 = time.perf_counter() |
| | | with context: |
| | |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | with context: |
| | | trainer.validate_epoch( |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | scheduler.step() |
| | | trainer.validate_epoch( |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | |
| | | |
| | | trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler) |