游雁
2024-06-09 b75d1e89bb2f513a79bb07e9100ba1cd2bbcf40c
funasr/train_utils/average_nbest_models.py
@@ -16,20 +16,33 @@
from functools import cmp_to_key
def _get_checkpoint_paths(output_dir: str, last_n: int = 5):
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs):
    """
    Get the paths of the last 'last_n' checkpoints by parsing filenames
    in the output directory.
    """
    try:
        checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
        if not use_deepspeed:
            checkpoint = torch.load(os.path.join(output_dir, "model.pt"), map_location="cpu")
        else:
            checkpoint = torch.load(
                os.path.join(output_dir, "model.pt", "mp_rank_00_model_states.pt"),
                map_location="cpu",
            )
        avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
        val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
        sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
        sorted_items = (
            sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:]
        )
        checkpoint_paths = [os.path.join(output_dir, key) for key, value in sorted_items[:last_n]]
        checkpoint_paths = []
        for key, value in sorted_items[:last_n]:
            if not use_deepspeed:
                ckpt = os.path.join(output_dir, key)
            else:
                ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
            checkpoint_paths.append(ckpt)
    except:
        print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
        # List all files in the output directory
@@ -49,7 +62,7 @@
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.
    """
    checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
    checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
    print(f"average_checkpoints: {checkpoint_paths}")
    state_dicts = []