From e6fe602db3eb1209543e55f1aafa2932dfda3310 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 一月 2025 10:14:30 +0800
Subject: [PATCH] step_or_epoch bugfix
---
funasr/train_utils/trainer_ds.py | 76 +++++++++++++++++++-------------------
1 files changed, 38 insertions(+), 38 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 85513a5..0b104da 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -122,8 +122,8 @@
self.saved_ckpts = {}
self.step_or_epoch = -1
self.best_step_or_epoch = ""
- self.val_acc_step_or_eoch = {}
- self.val_loss_step_or_eoch = {}
+ self.val_acc_step_or_epoch = {}
+ self.val_loss_step_or_epoch = {}
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
self.start_data_split_i = 0
@@ -195,8 +195,8 @@
# "optimizer": optim.state_dict(),
# "scheduler": scheduler.state_dict(),
"saved_ckpts": self.saved_ckpts,
- "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
- "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+ "val_acc_step_or_epoch": self.val_acc_step_or_epoch,
+ "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
@@ -234,8 +234,8 @@
if self.avg_keep_nbest_models_type == "acc":
if (
- self.val_acc_step_or_eoch[ckpt_name]
- >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
+ self.val_acc_step_or_epoch[ckpt_name]
+ >= self.val_acc_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
@@ -245,16 +245,16 @@
save_dir=self.output_dir, tag=f"model.pt.best", client_state=state
)
logging.info(
- f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
- f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
elif self.avg_keep_nbest_models_type == "loss":
if (
- self.val_loss_step_or_eoch[ckpt_name]
- <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
+ self.val_loss_step_or_epoch[ckpt_name]
+ <= self.val_loss_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
@@ -264,16 +264,16 @@
save_dir=self.output_dir, tag=f"model.pt.best", client_state=state
)
logging.info(
- f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
- f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
else:
print("Undo")
self.saved_ckpts[ckpt_name] = getattr(
- self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
+ self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
)[ckpt_name]
if self.keep_nbest_models > 0:
if len(self.saved_ckpts) > self.keep_nbest_models:
@@ -301,8 +301,8 @@
"optimizer": optim.state_dict(),
"scheduler": scheduler.state_dict(),
"saved_ckpts": self.saved_ckpts,
- "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
- "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+ "val_acc_step_or_epoch": self.val_acc_step_or_epoch,
+ "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
@@ -353,38 +353,38 @@
if self.avg_keep_nbest_models_type == "acc":
if (
- self.val_acc_step_or_eoch[ckpt_name]
- >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
+ self.val_acc_step_or_epoch[ckpt_name]
+ >= self.val_acc_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
- f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
- f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
elif self.avg_keep_nbest_models_type == "loss":
if (
- self.val_loss_step_or_eoch[ckpt_name]
- <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
+ self.val_loss_step_or_epoch[ckpt_name]
+ <= self.val_loss_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
- f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
- f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
else:
print("Undo")
self.saved_ckpts[ckpt_name] = getattr(
- self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
+ self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
)[ckpt_name]
if self.keep_nbest_models > 0:
if len(self.saved_ckpts) > self.keep_nbest_models:
@@ -425,14 +425,14 @@
_, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
self.start_epoch = checkpoint["epoch"]
self.saved_ckpts = checkpoint["saved_ckpts"]
- self.val_acc_step_or_eoch = (
- checkpoint["val_acc_step_or_eoch"]
- if "val_acc_step_or_eoch" in checkpoint
+ self.val_acc_step_or_epoch = (
+ checkpoint["val_acc_step_or_epoch"]
+ if "val_acc_step_or_epoch" in checkpoint
else {}
)
- self.val_loss_step_or_eoch = (
- checkpoint["val_loss_step_or_eoch"]
- if "val_loss_step_or_eoch" in checkpoint
+ self.val_loss_step_or_epoch = (
+ checkpoint["val_loss_step_or_epoch"]
+ if "val_loss_step_or_epoch" in checkpoint
else {}
)
self.best_step_or_epoch = (
@@ -501,14 +501,14 @@
scaler.load_state_dict(checkpoint["scaler_state"])
self.saved_ckpts = checkpoint["saved_ckpts"]
- self.val_acc_step_or_eoch = (
- checkpoint["val_acc_step_or_eoch"]
- if "val_acc_step_or_eoch" in checkpoint
+ self.val_acc_step_or_epoch = (
+ checkpoint["val_acc_step_or_epoch"]
+ if "val_acc_step_or_epoch" in checkpoint
else {}
)
- self.val_loss_step_or_eoch = (
- checkpoint["val_loss_step_or_eoch"]
- if "val_loss_step_or_eoch" in checkpoint
+ self.val_loss_step_or_epoch = (
+ checkpoint["val_loss_step_or_epoch"]
+ if "val_loss_step_or_epoch" in checkpoint
else {}
)
self.best_step_or_epoch = (
@@ -803,8 +803,8 @@
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
- self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
- self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+ self.val_acc_step_or_epoch[ckpt_name] = self.val_acc_avg
+ self.val_loss_step_or_epoch[ckpt_name] = self.val_loss_avg
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
dist.barrier()
--
Gitblit v1.9.1