| | |
| | | self.dataloader_val = dataloader_val |
| | | self.output_dir = kwargs.get('output_dir', './') |
| | | self.resume = kwargs.get('resume', True) |
| | | self.start_epoch = 1 |
| | | self.start_epoch = 0 |
| | | self.max_epoch = kwargs.get('max_epoch', 100) |
| | | self.local_rank = local_rank |
| | | self.use_ddp = use_ddp |
| | |
| | | for epoch in range(self.start_epoch, self.max_epoch + 1): |
| | | self._train_epoch(epoch) |
| | | # self._validate_epoch(epoch) |
| | | if dist.get_rank() == 0: |
| | | if self.rank == 0: |
| | | self._save_checkpoint(epoch) |
| | | self.scheduler.step() |
| | | break |
| | |
| | | speed_stats["optim_time"] = f"{time5 - time4:0.3f}" |
| | | |
| | | speed_stats["total_time"] = total_time |
| | | |
| | | |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | pbar.update(1) |
| | | if self.local_rank == 0: |
| | | description = ( |
| | | f"Epoch: {epoch + 1}/{self.max_epoch}, " |
| | | f"step {batch_idx}/{len(self.dataloader_train)}, " |
| | | f"{speed_stats}, " |
| | | f"(loss: {loss.detach().float():.3f}), " |
| | | f"(loss: {loss.detach().cpu().item():.3f}), " |
| | | f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}" |
| | | ) |
| | | pbar.set_description(description) |
| | | |
| | | if batch_idx == 2: |
| | | break |
| | | # if batch_idx == 2: |
| | | # break |
| | | pbar.close() |
| | | |
| | | def _validate_epoch(self, epoch): |