From 9be30f99dd09cfe0de929266ec43c1b95abb6d96 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 五月 2024 10:16:28 +0800
Subject: [PATCH] update avg slice

---
 funasr/train_utils/trainer.py |   44 ++++++++++++++++++++++++++++++--------------
 1 files changed, 30 insertions(+), 14 deletions(-)

diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index a28ca51..33dd351 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -169,6 +169,8 @@
                 "data_split_i": kwargs.get("data_split_i", 0),
                 "data_split_num": kwargs.get("data_split_num", 1),
                 "batch_total": self.batch_total,
+                "train_loss_avg": kwargs.get("train_loss_avg", 0),
+                "train_acc_avg": kwargs.get("train_acc_avg", 0),
             }
             step = step_in_epoch
             if hasattr(model, "module"):
@@ -306,7 +308,13 @@
                     checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
                 )
                 self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
-
+                print(checkpoint["train_acc_avg"])
+                self.train_acc_avg = (
+                    checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
+                )
+                self.train_loss_avg = (
+                    checkpoint["train_loss_avg"] if "train_loss_avg" in checkpoint else 0
+                )
                 model.to(self.device)
                 print(f"Checkpoint loaded successfully from '{ckpt}'")
             else:
@@ -374,8 +382,6 @@
                     ):
                         torch.cuda.empty_cache()
 
-                time3 = time.perf_counter()
-                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 loss, stats, weight = retval
                 stats = {k: v for k, v in stats.items() if v is not None}
                 if self.use_ddp or self.use_fsdp:
@@ -390,22 +396,27 @@
                     # Multiply world_size because DistributedDataParallel
                     # automatically normalizes the gradient by world_size.
                     loss *= self.world_size
+                # loss *= self.world_size
                 # Scale the loss since we're not updating for every mini-batch
                 loss = loss / accum_grad
+
+                time3 = time.perf_counter()
+                speed_stats["forward_time"] = f"{time3 - time2:0.3f}"
                 if self.use_fp16:
                     scaler.scale(loss).backward()
                 else:
                     loss.backward()
                 time4 = time.perf_counter()
-                speed_stats["backward_time"] = f"{time4 - time3:0.3f}"
+                speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}"
 
                 self.train_loss_avg = (
-                    self.train_loss_avg * batch_idx + loss.detach().cpu().item()
-                ) / (batch_idx + 1)
+                    self.train_loss_avg * (self.step_in_epoch - 1) + loss.detach().cpu().item()
+                ) / self.step_in_epoch
                 if "acc" in stats:
                     self.train_acc_avg = (
-                        self.train_acc_avg * batch_idx + stats["acc"].detach().cpu().item()
-                    ) / (batch_idx + 1)
+                        self.train_acc_avg * (self.step_in_epoch - 1)
+                        + stats["acc"].detach().cpu().item()
+                    ) / self.step_in_epoch
                 if self.use_ddp or self.use_fsdp:
                     train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(
                         self.device
@@ -445,8 +456,9 @@
                 scheduler.step()
                 # Clear gradients for the next accumulation stage
                 optim.zero_grad(set_to_none=True)
-                total_time = f"{time.perf_counter() - time5:0.3f}"
+                total_time = f"{(time.perf_counter() - time5)/accum_grad:0.3f}"
                 time5 = time.perf_counter()
+
                 speed_stats["optim_time"] = f"{time5 - time4:0.3f}"
 
                 speed_stats["total_time"] = total_time
@@ -456,7 +468,8 @@
                     batch_num_epoch = len(dataloader_train)
                 self.log(
                     epoch,
-                    batch_idx + kwargs.get("start_step", 0),
+                    batch_idx,
+                    log_step=batch_idx + kwargs.get("start_step", 0),
                     step_in_epoch=self.step_in_epoch,
                     batch_num_epoch=batch_num_epoch,
                     lr=lr,
@@ -490,6 +503,8 @@
                     step_in_epoch=self.step_in_epoch,
                     data_split_i=kwargs.get("data_split_i", 0),
                     data_split_num=kwargs.get("data_split_num", 1),
+                    train_loss_avg=self.train_loss_avg,
+                    train_acc_avg=self.train_acc_avg,
                 )
 
             time_beg = time.perf_counter()
@@ -623,11 +638,12 @@
         tag="train",
         data_split_i=0,
         data_split_num=1,
+        log_step=None,
         **kwargs,
     ):
 
         if (batch_idx + 1) % self.log_interval == 0:
-
+            batch_idx = log_step if log_step is not None else batch_idx
             gpu_info = (
                 "GPU, memory: usage: {:.3f} GB, "
                 "peak: {:.3f} GB, "
@@ -649,9 +665,9 @@
                 f"data_slice: {data_split_i}/{data_split_num}, "
                 f"step_in_slice: {batch_idx + 1}/{batch_num_epoch}, step_in_epoch: {step_in_epoch}, total step: {self.batch_total}, "
                 f"(loss_avg_rank: {loss:.3f}), "
-                f"(loss_avg_epoch: {loss_avg_epoch:.3f}), "
-                f"(ppl_avg_epoch: {math.exp(loss_avg_epoch):.3e}), "
-                f"(acc_avg_epoch: {acc_avg_epoch:.3f}), "
+                f"(loss_avg_slice: {loss_avg_epoch:.3f}), "
+                f"(ppl_avg_slice: {math.exp(loss_avg_epoch):.3e}), "
+                f"(acc_avg_slice: {acc_avg_epoch:.3f}), "
                 f"(lr: {lr:.3e}), "
                 f"{[(k, round(v.detach().cpu().item(), 3)) for k, v in stats.items()]}, "
                 f"{speed_stats}, "

--
Gitblit v1.9.1