| | |
| | | except: |
| | | writer = None |
| | | |
| | | # if use_ddp or use_fsdp: |
| | | # context = Join([model]) |
| | | # else: |
| | | # context = nullcontext() |
| | | context = nullcontext() |
| | | |
| | | for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): |
| | | time1 = time.perf_counter() |
| | | with context: |
| | | dataloader_tr, dataloader_val = dataloader.build_iter(epoch) |
| | | |
| | | for data_split_i in range(dataloader.data_split_num): |
| | | dataloader_tr, dataloader_val = dataloader.build_iter(epoch, data_split_i=data_split_i) |
| | | trainer.train_epoch( |
| | | model=model, |
| | | optim=optim, |
| | |
| | | dataloader_train=dataloader_tr, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | writer=writer, |
| | | data_split_i=data_split_i, |
| | | data_split_num=dataloader.data_split_num, |
| | | ) |
| | | with context: |
| | | |
| | | trainer.validate_epoch( |
| | | model=model, |
| | | dataloader_val=dataloader_val, |