From 2cca8104d26b454112f39b8405dcb0e70d365990 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 19 一月 2024 17:05:08 +0800
Subject: [PATCH] Funasr1.0 (#1275)

---
 funasr/train_utils/trainer.py |   20 ++++++++++++++++----
 1 files changed, 16 insertions(+), 4 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 91b30b0..62d6be8 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -147,9 +147,17 @@
         for epoch in range(self.start_epoch, self.max_epoch + 1):
             
             self._train_epoch(epoch)
+
             
+            if self.use_ddp or self.use_fsdp:
+                dist.barrier()
+                
             self._validate_epoch(epoch)
-            
+
+            if self.use_ddp or self.use_fsdp:
+                dist.barrier()
+                
+
             if self.rank == 0:
                 self._save_checkpoint(epoch)
             
@@ -164,7 +172,9 @@
             
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-        self.writer.close()
+      
+        if self.writer:
+            self.writer.close()
         
     
     def _train_epoch(self, epoch):
@@ -230,6 +240,8 @@
                         continue
                 
                 # Execute an optimization step (update model parameters)
+                if self.use_ddp or self.use_fsdp:
+                    dist.barrier()
                 self.optim.step()
                 self.scheduler.step()
                 # Clear gradients for the next accumulation stage
@@ -244,7 +256,7 @@
             pbar.update(1)
             if self.local_rank == 0:
                 description = (
-                    f"Epoch: {epoch}/{self.max_epoch}, "
+                    f"Train epoch: {epoch}/{self.max_epoch}, "
                     f"step {batch_idx}/{len(self.dataloader_train)}, "
                     f"{speed_stats}, "
                     f"(loss: {loss.detach().cpu().item():.3f}), "
@@ -306,7 +318,7 @@
                 pbar.update(1)
                 if self.local_rank == 0:
                     description = (
-                        f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
+                        f"validation epoch: {epoch}/{self.max_epoch}, "
                         f"step {batch_idx}/{len(self.dataloader_train)}, "
                         f"{speed_stats}, "
                         f"(loss: {loss.detach().cpu().item():.3f}), "

--
Gitblit v1.9.1