From 2ac79cd3f312e485f3fc4f0e63313cc8a3e0bfc6 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 12 六月 2024 19:27:35 +0800
Subject: [PATCH] decoding
---
funasr/train_utils/trainer_ds.py | 46 ++++++++++++++++++++++++++++++++++++++++++++--
1 files changed, 44 insertions(+), 2 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index ec887cc..ba8dd16 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -147,6 +147,16 @@
self.use_deepspeed = use_deepspeed
self.deepspeed_config = kwargs.get("deepspeed_config", "")
+ excludes = kwargs.get("excludes", None)
+ if excludes is not None:
+ if isinstance(excludes, str):
+ excludes = excludes.split(",")
+ self.excludes = excludes
+ effective_save_name_excludes = kwargs.get("effective_save_name_excludes", None)
+ if effective_save_name_excludes is not None:
+ if isinstance(effective_save_name_excludes, str):
+ effective_save_name_excludes = effective_save_name_excludes.split(",")
+ self.effective_save_name_excludes = effective_save_name_excludes
def save_checkpoint(
self,
@@ -167,6 +177,8 @@
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
step_in_epoch = None if step is None else step_in_epoch
if self.use_deepspeed:
@@ -279,7 +291,6 @@
# self.step_or_epoch += 1
state = {
"epoch": epoch,
- "state_dict": model.state_dict(),
"optimizer": optim.state_dict(),
"scheduler": scheduler.state_dict(),
"saved_ckpts": self.saved_ckpts,
@@ -297,7 +308,24 @@
}
step = step_in_epoch
if hasattr(model, "module"):
- state["state_dict"] = model.module.state_dict()
+ state_dict = model.module.state_dict()
+ else:
+ state_dict = model.state_dict()
+
+ if self.effective_save_name_excludes is not None:
+ logging.info(f"effective_save_name_excludes: {self.effective_save_name_excludes}")
+ dst_state_dict = {}
+ for k in state_dict.keys():
+ for k_ex in self.effective_save_name_excludes:
+ k_tmp = k.replace("module.", "")
+ if k.startswith(k_ex):
+ logging.info(f"key: {k} matching: {k_ex}, not save it")
+ break
+ else:
+ dst_state_dict[k] = state_dict[k]
+ state["state_dict"] = dst_state_dict
+ else:
+ state["state_dict"] = state_dict
if scaler:
state["scaler_state"] = scaler.state_dict()
@@ -438,6 +466,16 @@
src_state = checkpoint["state_dict"]
dst_state = model.state_dict()
for k in dst_state.keys():
+ excludes_flag = False
+ if self.excludes is not None:
+ for k_ex in self.excludes:
+ k_tmp = k.replace("module.", "")
+ if k_tmp.startswith(k_ex):
+ logging.info(f"key: {{k}} matching: {k_ex}, excluded")
+ excludes_flag = True
+ break
+ if excludes_flag:
+ continue
if not k.startswith("module.") and "module." + k in src_state.keys():
k_ddp = "module." + k
elif k.startswith("module.") and "module." + k not in src_state.keys():
@@ -760,6 +798,10 @@
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+
+ if self.use_ddp or self.use_fsdp or self.use_deepspeed:
+ dist.barrier()
+
model.train()
def log(
--
Gitblit v1.9.1