From b3b10158097b10aa26ee3469c5ba8fd20c745de3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 20 五月 2024 11:41:53 +0800
Subject: [PATCH] ds

---
 funasr/train_utils/trainer_ds.py |   20 ++++++++++----------
 1 files changed, 10 insertions(+), 10 deletions(-)

diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 88a853c..8a52746 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -574,12 +574,12 @@
             loss_dict["lr"] = scheduler.get_last_lr()[0]
             loss_dict["batch_num_epoch"] = len(dataloader_train)
 
-            self.val_loss_avg = (
-                self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
+            self.train_loss_avg = (
+                self.train_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
             ) / (batch_idx + 1)
             if "acc" in loss_dict["stats"]:
-                self.val_acc_avg = (
-                    self.val_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
+                self.train_acc_avg = (
+                    self.train_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
                 ) / (batch_idx + 1)
 
             self.log(loss_dict, tag="train")
@@ -612,12 +612,12 @@
             time_beg = time.perf_counter()
 
         if self.use_ddp or self.use_fsdp or self.use_deepspeed:
-            val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
-            val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
-            dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
-            dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
-            self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
-            self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+            train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
+            train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
+            dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+            dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+            self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+            self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
 
     def forward_step(self, model, batch, loss_dict={}):
         dtype = torch.bfloat16

--
Gitblit v1.9.1