游雁
2024-05-20 97522b10f661b004fbdbe234aa55ffd192578ce0
bugfix
2个文件已修改
11 ■■■■ 已修改文件
funasr/train_utils/average_nbest_models.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer_ds.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/average_nbest_models.py
@@ -22,7 +22,13 @@
    in the output directory.
    """
    try:
        checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
        if not use_deepspeed:
            checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
        else:
            checkpoint = torch.load(
                os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"),
                map_location="cpu",
            )
        avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
        val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
        sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
@@ -35,6 +41,7 @@
                ckpt = os.path.join(output_dir, key)
            else:
                ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
            checkpoint_paths.append(ckpt)
    except:
        print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
funasr/train_utils/trainer_ds.py
@@ -388,7 +388,7 @@
                ckpt = os.path.join(self.output_dir, "model.pt")
                if os.path.exists(ckpt):
                    _, checkpoint = model.load_checkpoint(self.output_dir, "model.pt")
                    self.start_epoch = checkpoint["epoch"]
                    self.saved_ckpts = checkpoint["saved_ckpts"]
                    self.val_acc_step_or_eoch = (
                        checkpoint["val_acc_step_or_eoch"]