| | |
| | | # freeze_param |
| | | freeze_param = kwargs.get("freeze_param", None) |
| | | if freeze_param is not None: |
| | | freeze_param = eval(freeze_param) |
| | | if "," in freeze_param: |
| | | freeze_param = eval(freeze_param) |
| | | if isinstance(freeze_param, Sequence): |
| | | freeze_param = (freeze_param,) |
| | | logging.info("freeze_param is not None: %s", freeze_param) |
| | |
| | | if use_ddp: |
| | | model = model.cuda(local_rank) |
| | | model = DDP(model, device_ids=[local_rank], |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", True)) |
| | | find_unused_parameters=kwargs.get("train_conf", {}).get("find_unused_parameters", False)) |
| | | elif use_fsdp: |
| | | # model = FSDP(model).cuda(local_rank) |
| | | |
| | |
| | | except: |
| | | writer = None |
| | | |
| | | # if use_ddp or use_fsdp: |
| | | # context = Join([model]) |
| | | # else: |
| | | # context = nullcontext() |
| | | context = nullcontext() |
| | | |
| | | for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): |
| | | time1 = time.perf_counter() |
| | | with context: |
| | | dataloader_tr, dataloader_val = dataloader.build_iter(epoch) |
| | | |
| | | for data_split_i in range(dataloader.data_split_num): |
| | | dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i) |
| | | trainer.train_epoch( |
| | | model=model, |
| | | optim=optim, |
| | |
| | | dataloader_train=dataloader_tr, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | writer=writer, |
| | | data_split_i=data_split_i, |
| | | data_split_num=dataloader.data_split_num, |
| | | ) |
| | | with context: |
| | | trainer.validate_epoch( |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | |
| | | trainer.validate_epoch( |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | ) |
| | | scheduler.step() |
| | | |
| | | |