zhifu gao
2024-04-26 1cdb3cc28d4d89a576cc06e5cd8eb80da1f3a3aa
funasr/train_utils/average_nbest_models.py
@@ -16,8 +16,7 @@
from functools import cmp_to_key
def _get_checkpoint_paths(output_dir: str, last_n: int=5):
def _get_checkpoint_paths(output_dir: str, last_n: int = 5):
    """
    Get the paths of the last 'last_n' checkpoints by parsing filenames
    in the output directory.
@@ -27,7 +26,9 @@
        avg_keep_nbest_models_type = checkpoint["avg_keep_nbest_models_type"]
        val_step_or_eoch = checkpoint[f"val_{avg_keep_nbest_models_type}_step_or_eoch"]
        sorted_items = sorted(val_step_or_eoch.items(), key=lambda x: x[1], reverse=True)
        sorted_items = sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:]
        sorted_items = (
            sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:]
        )
        checkpoint_paths = [os.path.join(output_dir, key) for key, value in sorted_items[:last_n]]
    except:
        print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
@@ -36,13 +37,14 @@
        # Filter out checkpoint files and extract epoch numbers
        checkpoint_files = [f for f in files if f.startswith("model.pt.e")]
        # Sort files by epoch number in descending order
        checkpoint_files.sort(key=lambda x: int(re.search(r'(\d+)', x).group()), reverse=True)
        checkpoint_files.sort(key=lambda x: int(re.search(r"(\d+)", x).group()), reverse=True)
        # Get the last 'last_n' checkpoint paths
        checkpoint_paths = [os.path.join(output_dir, f) for f in checkpoint_files[:last_n]]
    return checkpoint_paths
@torch.no_grad()
def average_checkpoints(output_dir: str, last_n: int=5, **kwargs):
def average_checkpoints(output_dir: str, last_n: int = 5, **kwargs):
    """
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.
@@ -54,7 +56,7 @@
    # Load state_dicts from checkpoints
    for path in checkpoint_paths:
        if os.path.isfile(path):
            state_dicts.append(torch.load(path, map_location='cpu')['state_dict'])
            state_dicts.append(torch.load(path, map_location="cpu")["state_dict"])
        else:
            print(f"Checkpoint file {path} not found.")
@@ -76,5 +78,5 @@
            stacked_tensors = torch.stack(tensors)
            avg_state_dict[key] = torch.mean(stacked_tensors, dim=0)
    checkpoint_outpath = os.path.join(output_dir, f"model.pt.avg{last_n}")
    torch.save({'state_dict': avg_state_dict}, checkpoint_outpath)
    return checkpoint_outpath
    torch.save({"state_dict": avg_state_dict}, checkpoint_outpath)
    return checkpoint_outpath