游雁
2024-03-21 7a4d17b54df09b15c40d0fd7c3af6f9ee3a25e73
funasr/train_utils/average_nbest_models.py
@@ -143,7 +143,7 @@
    return checkpoint_paths
@torch.no_grad()
def average_checkpoints(output_dir: str, last_n: int=5):
def average_checkpoints(output_dir: str, last_n: int=5, val_acc_list=[]):
    """
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.