From 557b913c5d78945e49cb2ac2bf254a2de40b6cd5 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 14 三月 2024 11:48:31 +0800
Subject: [PATCH] v1.0.16
---
funasr/train_utils/trainer.py | 19 ++++++++++++++-----
1 files changed, 14 insertions(+), 5 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 3b20596..a00b3de 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -88,6 +88,7 @@
scaler = GradScaler(enabled=use_fp16) if use_fp16 else None
scaler = ShardedGradScaler(enabled=use_fp16) if use_ddp else scaler
self.scaler = scaler
+ self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000)
try:
@@ -104,7 +105,7 @@
self.writer = SummaryWriter(os.path.join(self.output_dir, "tensorboard")) if rank == 0 else None
- def _save_checkpoint(self, epoch):
+ def _save_checkpoint(self, epoch, step=None):
"""
Saves a checkpoint containing the model's state, the optimizer's state,
and the scheduler's state at the end of the given epoch. This method is
@@ -123,7 +124,11 @@
state["scaler_state"] = self.scaler.state_dict()
# Create output directory if it does not exist
os.makedirs(self.output_dir, exist_ok=True)
- filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
+ if step is None:
+ filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}')
+ else:
+ filename = os.path.join(self.output_dir, f'model.pt.ep{epoch}.{step}')
+
torch.save(state, filename)
print(f'\nCheckpoint saved to {filename}\n')
@@ -163,7 +168,7 @@
self.scaler.load_state_dict(checkpoint['scaler_state'])
print(f"Checkpoint loaded successfully from '{ckpt}'")
else:
- print(f"No checkpoint found at '{ckpt}', starting from scratch")
+ print(f"No checkpoint found at '{ckpt}', does not resume status!")
if self.use_ddp or self.use_fsdp:
dist.barrier()
@@ -337,8 +342,10 @@
for key, var in speed_stats.items():
self.writer.add_scalar(f'rank{self.local_rank}_{key}/train', eval(var), self.batch_total)
-
+ if (batch_idx+1) % self.save_checkpoint_interval == 0 and self.rank == 0:
+ self._save_checkpoint(epoch, step=batch_idx+1)
pbar.close()
+
def _validate_epoch(self, epoch):
"""
@@ -401,4 +408,6 @@
epoch * len(self.dataloader_val) + batch_idx)
for key, var in speed_stats.items():
self.writer.add_scalar(f'rank{self.local_rank}_{key}/val', eval(var),
- epoch * len(self.dataloader_val) + batch_idx)
\ No newline at end of file
+ epoch * len(self.dataloader_val) + batch_idx)
+
+ self.model.train()
\ No newline at end of file
--
Gitblit v1.9.1