From b85e140c3e4a7a7ccba59abfc67b63aac7a28dd9 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 五月 2024 17:28:56 +0800
Subject: [PATCH] update
---
funasr/train_utils/trainer_ds.py | 12 ++++++------
1 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index aa3c5af..8afbc6d 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -397,8 +397,8 @@
time4 = time.perf_counter()
loss_dict["speed_stats"]["backward_time"] = f"{time4 - time3:0.3f}"
- self.update_step(model, optim, scheduler, scaler, loss_dict)
- total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}"
+ self.update_step(model, optim, scheduler, scaler, loss_dict=loss_dict)
+ total_time = f"{(time.perf_counter() - time5):0.3f}"
time5 = time.perf_counter()
loss_dict["speed_stats"]["optim_time"] = f"{time5 - time4:0.3f}"
@@ -415,7 +415,7 @@
model=model,
dataloader_val=dataloader_val,
epoch=epoch,
- writer=writer,
+ writer=self.writer,
step=batch_idx + 1,
step_in_epoch=self.step_in_epoch,
)
@@ -469,8 +469,8 @@
else:
loss.backward()
- def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=None):
-
+ def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
+ batch_idx = loss_dict["batch_idx"]
if self.use_deepspeed:
model.step()
else:
@@ -747,7 +747,6 @@
from funasr.schedulers import scheduler_classes
from omegaconf import OmegaConf, DictConfig
import json
- import deepspeed
# optim
logging.info("Build optim")
@@ -764,6 +763,7 @@
scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
if self.use_deepspeed:
+ import deepspeed
args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
with open(self.deepspeed_config, "r") as fin:
--
Gitblit v1.9.1