游雁
2024-03-26 abf5af40e934216b397c5331e0a68dc92f0a4f4e
funasr/train_utils/average_nbest_models.py
@@ -143,7 +143,7 @@
    return checkpoint_paths
@torch.no_grad()
def average_checkpoints(output_dir: str, last_n: int=5):
def average_checkpoints(output_dir: str, last_n: int=5, val_acc_list=[]):
    """
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.