From abb33d6b2097e5b0643326bc1b376a63cdc2f967 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 17:06:21 +0800
Subject: [PATCH] Dev gzf deepspeed (#1844)
---
funasr/train_utils/trainer.py | 28 ++++++++++++++--------------
1 files changed, 14 insertions(+), 14 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index afc632d..665a7af 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -362,10 +362,10 @@
time_beg = time.perf_counter()
time5 = time_beg
for batch_idx, batch in enumerate(dataloader_train):
- if self.use_ddp or self.use_fsdp:
- dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
- if iterator_stop > 0:
- break
+ # if self.use_ddp or self.use_fsdp:
+ # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+ # if iterator_stop > 0:
+ # break
self.batch_total += 1
self.step_in_epoch += 1
time1 = time.perf_counter()
@@ -381,11 +381,11 @@
with maybe_autocast(self.use_fp16):
retval = model(**batch)
- if (
- self.reset_gpu_cache
- and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
- ):
- torch.cuda.empty_cache()
+ # if (
+ # self.reset_gpu_cache
+ # and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70
+ # ):
+ # torch.cuda.empty_cache()
loss, stats, weight = retval
stats = {k: v for k, v in stats.items() if v is not None}
@@ -516,14 +516,14 @@
)
time_beg = time.perf_counter()
- else:
- if self.use_ddp or self.use_fsdp:
- iterator_stop.fill_(1)
- dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
+ # else:
+ # if self.use_ddp or self.use_fsdp:
+ # iterator_stop.fill_(1)
+ # dist.all_reduce(iterator_stop, dist.ReduceOp.SUM)
if self.use_ddp or self.use_fsdp:
dist.barrier()
- iterator_stop = torch.tensor(0).to(self.device)
+ # iterator_stop = torch.tensor(0).to(self.device)
def validate_epoch(
self,
--
Gitblit v1.9.1