funasr/train_utils/average_nbest_models.py
@@ -143,7 +143,7 @@ return checkpoint_paths @torch.no_grad() def average_checkpoints(output_dir: str, last_n: int=5): def average_checkpoints(output_dir: str, last_n: int=5, val_acc_list=[]): """ Average the last 'last_n' checkpoints' model state_dicts. If a tensor is of type torch.int, perform sum instead of average.