游雁
2023-12-06 e98e10639d90c55a4b7e498d0d87837ad9c4173d
funasr/cli/trainer.py
@@ -131,7 +131,7 @@
      for batch_idx, batch in enumerate(self.dataloader_train):
         batch = to_device(batch, self.device)
         
         my_context = model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
         my_context = self.model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
         with my_context():
            retval = self.model(**batch)
            loss, stats, weight = retval
@@ -163,9 +163,10 @@
            self.optim.zero_grad()
         
         pbar.update(1)
         pbar.set_description(
            f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
         if self.local_rank == 0:
            pbar.set_description(
               f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")
      pbar.close()
   
   # def _train_epoch(self, epoch):