游雁
2024-05-17 86ada491e01691c53fd72fa76b8c77294042f938
deepspeed
1个文件已修改
4 ■■■■ 已修改文件
funasr/bin/train_ds.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train_ds.py
@@ -130,8 +130,8 @@
    model = trainer.warp_model(model)
    kwargs["device"] = next(model.parameters()).device
    trainer.device = kwargs["device"]
    kwargs["device"] = int(os.environ.get("LOCAL_RANK", 0))
    trainer.device = int(os.environ.get("LOCAL_RANK", 0))
    model, optim, scheduler = trainer.warp_optim_scheduler(model, **kwargs)