| funasr/tasks/abs_task.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
funasr/tasks/abs_task.py
@@ -1252,9 +1252,13 @@ raise RuntimeError( f"model must inherit {FunASRModel.__name__}, but got {type(model)}" ) #model = model.to( # dtype=getattr(torch, args.train_dtype), # device="cuda" if args.ngpu > 0 else "cpu", #) model = model.to( dtype=getattr(torch, args.train_dtype), device="cuda" if args.ngpu > 0 else "cpu", device="cpu", ) for t in args.freeze_param: for k, p in model.named_parameters():