From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/train_utils/trainer_ds.py |  120 +++++++++++++++++++++++++++++++-----------------------------
 1 files changed, 62 insertions(+), 58 deletions(-)

diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 8a0679c..ce8809c 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -121,8 +121,8 @@
         self.saved_ckpts = {}
         self.step_or_epoch = -1
         self.best_step_or_epoch = ""
-        self.val_acc_step_or_eoch = {}
-        self.val_loss_step_or_eoch = {}
+        self.val_acc_step_or_epoch = {}
+        self.val_loss_step_or_epoch = {}
 
         self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
         self.start_data_split_i = 0
@@ -194,8 +194,8 @@
                 # "optimizer": optim.state_dict(),
                 # "scheduler": scheduler.state_dict(),
                 "saved_ckpts": self.saved_ckpts,
-                "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
-                "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+                "val_acc_step_or_epoch": self.val_acc_step_or_epoch,
+                "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
                 "best_step_or_epoch": self.best_step_or_epoch,
                 "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
                 "step": step,
@@ -233,8 +233,8 @@
 
             if self.avg_keep_nbest_models_type == "acc":
                 if (
-                    self.val_acc_step_or_eoch[ckpt_name]
-                    >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
+                    self.val_acc_step_or_epoch[ckpt_name]
+                    >= self.val_acc_step_or_epoch[self.best_step_or_epoch]
                 ):
                     self.best_step_or_epoch = ckpt_name
                     best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
@@ -244,16 +244,16 @@
                             save_dir=self.output_dir, tag=f"model.pt.best", client_state=state
                         )
                     logging.info(
-                        f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+                        f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
                     )
                 else:
                     logging.info(
-                        f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+                        f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
                     )
             elif self.avg_keep_nbest_models_type == "loss":
                 if (
-                    self.val_loss_step_or_eoch[ckpt_name]
-                    <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
+                    self.val_loss_step_or_epoch[ckpt_name]
+                    <= self.val_loss_step_or_epoch[self.best_step_or_epoch]
                 ):
                     self.best_step_or_epoch = ckpt_name
                     best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
@@ -263,30 +263,31 @@
                             save_dir=self.output_dir, tag=f"model.pt.best", client_state=state
                         )
                     logging.info(
-                        f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+                        f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
                     )
                 else:
                     logging.info(
-                        f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+                        f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
                     )
             else:
                 print("Undo")
-            self.saved_ckpts[ckpt_name] = getattr(
-                self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
-            )[ckpt_name]
-            if self.keep_nbest_models > 0:
-                if len(self.saved_ckpts) > self.keep_nbest_models:
-                    if self.avg_keep_nbest_models_type == "acc":
-                        key = min(self.saved_ckpts, key=self.saved_ckpts.get)
-                    else:
-                        key = max(self.saved_ckpts, key=self.saved_ckpts.get)
-                    if key in self.saved_ckpts:
-                        del self.saved_ckpts[key]
-                    filename = os.path.join(self.output_dir, key)
-                    logging.info(f"Delete: {filename}")
-                    if os.path.exists(filename):
-                        # os.remove(filename)
-                        misc_utils.smart_remove(filename)
+            if self.rank == 0:
+                self.saved_ckpts[ckpt_name] = getattr(
+                    self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
+                )[ckpt_name]
+                if self.keep_nbest_models > 0:
+                    if len(self.saved_ckpts) > self.keep_nbest_models:
+                        if self.avg_keep_nbest_models_type == "acc":
+                            key = min(self.saved_ckpts, key=self.saved_ckpts.get)
+                        else:
+                            key = max(self.saved_ckpts, key=self.saved_ckpts.get)
+                        if key in self.saved_ckpts:
+                            del self.saved_ckpts[key]
+                        filename = os.path.join(self.output_dir, key)
+                        logging.info(f"Delete: {filename}")
+                        if os.path.exists(filename):
+                            # os.remove(filename)
+                            misc_utils.smart_remove(filename)
 
         elif self.use_fsdp:
             pass
@@ -300,8 +301,8 @@
                 "optimizer": optim.state_dict(),
                 "scheduler": scheduler.state_dict(),
                 "saved_ckpts": self.saved_ckpts,
-                "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
-                "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+                "val_acc_step_or_epoch": self.val_acc_step_or_epoch,
+                "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
                 "best_step_or_epoch": self.best_step_or_epoch,
                 "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
                 "step": step,
