From 00d0df3a1018c63ec8c5d13e611f53c564c0a7e2 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 06 五月 2024 22:17:25 +0800
Subject: [PATCH] Dev gzf decoding (#1695)
---
funasr/train_utils/trainer.py | 7 +++++--
1 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/funasr/train_utils/trainer.py b/funasr/train_utils/trainer.py
index dd0ac7a..01e2924 100644
--- a/funasr/train_utils/trainer.py
+++ b/funasr/train_utils/trainer.py
@@ -308,6 +308,7 @@
checkpoint["step_in_epoch"] if "step_in_epoch" in checkpoint else 0
)
self.step_in_epoch = 0 if self.step_in_epoch is None else self.step_in_epoch
+ print(checkpoint["train_acc_avg"])
self.train_acc_avg = (
checkpoint["train_acc_avg"] if "train_acc_avg" in checkpoint else 0
)
@@ -464,7 +465,8 @@
batch_num_epoch = len(dataloader_train)
self.log(
epoch,
- batch_idx + kwargs.get("start_step", 0),
+ batch_idx,
+ log_step=batch_idx + kwargs.get("start_step", 0),
step_in_epoch=self.step_in_epoch,
batch_num_epoch=batch_num_epoch,
lr=lr,
@@ -633,11 +635,12 @@
tag="train",
data_split_i=0,
data_split_num=1,
+ log_step=None,
**kwargs,
):
if (batch_idx + 1) % self.log_interval == 0:
-
+ batch_idx = log_step if log_step is not None else batch_idx
gpu_info = (
"GPU, memory: usage: {:.3f} GB, "
"peak: {:.3f} GB, "
--
Gitblit v1.9.1