zhifu gao
2024-05-20 963ba1a7717c785d6e20ccb0d3cee9b59d5365e3
funasr/train_utils/average_nbest_models.py
@@ -22,7 +22,13 @@
    in the output directory.
    """
    try:
        checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
        if not use_deepspeed:
            checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
        else:
            checkpoint = torch.load(
                os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"),
                map_location="cpu",
            )
        avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
        val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
        sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
@@ -35,6 +41,7 @@
                ckpt = os.path.join(output_dir, key)
            else:
                ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
            checkpoint_paths.append(ckpt)
    except:
        print(f"{checkpoint} does not exist, avg the lastet checkpoint.")