| | |
| | | torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) |
| | | torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) |
| | | torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) |
| | | # open tf32 |
| | | torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True) |
| | | |
| | | local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| | | if local_rank == 0: |
| | |
| | | if use_ddp: |
| | | model = model.cuda(local_rank) |
| | | model = DDP(model, device_ids=[local_rank], |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False)) |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True)) |
| | | elif use_fsdp: |
| | | # model = FSDP(model).cuda(local_rank) |
| | | |
| | |
| | | else: |
| | | model = model.to(device=kwargs.get("device", "cuda")) |
| | | |
| | | logging.info(f"{model}") |
| | | if local_rank == 0: |
| | | logging.info(f"{model}") |
| | | kwargs["device"] = next(model.parameters()).device |
| | | |
| | | # optim |
| | |
| | | # dataset |
| | | logging.info("Build dataloader") |
| | | dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")) |
| | | # dataloader = dataloader_class(**kwargs) |
| | | dataloader_tr, dataloader_val = dataloader_class(**kwargs) |
| | | dataloader = dataloader_class(**kwargs) |
| | | # dataloader_tr, dataloader_val = dataloader_class(**kwargs) |
| | | trainer = Trainer(local_rank=local_rank, |
| | | use_ddp=use_ddp, |
| | | use_fsdp=use_fsdp, |
| | |
| | | # if use_ddp or use_fsdp: |
| | | # context = Join([model]) |
| | | # else: |
| | | # context = nullcontext() |
| | | context = nullcontext() |
| | | |
| | | for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): |
| | | time1 = time.perf_counter() |
| | | with context: |
| | | # dataloader_tr, dataloader_val = dataloader.build_iter(epoch) |
| | | dataloader_tr, dataloader_val = dataloader.build_iter(epoch) |
| | | trainer.train_epoch( |
| | | model=model, |
| | | optim=optim, |
| | |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | with context: |
| | | trainer.validate_epoch( |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | scheduler.step() |
| | | trainer.validate_epoch( |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | |
| | | |
| | | trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler) |
| | |
| | | |
| | | |
| | | if trainer.rank == 0: |
| | | average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list) |
| | | average_checkpoints(trainer.output_dir, trainer.avg_nbest_model) |
| | | |
| | | trainer.close() |
| | | |