From 1e1500adadf5c7ed3622efa0f48f51b48a78b31e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 20 五月 2024 11:33:14 +0800
Subject: [PATCH] ds
---
funasr/train_utils/trainer_ds.py | 4 ++--
1 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 78cfceb..88a853c 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -577,7 +577,7 @@
self.val_loss_avg = (
self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
) / (batch_idx + 1)
- if "acc" in stats:
+ if "acc" in loss_dict["stats"]:
self.val_acc_avg = (
self.val_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
) / (batch_idx + 1)
@@ -740,7 +740,7 @@
self.val_loss_avg = (
self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
) / (batch_idx + 1)
- if "acc" in stats:
+ if "acc" in loss_dict["stats"]:
self.val_acc_avg = (
self.val_acc_avg * batch_idx
+ loss_dict["stats"]["acc"].detach().cpu().item()
--
Gitblit v1.9.1