| | |
| | | torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) |
| | | torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) |
| | | torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) |
| | | # open tf32 |
| | | torch.backends.cuda.matmul.allow_tf32 = kwargs.get("enable_tf32", True) |
| | | |
| | | local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| | | if local_rank == 0: |
| | |
| | | if use_ddp: |
| | | model = model.cuda(local_rank) |
| | | model = DDP(model, device_ids=[local_rank], |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False)) |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True)) |
| | | elif use_fsdp: |
| | | # model = FSDP(model).cuda(local_rank) |
| | | |