From a0f03bd2a87d97d47a1636bbe6f0855a43160331 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 15 五月 2024 19:48:50 +0800
Subject: [PATCH] Dev gzf deepspeed (#1732)

---
 funasr/train_utils/trainer.py |   50 ++++++++++++++++++++++++++++----------------------
 1 files changed, 28 insertions(+), 22 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index 01e2924..50f99f0 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -382,8 +382,6 @@
                     ):
                         torch.cuda.empty_cache()
 
-                time3 = time.perf_counter()
-                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 loss, stats, weight = retval
                 stats = {k: v for k, v in stats.items() if v is not None}
                 if self.use_ddp or self.use_fsdp:
@@ -398,34 +396,28 @@
                     # Multiply world_size because DistributedDataParallel
                     # automatically normalizes the gradient by world_size.
                     loss *= self.world_size
+                # loss *= self.world_size
                 # Scale the loss since we're not updating for every mini-batch
                 loss = loss / accum_grad
+
+                time3 = time.perf_counter()
+                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 if self.use_fp16:
                     scaler.scale(loss).backward()
                 else:
                     loss.backward()
                 time4 = time.perf_counter()
-                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+                speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
 
                 self.train_loss_avg = (
-                    self.train_loss_avg * (self.step_in_epoch - 1) + loss.detach().cpu().item()
-                ) / self.step_in_epoch
+                    self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0))
+                    + loss.detach().cpu().item()
+                ) / (batch_idx + kwargs.get("start_step", 0) + 1)
                 if "acc" in stats:
                     self.train_acc_avg = (
-                        self.train_acc_avg * (self.step_in_epoch - 1)
+                        self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0))
                         + stats["acc"].detach().cpu().item()
-                    ) / self.step_in_epoch
-                if self.use_ddp or self.use_fsdp:
-                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
-                        self.device
-                    )
-                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
-                        self.device
-                    )
-                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
-                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
-                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
-                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+                    ) / (batch_idx + kwargs.get("start_step", 0) + 1)
 
             # Perform an optimizer step only after accumulating enough gradients
             if (batch_idx + 1) % accum_grad == 0:
@@ -454,8 +446,22 @@
                 scheduler.step()
                 # Clear gradients for the next accumulation stage
                 optim.zero_grad(set_to_none=True)
-                total_time = f"{time.perf_counter() - time5:0.3f}"
+
+                if self.use_ddp or self.use_fsdp:
+                    train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
+                        self.device
+                    )
+                    train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(
+                        self.device
+                    )
+                    dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+                    dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+                    self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+                    self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
+
+                total_time = f"{(time.perf_counter() - time5)/accum_grad:0.3f}"
                 time5 = time.perf_counter()
+
                 speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
 
                 speed_stats["total_time"] = total_time
@@ -662,9 +668,9 @@
                 f"data_slice: {data_split_i}/{data_split_num}, "
                 f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
                 f"(loss_avg_rank: {loss:.3f}), "
-                f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
-                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), "
-                f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
+                f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
+                f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
+                f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
                 f"(lr: {lr:.3e}), "
                 f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
                 f"{speed_stats}, "

--
Gitblit v1.9.1