| | |
| | | trainer.train_loss_avg = 0.0 |
| | | |
| | | if trainer.rank == 0: |
| | | average_checkpoints(trainer.output_dir, trainer.avg_nbest_model) |
| | | average_checkpoints( |
| | | trainer.output_dir, trainer.avg_nbest_model, use_deepspeed=trainer.use_deepspeed |
| | | ) |
| | | |
| | | trainer.close() |
| | | |
| | |
| | | from functools import cmp_to_key |
| | | |
| | | |
| | | def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False): |
| | | def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs): |
| | | """ |
| | | Get the paths of the last 'last_n' checkpoints by parsing filenames |
| | | in the output directory. |
| | |
| | | Average the last 'last_n' checkpoints' model state_dicts. |
| | | If a tensor is of type torch.int, perform sum instead of average. |
| | | """ |
| | | checkpoint_paths = _get_checkpoint_paths(output_dir, last_n) |
| | | checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs) |
| | | print(f"average_checkpoints: {checkpoint_paths}") |
| | | state_dicts = [] |
| | | |