ds
游雁
2024-05-20 bbd300a91184d09dec56a3b06734051dbd5812e4
funasr/train_utils/average_nbest_models.py
@@ -16,7 +16,7 @@
from functools import cmp_to_key
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False):
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs):
    """
    Get the paths of the last 'last_n' checkpoints by parsing filenames
    in the output directory.
@@ -55,7 +55,7 @@
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.
    """
    checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
    checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
    print(f"average_checkpoints: {checkpoint_paths}")
    state_dicts = []