From 6a74eb706c1f26023608b06013f6ae38bc569d6f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 12 六月 2024 15:45:03 +0800
Subject: [PATCH] decoding

---
 funasr/train_utils/trainer_ds.py            |   43 ++++++++++++++++++++++++++++++++++---------
 funasr/train_utils/load_pretrained_model.py |    6 +++++-
 2 files changed, 39 insertions(+), 10 deletions(-)

diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index d633d81..cb2d87e 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -56,11 +56,15 @@
     logging.info(f"excludes: {excludes}")
 
     for k in dst_state.keys():
+        excludes_flag = False
         if excludes is not None:
             for k_ex in excludes:
                 if k.startswith(k_ex):
                     logging.info(f"key: {{k}} matching: {k_ex}, excluded")
-                    continue
+                    excludes_flag = True
+                    break
+        if excludes_flag:
+            continue
 
         k_src = k
 
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index c973728..b2d7b95 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -147,10 +147,16 @@
 
         self.use_deepspeed = use_deepspeed
         self.deepspeed_config = kwargs.get("deepspeed_config", "")
-        self.excludes = kwargs.get("excludes", None)
-        if self.excludes is not None:
-            if isinstance(self.excludes, str):
-                self.excludes = self.excludes.split(",")
+        excludes = kwargs.get("excludes", None)
+        if excludes is not None:
+            if isinstance(excludes, str):
+                excludes = excludes.split(",")
+        self.excludes = excludes
+        effective_save_name_excludes = kwargs.get("effective_save_name_excludes", None)
+        if effective_save_name_excludes is not None:
+            if isinstance(effective_save_name_excludes, str):
+                effective_save_name_excludes = effective_save_name_excludes.split(",")
+        self.effective_save_name_excludes = effective_save_name_excludes
 
     def save_checkpoint(
         self,
@@ -285,7 +291,6 @@
             # self.step_or_epoch += 1
             state = {
                 "epoch": epoch,
-                "state_dict": model.state_dict(),
                 "optimizer": optim.state_dict(),
                 "scheduler": scheduler.state_dict(),
                 "saved_ckpts": self.saved_ckpts,
@@ -303,7 +308,23 @@
             }
             step = step_in_epoch
             if hasattr(model, "module"):
-                state["state_dict"] = model.module.state_dict()
+                state_dict = model.module.state_dict()
+            else:
+                state_dict = model.state_dict()
+
+            if self.effective_save_name_excludes is not None:
+                dst_state_dict = {}
+                for k in state_dict.keys():
+                    for k_ex in self.effective_save_name_excludes:
+                        k_tmp = k.replace("module.", "")
+                        if k.startswith(k_ex):
+                            logging.info(f"key: {{k}} matching: {k_ex}, not save it")
+                            break
+                    else:
+                        dst_state_dict[k] = state_dict[k]
+                state["state_dict"] = dst_state_dict
+            else:
+                state["state_dict"] = state_dict
 
             if scaler:
                 state["scaler_state"] = scaler.state_dict()
@@ -444,12 +465,16 @@
                     src_state = checkpoint["state_dict"]
                     dst_state = model.state_dict()
                     for k in dst_state.keys():
-                        if excludes is not None:
-                            for k_ex in excludes:
+                        excludes_flag = False
+                        if self.excludes is not None:
+                            for k_ex in self.excludes:
                                 k_tmp = k.replace("module.", "")
                                 if k_tmp.startswith(k_ex):
                                     logging.info(f"key: {{k}} matching: {k_ex}, excluded")
-                                    continue
+                                    excludes_flag = True
+                                    break
+                        if excludes_flag:
+                            continue
                         if not k.startswith("module.") and "module." + k in src_state.keys():
                             k_ddp = "module." + k
                         elif k.startswith("module.") and "module." + k not in src_state.keys():

--
Gitblit v1.9.1