| | |
| | | self.scaler.load_state_dict(checkpoint['scaler_state']) |
| | | print(f"Checkpoint loaded successfully from '{ckpt}'") |
| | | else: |
| | | print(f"No checkpoint found at '{ckpt}', starting from scratch") |
| | | print(f"No checkpoint found at '{ckpt}', does not resume status!") |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | |
| | | epoch * len(self.dataloader_val) + batch_idx) |
| | | for key, var in speed_stats.items(): |
| | | self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var), |
| | | epoch * len(self.dataloader_val) + batch_idx) |
| | | epoch * len(self.dataloader_val) + batch_idx) |
| | | |
| | | self.model.train() |