From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/train_utils/average_nbest_models.py | 9 ++++++++-
1 files changed, 8 insertions(+), 1 deletions(-)
diff --git a/funasr/train_utils/average_nbest_models.py b/funasr/train_utils/average_nbest_models.py
index 20da130..67f1e55 100644
--- a/funasr/train_utils/average_nbest_models.py
+++ b/funasr/train_utils/average_nbest_models.py
@@ -22,7 +22,13 @@
in the output directory.
"""
try:
- checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
+ if not use_deepspeed:
+ checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
+ else:
+ checkpoint = torch.load(
+ os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"),
+ map_location="cpu",
+ )
avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
@@ -35,6 +41,7 @@
ckpt = os.path.join(output_dir, key)
else:
ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
+ checkpoint_paths.append(ckpt)
except:
print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
--
Gitblit v1.9.1