From 6a74eb706c1f26023608b06013f6ae38bc569d6f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 12 六月 2024 15:45:03 +0800
Subject: [PATCH] decoding
---
funasr/train_utils/trainer_ds.py | 43 ++++++++++++++++++++++++++++++++++---------
1 files changed, 34 insertions(+), 9 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index c973728..b2d7b95 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -147,10 +147,16 @@
self.use_deepspeed = use_deepspeed
self.deepspeed_config = kwargs.get("deepspeed_config", "")
- self.excludes = kwargs.get("excludes", None)
- if self.excludes is not None:
- if isinstance(self.excludes, str):
- self.excludes = self.excludes.split(",")
+ excludes = kwargs.get("excludes", None)
+ if excludes is not None:
+ if isinstance(excludes, str):
+ excludes = excludes.split(",")
+ self.excludes = excludes
+ effective_save_name_excludes = kwargs.get("effective_save_name_excludes", None)
+ if effective_save_name_excludes is not None:
+ if isinstance(effective_save_name_excludes, str):
+ effective_save_name_excludes = effective_save_name_excludes.split(",")
+ self.effective_save_name_excludes = effective_save_name_excludes
def save_checkpoint(
self,
@@ -285,7 +291,6 @@
# self.step_or_epoch += 1
state = {
"epoch": epoch,
- "state_dict": model.state_dict(),
"optimizer": optim.state_dict(),
"scheduler": scheduler.state_dict(),
"saved_ckpts": self.saved_ckpts,
@@ -303,7 +308,23 @@
}
step = step_in_epoch
if hasattr(model, "module"):
- state["state_dict"] = model.module.state_dict()
+ state_dict = model.module.state_dict()
+ else:
+ state_dict = model.state_dict()
+
+ if self.effective_save_name_excludes is not None:
+ dst_state_dict = {}
+ for k in state_dict.keys():
+ for k_ex in self.effective_save_name_excludes:
+ k_tmp = k.replace("module.", "")
+ if k.startswith(k_ex):
+ logging.info(f"key: {{k}} matching: {k_ex}, not save it")
+ break
+ else:
+ dst_state_dict[k] = state_dict[k]
+ state["state_dict"] = dst_state_dict
+ else:
+ state["state_dict"] = state_dict
if scaler:
state["scaler_state"] = scaler.state_dict()
@@ -444,12 +465,16 @@
src_state = checkpoint["state_dict"]
dst_state = model.state_dict()
for k in dst_state.keys():
- if excludes is not None:
- for k_ex in excludes:
+ excludes_flag = False
+ if self.excludes is not None:
+ for k_ex in self.excludes:
k_tmp = k.replace("module.", "")
if k_tmp.startswith(k_ex):
logging.info(f"key: {{k}} matching: {k_ex}, excluded")
- continue
+ excludes_flag = True
+ break
+ if excludes_flag:
+ continue
if not k.startswith("module.") and "module." + k in src_state.keys():
k_ddp = "module." + k
elif k.startswith("module.") and "module." + k not in src_state.keys():
--
Gitblit v1.9.1