| | |
| | | |
| | | from contextlib import nullcontext |
| | | import torch.distributed as dist |
| | | from collections.abc import Sequence |
| | | |
| | | from omegaconf import DictConfig, OmegaConf |
| | | from torch.cuda.amp import autocast, GradScaler |
| | | from torch.nn.parallel import DistributedDataParallel as DDP |
| | |
| | | if freeze_param is not None: |
| | | if "," in freeze_param: |
| | | freeze_param = eval(freeze_param) |
| | | if not isinstance(freeze_param, Sequence): |
| | | if not isinstance(freeze_param, (list, tuple)): |
| | | freeze_param = (freeze_param,) |
| | | logging.info("freeze_param is not None: %s", freeze_param) |
| | | for t in freeze_param: |
| | |
| | | try: |
| | | from tensorboardX import SummaryWriter |
| | | |
| | | writer = SummaryWriter(tensorboard_dir) if trainer.rank == 0 else None |
| | | writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None |
| | | except: |
| | | writer = None |
| | | |
| | |
| | | epoch, data_split_i=data_split_i, start_step=trainer.start_step |
| | | ) |
| | | trainer.start_step = 0 |
| | | |
| | | trainer.train_epoch( |
| | | model=model, |
| | | optim=optim, |
| | |
| | | model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer |
| | | ) |
| | | scheduler.step() |
| | | |
| | | trainer.step_cur_in_epoch = 0 |
| | | trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler) |
| | | |
| | | time2 = time.perf_counter() |