zhifu gao
2024-03-21 3ac03e448b7673604eb86f619b27521fca55f34d
funasr/train_utils/average_nbest_models.py
@@ -143,7 +143,7 @@
    return checkpoint_paths
@torch.no_grad()
def average_checkpoints(output_dir: str, last_n: int=5):
def average_checkpoints(output_dir: str, last_n: int=5, val_acc_list=[]):
    """
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.