| | |
| | | |
| | | def main(**kwargs): |
| | | print(kwargs) |
| | | |
| | | # set random seed |
| | | tables.print() |
| | | set_all_random_seed(kwargs.get("seed", 0)) |
| | | torch.backends.cudnn.enabled = kwargs.get("cudnn_enabled", torch.backends.cudnn.enabled) |
| | | torch.backends.cudnn.benchmark = kwargs.get("cudnn_benchmark", torch.backends.cudnn.benchmark) |
| | | torch.backends.cudnn.deterministic = kwargs.get("cudnn_deterministic", True) |
| | | |
| | | local_rank = int(os.environ.get('LOCAL_RANK', 0)) |
| | | if local_rank == 0: |
| | | tables.print() |
| | | # Check if we are using DDP or FSDP |
| | | use_ddp = 'WORLD_SIZE' in os.environ and int(os.environ["WORLD_SIZE"]) > 1 |
| | | use_fsdp = kwargs.get("use_fsdp", None) |
| | |
| | | init_param = (init_param,) |
| | | logging.info("init_param is not None: %s", init_param) |
| | | for p in init_param: |
| | | logging.info(f"Loading pretrained params from {p}") |
| | | load_pretrained_model( |
| | | model=model, |
| | | path=p, |
| | | ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), |
| | | oss_bucket=kwargs.get("oss_bucket", None), |
| | | scope_map=kwargs.get("scope_map", None), |
| | | excludes=kwargs.get("excludes", None), |
| | | ) |
| | | else: |
| | | if os.path.exists(p): |
| | | logging.info(f"Loading pretrained params from {p}") |
| | | load_pretrained_model( |
| | | model=model, |
| | | path=p, |
| | | ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True), |
| | | oss_bucket=kwargs.get("oss_bucket", None), |
| | | scope_map=kwargs.get("scope_map", None), |
| | | excludes=kwargs.get("excludes", None), |
| | | ) |
| | | else: |
| | | logging.info(f"Checkpoint does not exist, init randomly: {p}") |
| | | elif kwargs.get("init", None): |
| | | initialize(model, kwargs.get("init", "kaiming_normal")) |
| | | else: |
| | | print("No initialize method") |
| | | |
| | | |
| | | # freeze_param |