| | |
| | | |
| | | model = trainer.warp_model(model) |
| | | |
| | | kwargs["device"] = next(model.parameters()).device |
| | | trainer.device = kwargs["device"] |
| | | kwargs["device"] = int(os.environ.get("LOCAL_RANK", 0)) |
| | | trainer.device = int(os.environ.get("LOCAL_RANK", 0)) |
| | | |
| | | model, optim, scheduler = trainer.warp_optim_scheduler(model, kwargs) |
| | | model, optim, scheduler = trainer.warp_optim_scheduler(model, **kwargs) |
| | | |
| | | # dataset |
| | | logging.info("Build dataloader") |
| | |
| | | trainer.train_loss_avg = 0.0 |
| | | |
| | | if trainer.rank == 0: |
| | | average_checkpoints(trainer.output_dir, trainer.avg_nbest_model) |
| | | average_checkpoints( |
| | | trainer.output_dir, trainer.avg_nbest_model, use_deepspeed=trainer.use_deepspeed |
| | | ) |
| | | |
| | | trainer.close() |
| | | |