| | |
| | | scaler = GradScaler(enabled=use_fp16) if use_fp16 else None |
| | | scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler |
| | | self.scaler = scaler |
| | | self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000) |
| | | |
| | | |
| | | try: |
| | |
| | | self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None |
| | | |
| | | |
| | | def _save_checkpoint(self, epoch): |
| | | def _save_checkpoint(self, epoch, step=None): |
| | | """ |
| | | Saves a checkpoint containing the model's state, the optimizer's state, |
| | | and the scheduler's state at the end of the given epoch. This method is |
| | |
| | | state["scaler_state"] = self.scaler.state_dict() |
| | | # Create output directory if it does not exist |
| | | os.makedirs(self.output_dir, exist_ok=True) |
| | | filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}') |
| | | if step is None: |
| | | filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}') |
| | | else: |
| | | filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}') |
| | | |
| | | torch.save(state, filename) |
| | | |
| | | print(f'\nCheckpoint saved to {filename}\n') |
| | |
| | | for key, var in speed_stats.items(): |
| | | self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total) |
| | | |
| | | |
| | | if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0: |
| | | self._save_checkpoint(epoch, step=batch_idx+1) |
| | | pbar.close() |
| | | |
| | | |
| | | def _validate_epoch(self, epoch): |
| | | """ |