游雁
2024-03-25 817ff41fbc5afbde346db62ad5e28e33178a622a
funasr/train_utils/average_nbest_models.py
@@ -143,7 +143,7 @@
    return checkpoint_paths
@torch.no_grad()
def average_checkpoints(output_dir: str, last_n: int=5):
def average_checkpoints(output_dir: str, last_n: int=5, val_acc_list=[]):
    """
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.