游雁
2023-12-19 0e622e694e6cb4459955f1e5942a7c53349ce640
funasr/train_utils/trainer.py
@@ -55,7 +55,7 @@
      self.dataloader_val = dataloader_val
      self.output_dir = kwargs.get('output_dir', './')
      self.resume = kwargs.get('resume', True)
      self.start_epoch = 1
      self.start_epoch = 0
      self.max_epoch = kwargs.get('max_epoch', 100)
      self.local_rank = local_rank
      self.use_ddp = use_ddp
@@ -123,7 +123,7 @@
      for epoch in range(self.start_epoch, self.max_epoch + 1):
         self._train_epoch(epoch)
         # self._validate_epoch(epoch)
         if dist.get_rank() == 0:
         if self.rank == 0:
            self._save_checkpoint(epoch)
         self.scheduler.step()
         break
@@ -201,21 +201,22 @@
            speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
   
            speed_stats["total_time"] = total_time
         
         # import pdb;
         # pdb.set_trace()
         pbar.update(1)
         if self.local_rank == 0:
            description = (
               f"Epoch: {epoch + 1}/{self.max_epoch}, "
               f"step {batch_idx}/{len(self.dataloader_train)}, "
               f"{speed_stats}, "
               f"(loss: {loss.detach().float():.3f}), "
               f"(loss: {loss.detach().cpu().item():.3f}), "
               f"{[(k, round(v.cpu().item(), 3)) for k, v in stats.items()]}"
            )
            pbar.set_description(description)
         
         if batch_idx == 2:
            break
         # if batch_idx == 2:
         #    break
      pbar.close()
   def _validate_epoch(self, epoch):