| | |
| | | # self.step_or_epoch += 1 |
| | | state = { |
| | | "epoch": epoch, |
| | | 'step': step, |
| | | 'total_step': self.batch_total, |
| | | "step": step, |
| | | "total_step": self.batch_total, |
| | | "state_dict": model.state_dict(), |
| | | "optimizer": optim.state_dict(), |
| | | "scheduler": scheduler.state_dict(), |
| | |
| | | "val_loss_step_or_epoch": self.val_loss_step_or_epoch, |
| | | "best_step_or_epoch": self.best_step_or_epoch, |
| | | "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type, |
| | | "step": step, |
| | | "step_in_epoch": step_in_epoch, |
| | | "data_split_i": kwargs.get("data_split_i", 0), |
| | | "data_split_num": kwargs.get("data_split_num", 1), |
| | |
| | | ckpt_name = f"model.pt.ep{epoch}.{step}" |
| | | filename = os.path.join(self.output_dir, ckpt_name) |
| | | torch.save(state, filename) |
| | | logging.info(f'Checkpoint saved to {filename}') |
| | | logging.info(f"Checkpoint saved to {filename}") |
| | | |
| | | latest = Path(os.path.join(self.output_dir, f'model.pt')) |
| | | latest = Path(os.path.join(self.output_dir, f"model.pt")) |
| | | torch.save(state, latest) |
| | | |
| | | if self.best_step_or_epoch == "": |
| | |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | |
| | | |
| | | def train_epoch( |
| | | self, |
| | |
| | | time4 = time.perf_counter() |
| | | |
| | | if torch.isfinite(loss): |
| | | self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / ( |
| | | batch_idx + 1 |
| | | ) |
| | | self.val_loss_avg = ( |
| | | self.val_loss_avg * batch_idx + loss.detach().cpu().item() |
| | | ) / (batch_idx + 1) |
| | | |
| | | if "acc" in stats: |
| | | self.val_acc_avg = ( |