From 443bc09c11f3cf89ffc573aab2021f0c933aa5b3 Mon Sep 17 00:00:00 2001
From: kmn1024 <kienman@gmail.com>
Date: 星期三, 25 六月 2025 16:34:30 +0800
Subject: [PATCH] Bugfix: Only allow rank==0 to clean up old checkpoints (#2558)

---
 funasr/train_utils/trainer_ds.py |   33 +++++++++++++++++----------------
 1 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 0b104da..a1430db 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -272,22 +272,23 @@
                     )
             else:
                 print("Undo")
-            self.saved_ckpts[ckpt_name] = getattr(
-                self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
-            )[ckpt_name]
-            if self.keep_nbest_models > 0:
-                if len(self.saved_ckpts) > self.keep_nbest_models:
-                    if self.avg_keep_nbest_models_type == "acc":
-                        key = min(self.saved_ckpts, key=self.saved_ckpts.get)
-                    else:
-                        key = max(self.saved_ckpts, key=self.saved_ckpts.get)
-                    if key in self.saved_ckpts:
-                        del self.saved_ckpts[key]
-                    filename = os.path.join(self.output_dir, key)
-                    logging.info(f"Delete: {filename}")
-                    if os.path.exists(filename):
-                        # os.remove(filename)
-                        misc_utils.smart_remove(filename)
+            if self.rank == 0:
+                self.saved_ckpts[ckpt_name] = getattr(
+                    self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
+                )[ckpt_name]
+                if self.keep_nbest_models > 0:
+                    if len(self.saved_ckpts) > self.keep_nbest_models:
+                        if self.avg_keep_nbest_models_type == "acc":
+                            key = min(self.saved_ckpts, key=self.saved_ckpts.get)
+                        else:
+                            key = max(self.saved_ckpts, key=self.saved_ckpts.get)
+                        if key in self.saved_ckpts:
+                            del self.saved_ckpts[key]
+                        filename = os.path.join(self.output_dir, key)
+                        logging.info(f"Delete: {filename}")
+                        if os.path.exists(filename):
+                            # os.remove(filename)
+                            misc_utils.smart_remove(filename)
 
         elif self.use_fsdp:
             pass

--
Gitblit v1.9.1