| | |
| | | "data_split_i": kwargs.get("data_split_i", 0), |
| | | "data_split_num": kwargs.get("data_split_num", 1), |
| | | "batch_total": self.batch_total, |
| | | "train_loss_avg": kwargs.get("train_loss_avg", 0), |
| | | "train_acc_avg": kwargs.get("train_acc_avg", 0), |
| | | } |
| | | step = step_in_epoch |
| | | if hasattr(model, "module"): |
| | |
| | | checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0 |
| | | ) |
| | | self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch |
| | | |
| | | self.train_acc_avg = ( |
| | | checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0 |
| | | ) |
| | | self.train_loss_avg = ( |
| | | checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0 |
| | | ) |
| | | model.to(self.device) |
| | | print(f"Checkpoint loaded successfully from '{ckpt}'") |
| | | else: |
| | |
| | | speed_stats["backward_time"] = f"{time4 - time3:0.3f}" |
| | | |
| | | self.train_loss_avg = ( |
| | | self.train_loss_avg * batch_idx + loss.detach().cpu().item() |
| | | ) / (batch_idx + 1) |
| | | self.train_loss_avg * (self.step_in_epoch - 1) + loss.detach().cpu().item() |
| | | ) / self.step_in_epoch |
| | | if "acc" in stats: |
| | | self.train_acc_avg = ( |
| | | self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item() |
| | | ) / (batch_idx + 1) |
| | | self.train_acc_avg * (self.step_in_epoch - 1) |
| | | + stats["acc"].detach().cpu().item() |
| | | ) / self.step_in_epoch |
| | | if self.use_ddp or self.use_fsdp: |
| | | train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to( |
| | | self.device |
| | |
| | | step_in_epoch=self.step_in_epoch, |
| | | data_split_i=kwargs.get("data_split_i", 0), |
| | | data_split_num=kwargs.get("data_split_num", 1), |
| | | train_loss_avg=self.train_loss_avg, |
| | | train_acc_avg=self.train_acc_avg, |
| | | ) |
| | | |
| | | time_beg = time.perf_counter() |