From 1c8b46a233ac4a782d7170e20533f536761e25c4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 09 六月 2024 00:21:44 +0800
Subject: [PATCH] fix bug

---
 funasr/train_utils/trainer_ds.py |    6 ++++++
 1 files changed, 6 insertions(+), 0 deletions(-)

diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index ec887cc..ec76531 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -167,6 +167,8 @@
         Args:
             epoch (int): The epoch number at which the checkpoint is being saved.
         """
+        if self.use_ddp or self.use_fsdp:
+            dist.barrier()
         step_in_epoch = None if step is None else step_in_epoch
         if self.use_deepspeed:
 
@@ -760,6 +762,10 @@
             ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
         self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
         self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+
+        if self.use_ddp or self.use_fsdp or self.use_deepspeed:
+            dist.barrier()
+
         model.train()
 
     def log(

--
Gitblit v1.9.1