From 2196844d1d6e5b8732c95896bb46f0eacdd9cf9d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 25 九月 2024 15:10:50 +0800
Subject: [PATCH] Dev kws (#2105)
---
funasr/train_utils/trainer.py | 93 ++++++++++++++++++++++++++--------------------
1 files changed, 53 insertions(+), 40 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 665a7af..5fe34b9 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -115,8 +115,8 @@
self.saved_ckpts = {}
self.step_or_epoch = -1
self.best_step_or_epoch = ""
- self.val_acc_step_or_eoch = {}
- self.val_loss_step_or_eoch = {}
+ self.val_acc_step_or_epoch = {}
+ self.val_loss_step_or_epoch = {}
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
self.start_data_split_i = 0
@@ -161,12 +161,14 @@
# self.step_or_epoch += 1
state = {
"epoch": epoch,
+ 'step': step,
+ 'total_step': self.batch_total,
"state_dict": model.state_dict(),
"optimizer": optim.state_dict(),
"scheduler": scheduler.state_dict(),
"saved_ckpts": self.saved_ckpts,
- "val_acc_step_or_eoch": self.val_acc_step_or_eoch,
- "val_loss_step_or_eoch": self.val_loss_step_or_eoch,
+ "val_acc_step_or_epoch": self.val_acc_step_or_epoch,
+ "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
@@ -183,6 +185,7 @@
if scaler:
state["scaler_state"] = scaler.state_dict()
+
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
if step is None:
@@ -191,47 +194,48 @@
ckpt_name = f"model.pt.ep{epoch}.{step}"
filename = os.path.join(self.output_dir, ckpt_name)
torch.save(state, filename)
+ logging.info(f'Checkpoint saved to {filename}')
- logging.info(f"\nCheckpoint saved to {filename}\n")
- latest = Path(os.path.join(self.output_dir, f"model.pt"))
+ latest = Path(os.path.join(self.output_dir, f'model.pt'))
torch.save(state, latest)
+
if self.best_step_or_epoch == "":
self.best_step_or_epoch = ckpt_name
if self.avg_keep_nbest_models_type == "acc":
if (
- self.val_acc_step_or_eoch[ckpt_name]
- >= self.val_acc_step_or_eoch[self.best_step_or_epoch]
+ self.val_acc_step_or_epoch[ckpt_name]
+ >= self.val_acc_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
- f"Update best acc: {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ f"Update best acc: {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
- f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ f"No improvement in acc: {self.val_acc_step_or_epoch[ckpt_name]:.4f} < {self.val_acc_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
elif self.avg_keep_nbest_models_type == "loss":
if (
- self.val_loss_step_or_eoch[ckpt_name]
- <= self.val_loss_step_or_eoch[self.best_step_or_epoch]
+ self.val_loss_step_or_epoch[ckpt_name]
+ <= self.val_loss_step_or_epoch[self.best_step_or_epoch]
):
self.best_step_or_epoch = ckpt_name
best_ckpt = Path(os.path.join(self.output_dir, f"model.pt.best"))
torch.save(state, best_ckpt)
logging.info(
- f"Update best loss: {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
+ f"Update best loss: {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {best_ckpt}"
)
else:
logging.info(
- f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
+ f"No improvement in loss: {self.val_loss_step_or_epoch[ckpt_name]:.4f} > {self.val_loss_step_or_epoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
else:
print("Undo")
self.saved_ckpts[ckpt_name] = getattr(
- self, f"val_{self.avg_keep_nbest_models_type}_step_or_eoch"
+ self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch"
)[ckpt_name]
if self.keep_nbest_models > 0:
if len(self.saved_ckpts) > self.keep_nbest_models:
@@ -278,6 +282,7 @@
k_ddp = k.replace("module.", "", 1)
else:
k_ddp = k
+
if k_ddp in src_state.keys():
dst_state[k] = src_state[k_ddp]
else:
@@ -290,14 +295,14 @@
scaler.load_state_dict(checkpoint["scaler_state"])
self.saved_ckpts = checkpoint["saved_ckpts"]
- self.val_acc_step_or_eoch = (
- checkpoint["val_acc_step_or_eoch"]
- if "val_acc_step_or_eoch" in checkpoint
+ self.val_acc_step_or_epoch = (
+ checkpoint["val_acc_step_or_epoch"]
+ if "val_acc_step_or_epoch" in checkpoint
else {}
)
- self.val_loss_step_or_eoch = (
- checkpoint["val_loss_step_or_eoch"]
- if "val_loss_step_or_eoch" in checkpoint
+ self.val_loss_step_or_epoch = (
+ checkpoint["val_loss_step_or_epoch"]
+ if "val_loss_step_or_epoch" in checkpoint
else {}
)
self.best_step_or_epoch = (
@@ -327,6 +332,7 @@
if self.use_ddp or self.use_fsdp:
dist.barrier()
+
def train_epoch(
self,
@@ -559,12 +565,14 @@
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1 - time5:0.3f}"
batch = to_device(batch, self.device)
+
time2 = time.perf_counter()
retval = model(**batch)
time3 = time.perf_counter()
speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
+
if self.use_ddp or self.use_fsdp:
# Apply weighted averaging for loss and stats
loss = (loss * weight.type(loss.dtype)).sum()
@@ -577,28 +585,33 @@
# Multiply world_size because DistributedDataParallel
# automatically normalizes the gradient by world_size.
loss *= self.world_size
+
# Scale the loss since we're not updating for every mini-batch
loss = loss
time4 = time.perf_counter()
- self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
- batch_idx + 1
- )
- if "acc" in stats:
- self.val_acc_avg = (
- self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
- ) / (batch_idx + 1)
- if self.use_ddp or self.use_fsdp:
- val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
- self.device
+ if torch.isfinite(loss):
+ self.val_loss_avg = (self.val_loss_avg * batch_idx + loss.detach().cpu().item()) / (
+ batch_idx + 1
)
- val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
- self.device
- )
- dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
- dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
- self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
- self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+
+ if "acc" in stats:
+ self.val_acc_avg = (
+ self.val_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
+ ) / (batch_idx + 1)
+
+ if self.use_ddp or self.use_fsdp:
+ val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(
+ self.device
+ )
+ val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(
+ self.device
+ )
+ dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
+ self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
+ self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+
time5 = time.perf_counter()
batch_num_epoch = 1
if hasattr(dataloader_val, "__len__"):
@@ -624,8 +637,8 @@
ckpt_name = f"model.pt.ep{epoch}"
else:
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
- self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
- self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+ self.val_acc_step_or_epoch[ckpt_name] = self.val_acc_avg
+ self.val_loss_step_or_epoch[ckpt_name] = self.val_loss_avg
model.train()
if self.use_ddp or self.use_fsdp:
--
Gitblit v1.9.1