@@ -352,38 +353,38 @@
 
             if self.avg_keep_nbest_models_type == "acc":
                 if (
-                    self.val_acc_step_or_eoch[ckpt_name]
-                    >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
+                    self.val_acc_step_or_epoch[ckpt_name]
+                    >= self.val_acc_step_or_epoch[self.best_step_or_epoch]
                 ):
                     self.best_step_or_epoch = ckpt_name
                     best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
                     torch.save(state, best_ckpt)
                     logging.info(
-                        f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+                        f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
                     )
                 else:
                     logging.info(
-                        f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+                        f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
                     )
             elif self.avg_keep_nbest_models_type == "loss":
                 if (
-                    self.val_loss_step_or_eoch[ckpt_name]
-                    <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
+                    self.val_loss_step_or_epoch[ckpt_name]
+                    <= self.val_loss_step_or_epoch[self.best_step_or_epoch]
                 ):
                     self.best_step_or_epoch = ckpt_name
                     best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
                     torch.save(state, best_ckpt)
                     logging.info(
-                        f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+                        f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
                     )
                 else:
                     logging.info(
-                        f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+                        f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
                     )
             else:
                 print("Undo")
             self.saved_ckpts[ckpt_name] = getattr(
-                self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
+                self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
             )[ckpt_name]
             if self.keep_nbest_models > 0:
                 if len(self.saved_ckpts) > self.keep_nbest_models:
@@ -424,14 +425,14 @@
                     _, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
                     self.start_epoch = checkpoint["epoch"]
                     self.saved_ckpts = checkpoint["saved_ckpts"]
-                    self.val_acc_step_or_eoch = (
-                        checkpoint["val_acc_step_or_eoch"]
-                        if "val_acc_step_or_eoch" in checkpoint
+                    self.val_acc_step_or_epoch = (
+                        checkpoint["val_acc_step_or_epoch"]
+                        if "val_acc_step_or_epoch" in checkpoint
                         else {}
                     )
-                    self.val_loss_step_or_eoch = (
-                        checkpoint["val_loss_step_or_eoch"]
-                        if "val_loss_step_or_eoch" in checkpoint
+                    self.val_loss_step_or_epoch = (
+                        checkpoint["val_loss_step_or_epoch"]
+                        if "val_loss_step_or_epoch" in checkpoint
                         else {}
                     )
                     self.best_step_or_epoch = (
@@ -477,7 +478,7 @@
                             for k_ex in self.excludes:
                                 k_tmp = k.replace("module.", "")
                                 if k_tmp.startswith(k_ex):
-                                    logging.info(f"key: {{k}} matching: {k_ex}, excluded")
+                                    logging.info(f"key: {k} matching: {k_ex}, excluded")
                                     excludes_flag = True
                                     break
                         if excludes_flag:
@@ -500,14 +501,14 @@
                         scaler.load_state_dict(checkpoint["scaler_state"])
 
                     self.saved_ckpts = checkpoint["saved_ckpts"]
-                    self.val_acc_step_or_eoch = (
-                        checkpoint["val_acc_step_or_eoch"]
-                        if "val_acc_step_or_eoch" in checkpoint
+                    self.val_acc_step_or_epoch = (
+                        checkpoint["val_acc_step_or_epoch"]
+                        if "val_acc_step_or_epoch" in checkpoint
                         else {}
                     )
-                    self.val_loss_step_or_eoch = (
-                        checkpoint["val_loss_step_or_eoch"]
-                        if "val_loss_step_or_eoch" in checkpoint
+                    self.val_loss_step_or_epoch = (
+                        checkpoint["val_loss_step_or_epoch"]
+                        if "val_loss_step_or_epoch" in checkpoint
                         else {}
                     )
                     self.best_step_or_epoch = (
@@ -682,7 +683,7 @@
             scaled_loss = model.backward(loss)
         else:
             loss = loss / self.accum_grad
-            if self.use_fp16 or self.use_bf16:
+            if scaler:
                 scaler.scale(loss).backward()
             else:
                 loss.backward()
@@ -710,7 +711,7 @@
                 # Execute an optimization step (update model parameters)
                 if self.use_ddp or self.use_fsdp:
                     dist.barrier()
-                if self.use_fp16 or self.use_bf16:
+                if scaler:
                     scaler.step(optim)
                     scaler.update()
                 else:
@@ -734,6 +735,9 @@
         Args:
             epoch (int): The current epoch number.
         """
+        self.val_loss_avg = 0.0
+        self.val_acc_avg  = 0.0
+
         if self.use_ddp or self.use_fsdp or self.use_deepspeed:
             dist.barrier()
         logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
@@ -755,7 +759,7 @@
                     "data_split_i": kwargs.get("data_split_i", 0),
                     "data_split_num": kwargs.get("data_split_num", 1),
                     "log_step": batch_idx + kwargs.get("start_step", 0),
-                    "batch_total": batch_idx + 1,
+                    "batch_total": self.batch_total,
                     "step_in_epoch": batch_idx + 1,
                     "lr": 0.0,
                 }
@@ -802,8 +806,8 @@
             ckpt_name = f"model.pt.ep{epoch}"
         else:
             ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
-        self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
-        self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+        self.val_acc_step_or_epoch[ckpt_name] = self.val_acc_avg
+        self.val_loss_step_or_epoch[ckpt_name] = self.val_loss_avg
 
         if self.use_ddp or self.use_fsdp or self.use_deepspeed:
             dist.barrier()
@@ -881,7 +885,7 @@
             if self.use_wandb and wandb is not None:
                 wandb.log(
                     description_dict,
-                    setp=batch_total,
+                    step=batch_total,
                 )
 
     def close(self, writer=None):

--
Gitblit v1.9.1