ds
游雁
2024-05-20 3de70601df378664905c665d327c4c9d20c81598
ds
2个文件已修改
10 ■■■■ 已修改文件
docs/images/wechat.png 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/average_nbest_models.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
docs/images/wechat.png

funasr/train_utils/average_nbest_models.py
@@ -16,7 +16,7 @@
from functools import cmp_to_key
def _get_checkpoint_paths(output_dir: str, last_n: int = 5):
def _get_checkpoint_paths(output_dir: str, last_n: int = 5, use_deepspeed=False):
    """
    Get the paths of the last 'last_n' checkpoints by parsing filenames
    in the output directory.
@@ -29,7 +29,13 @@
        sorted_items = (
            sorted_items[:last_n] if avg_keep_nbest_models_type == "acc" else sorted_items[-last_n:]
        )
        checkpoint_paths = [os.path.join(output_dir, key) for key, value in sorted_items[:last_n]]
        checkpoint_paths = []
        for key, value in sorted_items[:last_n]:
            if not use_deepspeed:
                ckpt = os.path.join(output_dir, key)
            else:
                ckpt = os.path.join(output_dir, key, "mp_rank_00_model_states.pt")
    except:
        print(f"{checkpoint} does not exist, avg the lastet checkpoint.")
        # List all files in the output directory