From fc68b5ffe453235294a561737d8e84bb6c1689a4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 21:43:47 +0800
Subject: [PATCH] Dev gzf exp (#1661)
---
funasr/train_utils/trainer.py | 55 ++++++++++++++++++++++++++++++++++++++++++++++---------
1 files changed, 46 insertions(+), 9 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 8f20ba4..66f8778 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -15,6 +15,11 @@
from funasr.train_utils.average_nbest_models import average_checkpoints
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
+try:
+ import wandb
+except:
+ wandb = None
+
@contextmanager
def maybe_autocast(enabled):
@@ -107,9 +112,22 @@
self.best_step_or_epoch = ""
self.val_acc_step_or_eoch = {}
self.val_loss_step_or_eoch = {}
-
- self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
+ self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
+ self.start_data_split_i = 0
+ self.start_step = 0
+ self.use_wandb = kwargs.get("use_wandb", False)
+ if self.use_wandb:
+ wandb.login(key=kwargs.get("wandb_token"))
+ wandb.init(
+ config=kwargs,
+ project=kwargs.get("wandb_project", "my_project"),
+ entity=kwargs.get("wandb_team", "my_team"),
+ name=kwargs.get("wandb_exp_name", "my_exp"),
+ dir=output_dir,
+ job_type="training",
+ reinit=True,
+ )
def save_checkpoint(
self,
@@ -142,6 +160,7 @@
"val_loss_step_or_eoch": self.val_loss_step_or_eoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
+ "step": step,
}
if hasattr(model, "module"):
state["state_dict"] = model.module.state_dict()
@@ -268,6 +287,13 @@
self.best_step_or_epoch = (
checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
)
+ self.start_data_split_i = (
+ checkpoint["start_data_split_i"] if "start_data_split_i" in checkpoint else 0
+ )
+ self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
+ self.start_step = checkpoint["step"] if "step" in checkpoint else 0
+ self.start_step = 0 if self.start_step is None else self.start_step
+
model.to(self.device)
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
@@ -327,7 +353,7 @@
time2 = time.perf_counter()
with maybe_autocast(self.use_fp16):
retval = model(**batch)
-
+
if (
self.reset_gpu_cache
and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
@@ -598,7 +624,7 @@
acc_avg_epoch = getattr(self, f"{tag}_acc_avg")
description = (
f"{tag}, "
- f"rank: {self.local_rank}, "
+ f"rank: {self.rank}, "
f"epoch: {epoch}/{self.max_epoch}, "
f"data_slice: {data_split_i}/{data_split_num}, "
f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
@@ -613,18 +639,29 @@
)
logging.info(description)
+ description_dict = {
+ f"rank{self.rank}_loss/{tag}": loss,
+ f"rank{self.rank}_lr/{tag}": lr,
+ }
+
if writer is not None:
- writer.add_scalar(f"rank{self.local_rank}_loss/{tag}", loss, self.batch_total)
- writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
- writer.add_scalar(f"rank{self.local_rank}_lr/{tag}", lr, self.batch_total)
+ writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total)
+ writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total)
for key, var in stats.items():
writer.add_scalar(
- f"stats_rank{self.local_rank}_{key}/{tag}", var.item(), self.batch_total
+ f"stats_rank{self.rank}_{key}/{tag}", var.item(), self.batch_total
)
+ description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = var.item()
for key, var in speed_stats.items():
writer.add_scalar(
- f"stats_rank{self.local_rank}_{key}/{tag}", eval(var), self.batch_total
+ f"stats_rank{self.rank}_{key}/{tag}", eval(var), self.batch_total
)
+ description_dict[f"stats_rank{self.rank}_{key}/{tag}"] = eval(var)
+ if self.use_wandb and wandb is not None:
+ wandb.log(
+ description_dict,
+ setp=self.batch_total,
+ )
def close(self, writer=None):
--
Gitblit v1.9.1