From 85c08383831ea2b7cdf4c6f863f71b20b95b6782 Mon Sep 17 00:00:00 2001
From: 夜雨飘零 <yeyupiaoling@foxmail.com>
Date: 星期五, 02 二月 2024 16:56:16 +0800
Subject: [PATCH] support funasr 1.0 (#1346)

---
 funasr/train_utils/trainer.py |   22 ++++++++++++++++++----
 1 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 91b30b0..414c0d7 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -147,9 +147,18 @@
         for epoch in range(self.start_epoch, self.max_epoch + 1):
             
             self._train_epoch(epoch)
+
+
             
+            if self.use_ddp or self.use_fsdp:
+                dist.barrier()
+                
             self._validate_epoch(epoch)
-            
+
+            if self.use_ddp or self.use_fsdp:
+                dist.barrier()
+           
+           
             if self.rank == 0:
                 self._save_checkpoint(epoch)
             
@@ -164,7 +173,10 @@
             
         if self.use_ddp or self.use_fsdp:
             dist.barrier()
-        self.writer.close()
+
+
+        if self.writer:
+            self.writer.close()
         
     
     def _train_epoch(self, epoch):
@@ -230,6 +242,8 @@
                         continue
                 
                 # Execute an optimization step (update model parameters)
+                if self.use_ddp or self.use_fsdp:
+                    dist.barrier()
                 self.optim.step()
                 self.scheduler.step()
                 # Clear gradients for the next accumulation stage
@@ -244,7 +258,7 @@
             pbar.update(1)
             if self.local_rank == 0:
                 description = (
-                    f"Epoch: {epoch}/{self.max_epoch}, "
+                    f"Train epoch: {epoch}/{self.max_epoch}, "
                     f"step {batch_idx}/{len(self.dataloader_train)}, "
                     f"{speed_stats}, "
                     f"(loss: {loss.detach().cpu().item():.3f}), "
@@ -306,7 +320,7 @@
                 pbar.update(1)
                 if self.local_rank == 0:
                     description = (
-                        f"validation: \nEpoch: {epoch}/{self.max_epoch}, "
+                        f"validation epoch: {epoch}/{self.max_epoch}, "
                         f"step {batch_idx}/{len(self.dataloader_train)}, "
                         f"{speed_stats}, "
                         f"(loss: {loss.detach().cpu().item():.3f}), "

--
Gitblit v1.9.1