| | |
| | | dist.init_process_group( |
| | | backend=kwargs.get("backend", "nccl"), |
| | | init_method="env://", |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get( |
| | | "find_unused_parameters", False |
| | | ), |
| | | ) |
| | | torch.cuda.set_device(local_rank) |
| | | |
| | |
| | | **kwargs.get("train_conf"), |
| | | ) |
| | | |
| | | model = trainer.warp_model(model) |
| | | model = trainer.warp_model(model, **kwargs) |
| | | |
| | | kwargs["device"] = int(os.environ.get("LOCAL_RANK", 0)) |
| | | trainer.device = int(os.environ.get("LOCAL_RANK", 0)) |
| | |
| | | ) |
| | | trainer.start_step = 0 |
| | | |
| | | torch.cuda.empty_cache() |
| | | device = next(model.parameters()).device |
| | | if device.type == "cuda": |
| | | with torch.cuda.device(device): |
| | | torch.cuda.empty_cache() |
| | | |
| | | time_escaped = (time.perf_counter() - time_slice_i) / 3600.0 |
| | | logging.info( |