From 1c8b46a233ac4a782d7170e20533f536761e25c4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 09 六月 2024 00:21:44 +0800
Subject: [PATCH] fix bug
---
funasr/train_utils/trainer_ds.py | 6 ++++++
1 files changed, 6 insertions(+), 0 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index ec887cc..ec76531 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -167,6 +167,8 @@
Args:
epoch (int): The epoch number at which the checkpoint is being saved.
"""
+ if self.use_ddp or self.use_fsdp:
+ dist.barrier()
step_in_epoch = None if step is None else step_in_epoch
if self.use_deepspeed:
@@ -760,6 +762,10 @@
ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
+
+ if self.use_ddp or self.use_fsdp or self.use_deepspeed:
+ dist.barrier()
+
model.train()
def log(
--
Gitblit v1.9.1