From b7ae3d52681ef4f5611b059762788af7d6a37190 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 28 四月 2024 17:42:33 +0800
Subject: [PATCH] Dev gzf exp (#1672)
---
funasr/train_utils/trainer.py | 47 ++++++++++++++++++++++-------------------------
funasr/bin/train.py | 8 +++++---
2 files changed, 27 insertions(+), 28 deletions(-)
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index 2af6a59..97516eb 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -223,11 +223,13 @@
torch.cuda.empty_cache()
trainer.validate_epoch(
- model=model, dataloader_val=dataloader_val, epoch=epoch, writer=writer
+ model=model, dataloader_val=dataloader_val, epoch=epoch + 1, writer=writer
)
scheduler.step()
- trainer.step_cur_in_epoch = 0
- trainer.save_checkpoint(epoch, model=model, optim=optim, scheduler=scheduler, scaler=scaler)
+ trainer.step_in_epoch = 0
+ trainer.save_checkpoint(
+ epoch + 1, model=model, optim=optim, scheduler=scheduler, scaler=scaler
+ )
time2 = time.perf_counter()
time_escaped = (time2 - time1) / 3600.0
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 5685b8f..e86420c 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -116,7 +116,7 @@
self.reset_gpu_cache = kwargs.get("reset_gpu_cache", False)
self.start_data_split_i = 0
self.start_step = 0
- self.step_cur_in_epoch = 0
+ self.step_in_epoch = 0
self.use_wandb = kwargs.get("use_wandb", False)
if self.use_wandb:
wandb.login(key=kwargs.get("wandb_token"))
@@ -138,7 +138,7 @@
optim=None,
scheduler=None,
scaler=None,
- step_cur_in_epoch=None,
+ step_in_epoch=None,
**kwargs,
):
"""
@@ -150,7 +150,7 @@
epoch (int): The epoch number at which the checkpoint is being saved.
"""
- step_cur_in_epoch = None if step is None else step_cur_in_epoch
+ step_in_epoch = None if step is None else step_in_epoch
if self.rank == 0:
logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
# self.step_or_epoch += 1
@@ -165,12 +165,12 @@
"best_step_or_epoch": self.best_step_or_epoch,
"avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
"step": step,
- "step_cur_in_epoch": step_cur_in_epoch,
+ "step_in_epoch": step_in_epoch,
"data_split_i": kwargs.get("data_split_i", 0),
"data_split_num": kwargs.get("data_split_num", 1),
"batch_total": self.batch_total,
}
- step = step_cur_in_epoch
+ step = step_in_epoch
if hasattr(model, "module"):
state["state_dict"] = model.module.state_dict()
@@ -204,7 +204,7 @@
)
else:
logging.info(
- f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}"
+ f"No improvement in acc: {self.val_acc_step_or_eoch[ckpt_name]:.4f} < {self.val_acc_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
elif self.avg_keep_nbest_models_type == "loss":
if (
@@ -219,7 +219,7 @@
)
else:
logging.info(
- f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}"
+ f"No improvement in loss: {self.val_loss_step_or_eoch[ckpt_name]:.4f} > {self.val_loss_step_or_eoch[self.best_step_or_epoch]:.4f}, {os.path.join(self.output_dir, self.best_step_or_epoch)}"
)
else:
print("Undo")
@@ -260,7 +260,7 @@
ckpt = os.path.join(self.output_dir, "model.pt")
if os.path.isfile(ckpt):
checkpoint = torch.load(ckpt, map_location="cpu")
- self.start_epoch = checkpoint["epoch"] + 1
+ self.start_epoch = checkpoint["epoch"]
# self.model.load_state_dict(checkpoint['state_dict'])
src_state = checkpoint["state_dict"]
dst_state = model.state_dict()
@@ -297,17 +297,15 @@
checkpoint["best_step_or_epoch"] if "best_step_or_epoch" in checkpoint else ""
)
self.start_data_split_i = (
- checkpoint["start_data_split_i"] if "start_data_split_i" in checkpoint else 0
+ checkpoint["data_split_i"] if "data_split_i" in checkpoint else 0
)
self.batch_total = checkpoint["batch_total"] if "batch_total" in checkpoint else 0
self.start_step = checkpoint["step"] if "step" in checkpoint else 0
self.start_step = 0 if self.start_step is None else self.start_step
- self.step_cur_in_epoch = (
- checkpoint["step_cur_in_epoch"] if "step_cur_in_epoch" in checkpoint else 0
+ self.step_in_epoch = (
+ checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
)
- self.step_cur_in_epoch = (
- 0 if self.step_cur_in_epoch is None else self.step_cur_in_epoch
- )
+ self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
model.to(self.device)
print(f"Checkpoint loaded successfully from '{ckpt}'")
@@ -356,7 +354,7 @@
if iterator_stop > 0:
break
self.batch_total += 1
- self.step_cur_in_epoch += 1
+ self.step_in_epoch += 1
time1 = time.perf_counter()
speed_stats["data_load"] = f"{time1-time_beg:0.3f}"
@@ -459,7 +457,7 @@
self.log(
epoch,
batch_idx,
- step_cur_in_epoch=self.step_cur_in_epoch,
+ step_in_epoch=self.step_in_epoch,
batch_num_epoch=batch_num_epoch,
lr=lr,
loss=loss.detach().cpu().item(),
@@ -471,17 +469,17 @@
data_split_num=kwargs.get("data_split_num", 1),
)
- if (batch_idx + 1) % self.validate_interval == 0:
+ if self.step_in_epoch % self.validate_interval == 0:
self.validate_epoch(
model=model,
dataloader_val=dataloader_val,
epoch=epoch,
writer=writer,
step=batch_idx + 1,
- step_cur_in_epoch=self.step_cur_in_epoch,
+ step_in_epoch=self.step_in_epoch,
)
- if (batch_idx + 1) % self.save_checkpoint_interval == 0:
+ if self.step_in_epoch % self.save_checkpoint_interval == 0:
self.save_checkpoint(
epoch,
model=model,
@@ -489,7 +487,7 @@
scheduler=scheduler,
scaler=scaler,
step=batch_idx + 1,
- step_cur_in_epoch=self.step_cur_in_epoch,
+ step_in_epoch=self.step_in_epoch,
data_split_i=kwargs.get("data_split_i", 0),
data_split_num=kwargs.get("data_split_num", 1),
)
@@ -599,10 +597,10 @@
iterator_stop.fill_(1)
dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
- if kwargs.get("step_cur_in_epoch", None) is None:
+ if kwargs.get("step_in_epoch", None) is None:
ckpt_name = f"model.pt.ep{epoch}"
else:
- ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_cur_in_epoch")}'
+ ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
model.train()
@@ -615,7 +613,7 @@
self,
epoch=0,
batch_idx=0,
- step_cur_in_epoch=0,
+ step_in_epoch=0,
batch_num_epoch=-1,
lr=0.0,
loss=0.0,
@@ -648,9 +646,8 @@
f"{tag}, "
f"rank: {self.rank}, "
f"epoch: {epoch}/{self.max_epoch}, "
- f"step_cur_in_epoch: {step_cur_in_epoch}, "
f"data_slice: {data_split_i}/{data_split_num}, "
- f"step: {batch_idx + 1}/{batch_num_epoch}, total step: {self.batch_total}, "
+ f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
f"(loss_avg_rank: {loss:.3f}), "
f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), "
--
Gitblit v1.9.1