From d4f13c2e444f972b272273bce76b05f52f5164aa Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 一月 2025 10:16:11 +0800
Subject: [PATCH] step_or_epoch bugfix

---
 funasr/train_utils/trainer.py |   16 +++++++---------
 1 files changed, 7 insertions(+), 9 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 5fe34b9..d0be9c8 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -161,8 +161,8 @@
             # self.step_or_epoch += 1
             state = {
                 "epoch": epoch,
-                'step': step,
-                'total_step': self.batch_total,
+                "step": step,
+                "total_step": self.batch_total,
                 "state_dict": model.state_dict(),
                 "optimizer": optim.state_dict(),
                 "scheduler": scheduler.state_dict(),
@@ -171,7 +171,6 @@
                 "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
                 "best_step_or_epoch": self.best_step_or_epoch,
                 "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
-                "step": step,
                 "step_in_epoch": step_in_epoch,
                 "data_split_i": kwargs.get("data_split_i", 0),
                 "data_split_num": kwargs.get("data_split_num", 1),
@@ -194,9 +193,9 @@
                 ckpt_name = f"model.pt.ep{epoch}.{step}"
             filename = os.path.join(self.output_dir, ckpt_name)
             torch.save(state, filename)
-            logging.info(f'Checkpoint saved to {filename}')
+            logging.info(f"Checkpoint saved to {filename}")
 
-            latest = Path(os.path.join(self.output_dir, f'model.pt'))
+            latest = Path(os.path.join(self.output_dir, f"model.pt"))
             torch.save(state, latest)
 
             if self.best_step_or_epoch == "":
@@ -332,7 +331,6 @@
 
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-
 
     def train_epoch(
         self,
@@ -591,9 +589,9 @@
                 time4 = time.perf_counter()
 
                 if torch.isfinite(loss):
-                    self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
-                        batch_idx + 1
-                    )
+                    self.val_loss_avg = (
+                        self.val_loss_avg * batch_idx + loss.detach().cpu().item()
+                    ) / (batch_idx + 1)
 
                     if "acc" in stats:
                         self.val_acc_avg = (

--
Gitblit v1.9.1