From d3ff05837bbc14749d09f44947633b87e8f2db0e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 五月 2024 11:47:27 +0800
Subject: [PATCH] deepspeed

---
 funasr/train_utils/trainer_ds.py |   52 +++++++++++++++++++---------------------------------
 1 files changed, 19 insertions(+), 33 deletions(-)

diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index f3a0309..aa3c5af 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -78,7 +78,7 @@
         self.world_size = world_size
         self.use_ddp = use_ddp
         self.use_fsdp = use_fsdp
-        self.use_deepspeed = use_deepspeed
+
         self.device = kwargs.get("device", "cuda")
 
         self.output_dir = output_dir
@@ -136,6 +136,9 @@
             self.writer = SummaryWriter(tensorboard_dir)  # if trainer.rank == 0 else None
         except:
             self.writer = None
+
+        self.use_deepspeed = use_deepspeed
+        self.deepspeed_config = kwargs.get("deepspeed_config", "")
 
     def save_checkpoint(
         self,
@@ -443,7 +446,8 @@
             iterator_stop = torch.tensor(0).to(self.device)
 
     def forward_step(self, model, batch, loss_dict={}):
-        with maybe_autocast(self.use_fp16):
+        dtype = torch.bfloat16
+        with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
             retval = model(**batch)
 
         loss, stats, weight = retval
@@ -465,7 +469,7 @@
             else:
                 loss.backward()
 
-    def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict):
+    def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=None):
 
         if self.use_deepspeed:
             model.step()
@@ -613,7 +617,7 @@
         loss = loss_dict["loss"].detach().cpu().item()
         epoch = loss_dict["epoch"]
         batch_idx = loss_dict["batch_idx"]
-        step_in_epoch = loss_dict["step_in_epoch"]
+        step_in_epoch = self.step_in_epoch
         batch_num_epoch = loss_dict["batch_num_epoch"]
         lr = loss_dict["lr"]
 
@@ -732,36 +736,18 @@
                     "find_unused_parameters", False
                 ),
             )
-        # elif self.use_fsdp:
-        #     # model = FSDP(model).cuda(local_rank)
-        #
-        #     def custom_auto_wrap_policy(
-        #         module: nn.Module,
-        #         recurse: bool,
-        #         nonwrapped_numel: int,
-        #         # Additional custom arguments
-        #         min_num_params: int = int(1e8),
-        #     ) -> bool:
-        #         # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
-        #         is_large = unwrapped_params >= min_num_params
-        #         requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
-        #         return is_large and requires_grad_uniform
-        #
-        #     # Configure a custom `min_num_params`
-        #     my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
-        #     torch.cuda.set_device(local_rank)
-        #     model = FSDP(
-        #         model,
-        #         auto_wrap_policy=custom_auto_wrap_policy,
-        #         mixed_precision=None,
-        #         device_id=torch.cuda.current_device(),
-        #     )
+
         else:
             model = model.to(device=kwargs.get("device", "cuda"))
 
         return model
 
     def warp_optim_scheduler(self, model, **kwargs):
+        from funasr.optimizers import optim_classes
+        from funasr.schedulers import scheduler_classes
+        from omegaconf import OmegaConf, DictConfig
+        import json
+        import deepspeed
 
         # optim
         logging.info("Build optim")
@@ -777,15 +763,16 @@
         scheduler_class = scheduler_classes.get(scheduler)
         scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
 
-        if use_deepspeed:
-            deepspeed_config = kwargs.get("deepspeed_config", "")
-            with open(deepspeed_config, "r") as fin:
+        if self.use_deepspeed:
+
+            args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
+            with open(self.deepspeed_config, "r") as fin:
                 ds_configs = json.load(fin)
             if "optimizer" in ds_configs:
                 # NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
                 # extremely useful when enable cpu_offload, DeepspeedCpuAdam
                 # could be 4~5x faster than torch native adam
-                deepspeed_config = None
+                optim = None
                 if "scheduler" in ds_configs:
                     scheduler = None
                 else:
@@ -793,7 +780,6 @@
                     def scheduler(opt):
                         return scheduler_class(opt, **kwargs.get("scheduler_conf"))
 
-            args = OmegaConf.create({"deepspeed_config": deepspeed_config})
             model, optimizer, _, scheduler = deepspeed.initialize(
                 args=args,
                 model=model,

--
Gitblit v1.9.1