From d3ff05837bbc14749d09f44947633b87e8f2db0e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 五月 2024 11:47:27 +0800
Subject: [PATCH] deepspeed

---
 funasr/models/sense_voice/decoder.py           |    4 +
 funasr/bin/train_ds.py                         |    2 
 funasr/train_utils/trainer_ds.py               |   52 +++++++++----------------
 funasr/models/sense_voice/whisper_lib/model.py |    4 +
 funasr/models/sanm/attention.py                |   16 ++++++--
 5 files changed, 38 insertions(+), 40 deletions(-)

diff --git a/funasr/bin/train_ds.py b/funasr/bin/train_ds.py
index 032a0cf..b0931b0 100644
--- a/funasr/bin/train_ds.py
+++ b/funasr/bin/train_ds.py
@@ -133,7 +133,7 @@
     kwargs["device"] = next(model.parameters()).device
     trainer.device = kwargs["device"]
 
-    model, optim, scheduler = trainer.warp_optim_scheduler(model, kwargs)
+    model, optim, scheduler = trainer.warp_optim_scheduler(model, **kwargs)
 
     # dataset
     logging.info("Build dataloader")
diff --git a/funasr/models/sanm/attention.py b/funasr/models/sanm/attention.py
index da8850f..08f7dc7 100644
--- a/funasr/models/sanm/attention.py
+++ b/funasr/models/sanm/attention.py
@@ -100,7 +100,9 @@
         n_batch = value.size(0)
         if mask is not None:
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+            min_value = -float(
+                "inf"
+            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             scores = scores.masked_fill(mask, min_value)
             self.attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0
@@ -269,7 +271,9 @@
 
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
 
-            min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+            min_value = -float(
+                "inf"
+            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             scores = scores.masked_fill(mask, min_value)
             self.attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0
@@ -673,7 +677,9 @@
         n_batch = value.size(0)
         if mask is not None:
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+            min_value = -float(
+                "inf"
+            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             # logging.info(
             #     "scores: {}, mask_size: {}".format(scores.size(), mask.size()))
             scores = scores.masked_fill(mask, min_value)
@@ -858,7 +864,9 @@
 
             mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
 
-            min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
+            min_value = -float(
+                "inf"
+            )  # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
             scores = scores.masked_fill(mask, min_value)
             self.attn = torch.softmax(scores, dim=-1).masked_fill(
                 mask, 0.0
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index 03b7532..60af29a 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -146,7 +146,9 @@
                 qk = qk + mask[:n_ctx, :n_ctx]
             else:
                 mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-                min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
+                min_value = -float(
+                    "inf"
+                )  # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
                 qk = qk.masked_fill(mask, min_value)
 
         qk = qk.float()
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index 40939df..8b3d3ab 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -112,7 +112,9 @@
                 qk = qk + mask[:n_ctx, :n_ctx]
             else:
                 mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-                min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
+                min_value = -float(
+                    "inf"
+                )  # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
                 qk = qk.masked_fill(mask, min_value)
 
         qk = qk.float()
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index f3a0309..aa3c5af 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -78,7 +78,7 @@
         self.world_size = world_size
         self.use_ddp = use_ddp
         self.use_fsdp = use_fsdp
-        self.use_deepspeed = use_deepspeed
+
         self.device = kwargs.get("device", "cuda")
 
         self.output_dir = output_dir
@@ -136,6 +136,9 @@
             self.writer = SummaryWriter(tensorboard_dir)  # if trainer.rank == 0 else None
         except:
             self.writer = None
+
+        self.use_deepspeed = use_deepspeed
+        self.deepspeed_config = kwargs.get("deepspeed_config", "")
 
     def save_checkpoint(
         self,
@@ -443,7 +446,8 @@
             iterator_stop = torch.tensor(0).to(self.device)
 
     def forward_step(self, model, batch, loss_dict={}):
-        with maybe_autocast(self.use_fp16):
+        dtype = torch.bfloat16
+        with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
             retval = model(**batch)
 
         loss, stats, weight = retval
@@ -465,7 +469,7 @@
             else:
                 loss.backward()
 
-    def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict):
+    def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=None):
 
         if self.use_deepspeed:
             model.step()
@@ -613,7 +617,7 @@
         loss = loss_dict["loss"].detach().cpu().item()
         epoch = loss_dict["epoch"]
         batch_idx = loss_dict["batch_idx"]
-        step_in_epoch = loss_dict["step_in_epoch"]
+        step_in_epoch = self.step_in_epoch
         batch_num_epoch = loss_dict["batch_num_epoch"]
         lr = loss_dict["lr"]
 
@@ -732,36 +736,18 @@
                     "find_unused_parameters", False
                 ),
             )
-        # elif self.use_fsdp:
-        #     # model = FSDP(model).cuda(local_rank)
-        #
-        #     def custom_auto_wrap_policy(
-        #         module: nn.Module,
-        #         recurse: bool,
-        #         nonwrapped_numel: int,
-        #         # Additional custom arguments
-        #         min_num_params: int = int(1e8),
-        #     ) -> bool:
-        #         # 鏍规嵁鑷畾涔夐�昏緫鍐冲畾鏄惁鍖呰妯″潡
-        #         is_large = unwrapped_params >= min_num_params
-        #         requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
-        #         return is_large and requires_grad_uniform
-        #
-        #     # Configure a custom `min_num_params`
-        #     my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
-        #     torch.cuda.set_device(local_rank)
-        #     model = FSDP(
-        #         model,
-        #         auto_wrap_policy=custom_auto_wrap_policy,
-        #         mixed_precision=None,
-        #         device_id=torch.cuda.current_device(),
-        #     )
+
         else:
             model = model.to(device=kwargs.get("device", "cuda"))
 
         return model
 
     def warp_optim_scheduler(self, model, **kwargs):
+        from funasr.optimizers import optim_classes
+        from funasr.schedulers import scheduler_classes
+        from omegaconf import OmegaConf, DictConfig
+        import json
+        import deepspeed
 
         # optim
         logging.info("Build optim")
@@ -777,15 +763,16 @@
         scheduler_class = scheduler_classes.get(scheduler)
         scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
 
-        if use_deepspeed:
-            deepspeed_config = kwargs.get("deepspeed_config", "")
-            with open(deepspeed_config, "r") as fin:
+        if self.use_deepspeed:
+
+            args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
+            with open(self.deepspeed_config, "r") as fin:
                 ds_configs = json.load(fin)
             if "optimizer" in ds_configs:
                 # NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
                 # extremely useful when enable cpu_offload, DeepspeedCpuAdam
                 # could be 4~5x faster than torch native adam
-                deepspeed_config = None
+                optim = None
                 if "scheduler" in ds_configs:
                     scheduler = None
                 else:
@@ -793,7 +780,6 @@
                     def scheduler(opt):
                         return scheduler_class(opt, **kwargs.get("scheduler_conf"))
 
-            args = OmegaConf.create({"deepspeed_config": deepspeed_config})
             model, optimizer, _, scheduler = deepspeed.initialize(
                 args=args,
                 model=model,

--
Gitblit v1.9.1