From 1ca314955fbe150db9a3f40193ca10736a9a4260 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 五月 2024 13:32:51 +0800
Subject: [PATCH] deepspeed
---
funasr/train_utils/trainer_ds.py | 10 +++++-----
1 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index aa3c5af..72f7b75 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -397,7 +397,7 @@
time4 = time.perf_counter()
loss_dict["speed_stats"]["backward_time"] = f"{time4 - time3:0.3f}"
- self.update_step(model, optim, scheduler, scaler, loss_dict)
+ self.update_step(model, optim, scheduler, scaler, loss_dict=loss_dict)
total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}"
time5 = time.perf_counter()
@@ -415,7 +415,7 @@
model=model,
dataloader_val=dataloader_val,
epoch=epoch,
- writer=writer,
+ writer=self.writer,
step=batch_idx + 1,
step_in_epoch=self.step_in_epoch,
)
@@ -469,8 +469,8 @@
else:
loss.backward()
- def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=None):
-
+ def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
+ batch_idx = loss_dict["batch_idx"]
if self.use_deepspeed:
model.step()
else:
@@ -747,7 +747,6 @@
from funasr.schedulers import scheduler_classes
from omegaconf import OmegaConf, DictConfig
import json
- import deepspeed
# optim
logging.info("Build optim")
@@ -764,6 +763,7 @@
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
if self.use_deepspeed:
+ import deepspeed
args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
with open(self.deepspeed_config, "r") as fin:
--
Gitblit v1.9.1