游雁
2023-12-06 e98e10639d90c55a4b7e498d0d87837ad9c4173d
funasr/cli/trainer.py
@@ -131,7 +131,7 @@
      for batch_idx, batch in enumerate(self.dataloader_train):
         batch = to_device(batch, self.device)
         
         my_context = model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
         my_context = self.model.no_sync if batch_idx % accumulation_steps != 0 else nullcontext
         with my_context():
            retval = self.model(**batch)
            loss, stats, weight = retval
@@ -163,6 +163,7 @@
            self.optim.zero_grad()
         
         pbar.update(1)
         if self.local_rank == 0:
         pbar.set_description(
            f"Training Epoch: {epoch + 1}/{self.max_epoch}, step {batch_idx}/{len(self.dataloader_train)}  (loss: {loss.detach().float()})")