| | |
| | | # dataset |
| | | logging.info("Build dataloader") |
| | | dataloader_class = tables.dataloader_classes.get(kwargs["dataset_conf"].get("dataloader", "DataloaderMapStyle")) |
| | | # dataloader = dataloader_class(**kwargs) |
| | | dataloader_tr, dataloader_val = dataloader_class(**kwargs) |
| | | dataloader = dataloader_class(**kwargs) |
| | | # dataloader_tr, dataloader_val = dataloader_class(**kwargs) |
| | | trainer = Trainer(local_rank=local_rank, |
| | | use_ddp=use_ddp, |
| | | use_fsdp=use_fsdp, |
| | |
| | | except: |
| | | writer = None |
| | | |
| | | if use_ddp or use_fsdp: |
| | | context = Join([model]) |
| | | else: |
| | | context = nullcontext() |
| | | |
| | | # if use_ddp or use_fsdp: |
| | | # context = Join([model]) |
| | | # else: |
| | | # context = nullcontext() |
| | | context = nullcontext() |
| | | for epoch in range(trainer.start_epoch, trainer.max_epoch + 1): |
| | | time1 = time.perf_counter() |
| | | with context: |
| | | # dataloader_tr, dataloader_val = dataloader.build_iter(epoch) |
| | | dataloader_tr, dataloader_val = dataloader.build_iter(epoch) |
| | | trainer.train_epoch( |
| | | model=model, |
| | | optim=optim, |
| | |
| | | |
| | | |
| | | if trainer.rank == 0: |
| | | average_checkpoints(trainer.output_dir, trainer.avg_nbest_model, trainer.val_acc_list) |
| | | average_checkpoints(trainer.output_dir, trainer.avg_nbest_model) |
| | | |
| | | trainer.close() |
| | | |