From 1e1500adadf5c7ed3622efa0f48f51b48a78b31e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 20 五月 2024 11:33:14 +0800
Subject: [PATCH] ds

---
 funasr/train_utils/trainer_ds.py |    4 ++--
 1 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 78cfceb..88a853c 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -577,7 +577,7 @@
             self.val_loss_avg = (
                 self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
             ) / (batch_idx + 1)
-            if "acc" in stats:
+            if "acc" in loss_dict["stats"]:
                 self.val_acc_avg = (
                     self.val_acc_avg * batch_idx + loss_dict["stats"]["acc"].detach().cpu().item()
                 ) / (batch_idx + 1)
@@ -740,7 +740,7 @@
                 self.val_loss_avg = (
                     self.val_loss_avg * batch_idx + loss_dict["loss"].detach().cpu().item()
                 ) / (batch_idx + 1)
-                if "acc" in stats:
+                if "acc" in loss_dict["stats"]:
                     self.val_acc_avg = (
                         self.val_acc_avg * batch_idx
                         + loss_dict["stats"]["acc"].detach().cpu().item()

--
Gitblit v1.9.1