| | |
| | | self.best_step_or_epoch = ckpt_name |
| | | best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best')) |
| | | torch.save(state, best_ckpt) |
| | | logging.info(f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]}, {best_ckpt}") |
| | | logging.info(f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}") |
| | | else: |
| | | logging.info(f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]}") |
| | | logging.info(f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}") |
| | | elif self.avg_keep_nbest_models_type == "loss": |
| | | if self.val_loss_step_or_eoch[ckpt_name] <= self.val_loss_step_or_eoch[self.best_step_or_epoch]: |
| | | self.best_step_or_epoch = ckpt_name |
| | | best_ckpt = Path(os.path.join(self.output_dir, f'model.pt.best')) |
| | | torch.save(state, best_ckpt) |
| | | logging.info(f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]}, {best_ckpt}") |
| | | logging.info(f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}") |
| | | else: |
| | | logging.info(f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]}") |
| | | logging.info(f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}") |
| | | else: |
| | | print("Undo") |
| | | self.saved_ckpts[ckpt_name] = getattr(self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch")[ckpt_name] |
| | |
| | | model=model, |
| | | dataloader_val=dataloader_val, |
| | | epoch=epoch, |
| | | writer=writer |
| | | writer=writer, |
| | | step=batch_idx+1, |
| | | ) |
| | | |
| | | if (batch_idx+1) % self.save_checkpoint_interval == 0: |
| | |
| | | f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, " |
| | | f"(loss_avg_rank: {loss:.3f}), " |
| | | f"(loss_avg_epoch: {loss_avg_epoch:.3f}), " |
| | | f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3f}), " |
| | | f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), " |
| | | f"(acc_avg_epoch: {acc_avg_epoch:.3f}), " |
| | | f"(lr: {lr:.3e}), " |
| | | f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, " |