shixian.shi
2023-06-27 6f08e9040825fb120abb2d9f386bd91bbaea3f80
debug
1个文件已修改
6 ■■■■ 已修改文件
funasr/tasks/abs_task.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/abs_task.py
@@ -1252,9 +1252,13 @@
            raise RuntimeError(
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        #model = model.to(
        #    dtype=getattr(torch, args.train_dtype),
        #    device="cuda" if args.ngpu > 0 else "cpu",
        #)
        model = model.to(
            dtype=getattr(torch, args.train_dtype),
            device="cuda" if args.ngpu > 0 else "cpu",
            device="cpu",
        )
        for t in args.freeze_param:
            for k, p in model.named_parameters():