shixian.shi
2023-06-27 6569b950257415ea7b6c21fef013da65a43772f7
update
1个文件已修改
7 ■■■■ 已修改文件
funasr/tasks/abs_task.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/abs_task.py
@@ -1150,7 +1150,6 @@
    def main_worker(cls, args: argparse.Namespace):
        assert check_argument_types()
        args.ngpu = 0
        # 0. Init distributed process
        distributed_option = build_dataclass(DistributedOption, args)
        # Setting distributed_option.dist_rank, etc.
@@ -1253,13 +1252,9 @@
            raise RuntimeError(
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        #model = model.to(
        #    dtype=getattr(torch, args.train_dtype),
        #    device="cuda" if args.ngpu > 0 else "cpu",
        #)
        model = model.to(
            dtype=getattr(torch, args.train_dtype),
            device="cpu",
            device="cuda" if args.ngpu > 0 else "cpu",
        )
        for t in args.freeze_param:
            for k, p in model.named_parameters():