From b3b10158097b10aa26ee3469c5ba8fd20c745de3 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 20 五月 2024 11:41:53 +0800
Subject: [PATCH] ds
---
funasr/train_utils/trainer_ds.py | 20 ++++++++++----------
1 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 88a853c..8a52746 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -574,12 +574,12 @@
loss_dict["lr"] = scheduler.get_last_lr()[0]
loss_dict["batch_num_epoch"] = len(dataloader_train)
- self.val_loss_avg = (
- self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
+ self.train_loss_avg = (
+ self.train_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
) / (batch_idx + 1)
if "acc" in loss_dict["stats"]:
- self.val_acc_avg = (
- self.val_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
+ self.train_acc_avg = (
+ self.train_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
) / (batch_idx + 1)
self.log(loss_dict, tag="train")
@@ -612,12 +612,12 @@
time_beg = time.perf_counter()
if self.use_ddp or self.use_fsdp or self.use_deepspeed:
- val_loss_avg = torch.tensor(self.val_loss_avg, dtype=torch.float32).to(self.device)
- val_acc_avg = torch.tensor(self.val_acc_avg, dtype=torch.float32).to(self.device)
- dist.all_reduce(val_loss_avg, op=dist.ReduceOp.SUM)
- dist.all_reduce(val_acc_avg, op=dist.ReduceOp.SUM)
- self.val_loss_avg = val_loss_avg.detach().cpu().item() / self.world_size
- self.val_acc_avg = val_acc_avg.detach().cpu().item() / self.world_size
+ train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to(self.device)
+ train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to(self.device)
+ dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM)
+ dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM)
+ self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size
+ self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
def forward_step(self, model, batch, loss_dict={}):
dtype = torch.bfloat16
--
Gitblit v1.9.1