From 2ac79cd3f312e485f3fc4f0e63313cc8a3e0bfc6 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 12 六月 2024 19:27:35 +0800
Subject: [PATCH] decoding

---
 funasr/train_utils/trainer_ds.py |   47 ++++++++++++++++++++++++++++++++++++++++++++---
 1 files changed, 44 insertions(+), 3 deletions(-)

diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 1a553f8..ba8dd16 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -147,6 +147,16 @@
 
         self.use_deepspeed = use_deepspeed
         self.deepspeed_config = kwargs.get("deepspeed_config", "")
+        excludes = kwargs.get("excludes", None)
+        if excludes is not None:
+            if isinstance(excludes, str):
+                excludes = excludes.split(",")
+        self.excludes = excludes
+        effective_save_name_excludes = kwargs.get("effective_save_name_excludes", None)
+        if effective_save_name_excludes is not None:
+            if isinstance(effective_save_name_excludes, str):
+                effective_save_name_excludes = effective_save_name_excludes.split(",")
+        self.effective_save_name_excludes = effective_save_name_excludes
 
     def save_checkpoint(
         self,
@@ -167,6 +177,8 @@
         Args:
             epoch (int): The epoch number at which the checkpoint is being saved.
         """
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
         step_in_epoch = None if step is None else step_in_epoch
         if self.use_deepspeed:
 
@@ -279,7 +291,6 @@
             # self.step_or_epoch += 1
             state = {
                 "epoch": epoch,
-                "state_dict": model.state_dict(),
                 "optimizer": optim.state_dict(),
                 "scheduler": scheduler.state_dict(),
                 "saved_ckpts": self.saved_ckpts,
@@ -297,7 +308,24 @@
             }
             step = step_in_epoch
             if hasattr(model, "module"):
-                state["state_dict"] = model.module.state_dict()
+                state_dict = model.module.state_dict()
+            else:
+                state_dict = model.state_dict()
+
+            if self.effective_save_name_excludes is not None:
+                logging.info(f"effective_save_name_excludes: {self.effective_save_name_excludes}")
+                dst_state_dict = {}
+                for k in state_dict.keys():
+                    for k_ex in self.effective_save_name_excludes:
+                        k_tmp = k.replace("module.", "")
+                        if k.startswith(k_ex):
+                            logging.info(f"key: {k} matching: {k_ex}, not save it")
+                            break
+                    else:
+                        dst_state_dict[k] = state_dict[k]
+                state["state_dict"] = dst_state_dict
+            else:
+                state["state_dict"] = state_dict
 
             if scaler:
                 state["scaler_state"] = scaler.state_dict()
@@ -438,6 +466,16 @@
                     src_state = checkpoint["state_dict"]
                     dst_state = model.state_dict()
                     for k in dst_state.keys():
+                        excludes_flag = False
+                        if self.excludes is not None:
+                            for k_ex in self.excludes:
+                                k_tmp = k.replace("module.", "")
+                                if k_tmp.startswith(k_ex):
+                                    logging.info(f"key: {{k}} matching: {k_ex}, excluded")
+                                    excludes_flag = True
+                                    break
+                        if excludes_flag:
+                            continue
                         if not k.startswith("module.") and "module." + k in src_state.keys():
                             k_ddp = "module." + k
                         elif k.startswith("module.") and "module." + k not in src_state.keys():
@@ -621,7 +659,6 @@
             self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
 
     def forward_step(self, model, batch, loss_dict={}):
-        dtype = torch.bfloat16
         with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
             retval = model(**batch)
 
@@ -761,6 +798,10 @@
             ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
         self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
         self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+
+        if self.use_ddp or self.use_fsdp or self.use_deepspeed:
+            dist.barrier()
+
         model.train()
 
     def log(

--
Gitblit v1.9.1