From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/train_utils/trainer.py | 93 ++++++++++++++++++++++++++--------------------
1 files changed, 52 insertions(+), 41 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 665a7af..3e69985 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -115,8 +115,8 @@
self.saved_ckpts = {}
self.step_or_epoch = -1
self.best_step_or_epoch = ""
- self.val_acc_step_or_eoch = {}
- self.val_loss_step_or_eoch = {}
+ self.val_acc_step_or_epoch = {}
+ self.val_loss_step_or_epoch = {}
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
self.start_data_split_i = 0
@@ -161,15 +161,16 @@
# self.step_or_epoch += 1
state = {
"epoch": epoch,
+ "step": step,
+ "total_step": self.batch_total,
"state_dict": model.state_dict(),
"optimizer": optim.state_dict(),
"scheduler": scheduler.state_dict(),
"saved_ckpts": self.saved_ckpts,
- "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
- "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+ "val_acc_step_or_epoch": self.val_acc_step_or_epoch,
+ "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
- "step": step,
"step_in_epoch": step_in_epoch,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
@@ -183,6 +184,7 @@
if scaler:
state["scaler_state"] = scaler.state_dict()
+
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
if step is None:
@@ -191,47 +193,48 @@
ckpt_name = f"model.pt.ep{epoch}.{step}"
filename = os.path.join(self.output_dir, ckpt_name)
torch.save(state, filename)
+ logging.info(f"Checkpoint saved to {filename}")
- logging.info(f"\nCheckpoint saved to {filename}\n")
latest = Path(os.path.join(self.output_dir, f"model.pt"))
torch.save(state, latest)
+
if self.best_step_or_epoch == "":
self.best_step_or_epoch = ckpt_name
if self.avg_keep_nbest_models_type == "acc":
if (
- self.val_acc_step_or_eoch[ckpt_name]
- >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
+ self.val_acc_step_or_epoch[ckpt_name]
+ >= self.val_acc_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
- f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
- f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
elif self.avg_keep_nbest_models_type == "loss":
if (
- self.val_loss_step_or_eoch[ckpt_name]
- <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
+ self.val_loss_step_or_epoch[ckpt_name]
+ <= self.val_loss_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
- f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
- f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
else:
print("Undo")
self.saved_ckpts[ckpt_name] = getattr(
- self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
+ self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
)[ckpt_name]
if self.keep_nbest_models > 0:
if len(self.saved_ckpts) > self.keep_nbest_models:
@@ -278,6 +281,7 @@
k_ddp = k.replace("module.", "", 1)
else:
k_ddp = k
+
if k_ddp in src_state.keys():
dst_state[k] = src_state[k_ddp]
else:
@@ -290,14 +294,14 @@
scaler.load_state_dict(checkpoint["scaler_state"])
self.saved_ckpts = checkpoint["saved_ckpts"]
- self.val_acc_step_or_eoch = (
- checkpoint["val_acc_step_or_eoch"]
- if "val_acc_step_or_eoch" in checkpoint
+ self.val_acc_step_or_epoch = (
+ checkpoint["val_acc_step_or_epoch"]
+ if "val_acc_step_or_epoch" in checkpoint
else {}
)
- self.val_loss_step_or_eoch = (
- checkpoint["val_loss_step_or_eoch"]
- if "val_loss_step_or_eoch" in checkpoint
+ self.val_loss_step_or_epoch = (
+ checkpoint["val_loss_step_or_epoch"]
+ if "val_loss_step_or_epoch" in checkpoint
else {}
)
self.best_step_or_epoch = (
@@ -559,12 +563,14 @@
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
batch = to_device(batch, self.device)
+
time2 = time.perf_counter()
retval = model(**batch)
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
+
if self.use_ddp or self.use_fsdp:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
@@ -577,28 +583,33 @@
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
+
# Scale the loss since we're not updating for every mini-batch
loss = loss
time4 = time.perf_counter()
- self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
- batch_idx + 1
- )
- if "acc" in stats:
- self.val_acc_avg = (
- self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
+ if torch.isfinite(loss):
+ self.val_loss_avg = (
+ self.val_loss_avg * batch_idx + loss.detach().cpu().item()
) / (batch_idx + 1)
- if self.use_ddp or self.use_fsdp:
- val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
- self.device
- )
- val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
- self.device
- )
- dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
- dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
- self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
- self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+
+ if "acc" in stats:
+ self.val_acc_avg = (
+ self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
+ ) / (batch_idx + 1)
+
+ if self.use_ddp or self.use_fsdp:
+ val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
+ self.device
+ )
+ val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
+ self.device
+ )
+ dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+ self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+ self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+
time5 = time.perf_counter()
batch_num_epoch = 1
if hasattr(dataloader_val, "__len__"):
@@ -624,8 +635,8 @@
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
- self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
- self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+ self.val_acc_step_or_epoch[ckpt_name] = self.val_acc_avg
+ self.val_loss_step_or_epoch[ckpt_name] = self.val_loss_avg
model.train()
if self.use_ddp or self.use_fsdp:
@@ -704,7 +715,7 @@
if self.use_wandb and wandb is not None:
wandb.log(
description_dict,
- setp=self.batch_total,
+ step=self.batch_total,
)
def close(self, writer=None):
--
Gitblit v1.9.1