From 97522b10f661b004fbdbe234aa55ffd192578ce0 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 20 五月 2024 17:05:31 +0800
Subject: [PATCH] bugfix
---
funasr/train_utils/average_nbest_models.py | 9 ++++++++-
funasr/train_utils/trainer_ds.py | 2 +-
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py
index 20da130..67f1e55 100644
--- a/funasr/train_utils/average_nbest_models.py
+++ b/funasr/train_utils/average_nbest_models.py
@@ -22,7 +22,13 @@
in the output directory.
"""
try:
- checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
+ if not use_deepspeed:
+ checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
+ else:
+ checkpoint = torch.load(
+ os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"),
+ map_location="cpu",
+ )
avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
@@ -35,6 +41,7 @@
ckpt = os.path.join(output_dir, key)
else:
ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
+ checkpoint_paths.append(ckpt)
except:
print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index db92bc8..1a553f8 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -388,7 +388,7 @@
ckpt = os.path.join(self.output_dir, "model.pt")
if os.path.exists(ckpt):
_, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
-
+ self.start_epoch = checkpoint["epoch"]
self.saved_ckpts = checkpoint["saved_ckpts"]
self.val_acc_step_or_eoch = (
checkpoint["val_acc_step_or_eoch"]
--
Gitblit v1.9.1