From fc68b5ffe453235294a561737d8e84bb6c1689a4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 21:43:47 +0800
Subject: [PATCH] Dev gzf exp (#1661)

---
 funasr/train_utils/trainer.py |   55 ++++++++++++++++++++++++++++++++++++++++++++++---------
 1 files changed, 46 insertions(+), 9 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 8f20ba4..66f8778 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -15,6 +15,11 @@
 from funasr.train_utils.average_nbest_models import average_checkpoints
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 
+try:
+    import wandb
+except:
+    wandb = None
+
 
 @contextmanager
 def maybe_autocast(enabled):
@@ -107,9 +112,22 @@
         self.best_step_or_epoch = ""
         self.val_acc_step_or_eoch = {}
         self.val_loss_step_or_eoch = {}
-        
-        self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
 
+        self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
+        self.start_data_split_i = 0
+        self.start_step = 0
+        self.use_wandb = kwargs.get("use_wandb", False)
+        if self.use_wandb:
+            wandb.login(key=kwargs.get("wandb_token"))
+            wandb.init(
+                config=kwargs,
+                project=kwargs.get("wandb_project", "my_project"),
+                entity=kwargs.get("wandb_team", "my_team"),
+                name=kwargs.get("wandb_exp_name", "my_exp"),
+                dir=output_dir,
+                job_type="training",
+                reinit=True,
+            )
 
     def save_checkpoint(
         self,
@@ -142,6 +160,7 @@
                 "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
                 "best_step_or_epoch": self.best_step_or_epoch,
                 "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
+                "step": step,
             }
             if hasattr(model, "module"):
                 state["state_dict"] = model.module.state_dict()
@@ -268,6 +287,13 @@
                 self.best_step_or_epoch = (
                     checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
                 )
+                self.start_data_split_i = (
+                    checkpoint["start_data_split_i"] if "start_data_split_i" in checkpoint else 0
+                )
+                self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
+                self.start_step = checkpoint["step"] if "step" in checkpoint else 0
+                self.start_step = 0 if self.start_step is None else self.start_step
+
                 model.to(self.device)
                 print(f"Checkpoint loaded successfully from '{ckpt}'")
             else:
@@ -327,7 +353,7 @@
                 time2 = time.perf_counter()
                 with maybe_autocast(self.use_fp16):
                     retval = model(**batch)
-                    
+
                     if (
                         self.reset_gpu_cache
                         and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
@@ -598,7 +624,7 @@
             acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
             description = (
                 f"{tag}, "
-                f"rank: {self.local_rank}, "
+                f"rank: {self.rank}, "
                 f"epoch: {epoch}/{self.max_epoch}, "
                 f"data_slice: {data_split_i}/{data_split_num}, "
                 f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
@@ -613,18 +639,29 @@
             )
             logging.info(description)
 
+            description_dict = {
+                f"rank{self.rank}_loss/{tag}": loss,
+                f"rank{self.rank}_lr/{tag}": lr,
+            }
+
             if writer is not None:
-                writer.add_scalar(f"rank{self.local_rank}_loss/{tag}", loss, self.batch_total)
-                writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
-                writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
+                writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
+                writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
                 for key, var in stats.items():
                     writer.add_scalar(
-                        f"stats_rank{self.local_rank}_{key}/{tag}", var.item(), self.batch_total
+                        f"stats_rank{self.rank}_{key}/{tag}", var.item(), self.batch_total
                     )
+                    description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
                 for key, var in speed_stats.items():
                     writer.add_scalar(
-                        f"stats_rank{self.local_rank}_{key}/{tag}", eval(var), self.batch_total
+                        f"stats_rank{self.rank}_{key}/{tag}", eval(var), self.batch_total
                     )
+                    description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
+            if self.use_wandb and wandb is not None:
+                wandb.log(
+                    description_dict,
+                    setp=self.batch_total,
+                )
 
     def close(self, writer=None):
 

--
Gitblit v1.9.1