| | |
| | | from functools import cmp_to_key |
| | | |
| | | |
| | | def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False): |
| | | def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs): |
| | | """ |
| | | Get the paths of the last 'last_n' checkpoints by parsing filenames |
| | | in the output directory. |
| | |
| | | Average the last 'last_n' checkpoints' model state_dicts. |
| | | If a tensor is of type torch.int, perform sum instead of average. |
| | | """ |
| | | checkpoint_paths = _get_checkpoint_paths(output_dir, last_n) |
| | | checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs) |
| | | print(f"average_checkpoints: {checkpoint_paths}") |
| | | state_dicts = [] |
| | | |