ds
游雁
2024-05-20 bbd300a91184d09dec56a3b06734051dbd5812e4
ds
2个文件已修改
8 ■■■■■ 已修改文件
funasr/bin/train_ds.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/average_nbest_models.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train_ds.py
@@ -198,7 +198,9 @@
        trainer.train_loss_avg = 0.0
    if trainer.rank == 0:
        average_checkpoints(trainer.output_dir, trainer.avg_nbest_model)
        average_checkpoints(
            trainer.output_dir, trainer.avg_nbest_model, use_deepspeed=trainer.use_deepspeed
        )
    trainer.close()
funasr/train_utils/average_nbest_models.py
@@ -16,7 +16,7 @@
from functools import cmp_to_key
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False):
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False, **kwargs):
    """
    Get the paths of the last 'last_n' checkpoints by parsing filenames
    in the output directory.
@@ -55,7 +55,7 @@
    Average the last 'last_n' checkpoints' model state_dicts.
    If a tensor is of type torch.int, perform sum instead of average.
    """
    checkpoint_paths = _get_checkpoint_paths(output_dir, last_n)
    checkpoint_paths = _get_checkpoint_paths(output_dir, last_n, **kwargs)
    print(f"average_checkpoints: {checkpoint_paths}")
    state_dicts = []