From 93ef505e2d426b6aa1e58c0b4721999de789ff8e Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 28 四月 2024 15:14:57 +0800
Subject: [PATCH] Dev gzf exp (#1670)

---
 funasr/train_utils/trainer.py |   31 +++++++++++++++++++++++++++----
 1 files changed, 27 insertions(+), 4 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 66f8778..5685b8f 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -116,6 +116,7 @@
         self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
         self.start_data_split_i = 0
         self.start_step = 0
+        self.step_cur_in_epoch = 0
         self.use_wandb = kwargs.get("use_wandb", False)
         if self.use_wandb:
             wandb.login(key=kwargs.get("wandb_token"))
@@ -137,6 +138,8 @@
         optim=None,
         scheduler=None,
         scaler=None,
+        step_cur_in_epoch=None,
+        **kwargs,
     ):
         """
         Saves a checkpoint containing the model's state, the optimizer's state,
@@ -147,6 +150,7 @@
             epoch (int): The epoch number at which the checkpoint is being saved.
         """
 
+        step_cur_in_epoch = None if step is None else step_cur_in_epoch
         if self.rank == 0:
             logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
             # self.step_or_epoch += 1
@@ -161,7 +165,12 @@
                 "best_step_or_epoch": self.best_step_or_epoch,
                 "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
                 "step": step,
+                "step_cur_in_epoch": step_cur_in_epoch,
+                "data_split_i": kwargs.get("data_split_i", 0),
+                "data_split_num": kwargs.get("data_split_num", 1),
+                "batch_total": self.batch_total,
             }
+            step = step_cur_in_epoch
             if hasattr(model, "module"):
                 state["state_dict"] = model.module.state_dict()
 
@@ -293,6 +302,12 @@
                 self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
                 self.start_step = checkpoint["step"] if "step" in checkpoint else 0
                 self.start_step = 0 if self.start_step is None else self.start_step
+                self.step_cur_in_epoch = (
+                    checkpoint["step_cur_in_epoch"] if "step_cur_in_epoch" in checkpoint else 0
+                )
+                self.step_cur_in_epoch = (
+                    0 if self.step_cur_in_epoch is None else self.step_cur_in_epoch
+                )
 
                 model.to(self.device)
                 print(f"Checkpoint loaded successfully from '{ckpt}'")
@@ -321,7 +336,7 @@
         """
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-        logging.info(f"Train epoch: {epoch}, rank: {self.local_rank}\n")
+        logging.info(f"Train epoch: {epoch}, rank: {self.rank}\n")
         model.train()
 
         # Set the number of steps for gradient accumulation
@@ -341,6 +356,7 @@
                 if iterator_stop > 0:
                     break
             self.batch_total += 1
+            self.step_cur_in_epoch += 1
             time1 = time.perf_counter()
             speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
 
@@ -443,6 +459,7 @@
                 self.log(
                     epoch,
                     batch_idx,
+                    step_cur_in_epoch=self.step_cur_in_epoch,
                     batch_num_epoch=batch_num_epoch,
                     lr=lr,
                     loss=loss.detach().cpu().item(),
@@ -461,6 +478,7 @@
                     epoch=epoch,
                     writer=writer,
                     step=batch_idx + 1,
+                    step_cur_in_epoch=self.step_cur_in_epoch,
                 )
 
             if (batch_idx + 1) % self.save_checkpoint_interval == 0:
@@ -471,6 +489,9 @@
                     scheduler=scheduler,
                     scaler=scaler,
                     step=batch_idx + 1,
+                    step_cur_in_epoch=self.step_cur_in_epoch,
+                    data_split_i=kwargs.get("data_split_i", 0),
+                    data_split_num=kwargs.get("data_split_num", 1),
                 )
 
             time_beg = time.perf_counter()
@@ -500,7 +521,7 @@
         """
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-        logging.info(f"Validate epoch: {epoch}, rank: {self.local_rank}\n")
+        logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n")
         model.eval()
 
         with torch.no_grad():
@@ -578,10 +599,10 @@
                     iterator_stop.fill_(1)
                     dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
 
-        if kwargs.get("step", None) is None:
+        if kwargs.get("step_cur_in_epoch", None) is None:
             ckpt_name = f"model.pt.ep{epoch}"
         else:
-            ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step")}'
+            ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_cur_in_epoch")}'
         self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
         self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
         model.train()
@@ -594,6 +615,7 @@
         self,
         epoch=0,
         batch_idx=0,
+        step_cur_in_epoch=0,
         batch_num_epoch=-1,
         lr=0.0,
         loss=0.0,
@@ -626,6 +648,7 @@
                 f"{tag}, "
                 f"rank: {self.rank}, "
                 f"epoch: {epoch}/{self.max_epoch}, "
+                f"step_cur_in_epoch: {step_cur_in_epoch}, "
                 f"data_slice: {data_split_i}/{data_split_num}, "
                 f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
                 f"(loss_avg_rank: {loss:.3f}), "

--
Gitblit v1.9.1