| | |
| | | for epoch in range(self.start_epoch, self.max_epoch + 1): |
| | | |
| | | self._train_epoch(epoch) |
| | | |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | |
| | | self._validate_epoch(epoch) |
| | | |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | |
| | | |
| | | if self.rank == 0: |
| | | self._save_checkpoint(epoch) |
| | | |
| | |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | self.writer.close() |
| | | |
| | | if self.writer: |
| | | self.writer.close() |
| | | |
| | | |
| | | def _train_epoch(self, epoch): |
| | |
| | | continue |
| | | |
| | | # Execute an optimization step (update model parameters) |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | self.optim.step() |
| | | self.scheduler.step() |
| | | # Clear gradients for the next accumulation stage |
| | |
| | | pbar.update(1) |
| | | if self.local_rank == 0: |
| | | description = ( |
| | | f"Epoch: {epoch}/{self.max_epoch}, " |
| | | f"Train epoch: {epoch}/{self.max_epoch}, " |
| | | f"step {batch_idx}/{len(self.dataloader_train)}, " |
| | | f"{speed_stats}, " |
| | | f"(loss: {loss.detach().cpu().item():.3f}), " |
| | |
| | | pbar.update(1) |
| | | if self.local_rank == 0: |
| | | description = ( |
| | | f"validation: \nEpoch: {epoch}/{self.max_epoch}, " |
| | | f"validation epoch: {epoch}/{self.max_epoch}, " |
| | | f"step {batch_idx}/{len(self.dataloader_train)}, " |
| | | f"{speed_stats}, " |
| | | f"(loss: {loss.detach().cpu().item():.3f}), " |