From 2cca8104d26b454112f39b8405dcb0e70d365990 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 19 一月 2024 17:05:08 +0800
Subject: [PATCH] Funasr1.0 (#1275)
---
funasr/train_utils/trainer.py | 20 ++++++++++++++++----
1 files changed, 16 insertions(+), 4 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 91b30b0..62d6be8 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -147,9 +147,17 @@
for epoch in range(self.start_epoch, self.max_epoch + 1):
self._train_epoch(epoch)
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
self._validate_epoch(epoch)
-
+
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
+
+
if self.rank == 0:
self._save_checkpoint(epoch)
@@ -164,7 +172,9 @@
if self.use_ddp or self.use_fsdp:
dist.barrier()
- self.writer.close()
+
+ if self.writer:
+ self.writer.close()
def _train_epoch(self, epoch):
@@ -230,6 +240,8 @@
continue
# Execute an optimization step (update model parameters)
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
self.optim.step()
self.scheduler.step()
# Clear gradients for the next accumulation stage
@@ -244,7 +256,7 @@
pbar.update(1)
if self.local_rank == 0:
description = (
- f"Epoch: {epoch}/{self.max_epoch}, "
+ f"Train epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
@@ -306,7 +318,7 @@
pbar.update(1)
if self.local_rank == 0:
description = (
- f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
+ f"validation epoch: {epoch}/{self.max_epoch}, "
f"step {batch_idx}/{len(self.dataloader_train)}, "
f"{speed_stats}, "
f"(loss: {loss.detach().cpu().item():.3f}), "
--
Gitblit v1.9.1