From 1ca314955fbe150db9a3f40193ca10736a9a4260 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 五月 2024 13:32:51 +0800
Subject: [PATCH] deepspeed

---
 funasr/train_utils/trainer_ds.py |   10 +++++-----
 1 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index aa3c5af..72f7b75 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -397,7 +397,7 @@
                 time4 = time.perf_counter()
                 loss_dict["speed_stats"]["backward_time"] = f"{time4 - time3:0.3f}"
 
-            self.update_step(model, optim, scheduler, scaler, loss_dict)
+            self.update_step(model, optim, scheduler, scaler, loss_dict=loss_dict)
             total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}"
             time5 = time.perf_counter()
 
@@ -415,7 +415,7 @@
                     model=model,
                     dataloader_val=dataloader_val,
                     epoch=epoch,
-                    writer=writer,
+                    writer=self.writer,
                     step=batch_idx + 1,
                     step_in_epoch=self.step_in_epoch,
                 )
@@ -469,8 +469,8 @@
             else:
                 loss.backward()
 
-    def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=None):
-
+    def update_step(self, model, optim, scheduler, scaler, loss_dict=None):
+        batch_idx = loss_dict["batch_idx"]
         if self.use_deepspeed:
             model.step()
         else:
@@ -747,7 +747,6 @@
         from funasr.schedulers import scheduler_classes
         from omegaconf import OmegaConf, DictConfig
         import json
-        import deepspeed
 
         # optim
         logging.info("Build optim")
@@ -764,6 +763,7 @@
         scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
 
         if self.use_deepspeed:
+            import deepspeed
 
             args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
             with open(self.deepspeed_config, "r") as fin:

--
Gitblit v1.9.1