shixian.shi
2023-06-27 25de54910e7b48552bdba2dd655fbcd64a07668e
funasr/tasks/abs_task.py
@@ -1150,6 +1150,7 @@
    def main_worker(cls, args: argparse.Namespace):
        assert check_argument_types()
        args.ngpu = 0
        # 0. Init distributed process
        distributed_option = build_dataclass(DistributedOption, args)
        # Setting distributed_option.dist_rank, etc.
@@ -1252,9 +1253,13 @@
            raise RuntimeError(
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        #model = model.to(
        #    dtype=getattr(torch, args.train_dtype),
        #    device="cuda" if args.ngpu > 0 else "cpu",
        #)
        model = model.to(
            dtype=getattr(torch, args.train_dtype),
            device="cuda" if args.ngpu > 0 else "cpu",
            device="cpu",
        )
        for t in args.freeze_param:
            for k, p in model.named_parameters():
@@ -1376,25 +1381,10 @@
            # 7. Build iterator factories
            if args.dataset_type == "large":
                from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
                                                   frontend_conf=args.frontend_conf if hasattr(args,
                                                                                               "frontend_conf") else None,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   punc_dict_file=args.punc_list if hasattr(args,
                                                                                            "punc_list") else None,
                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                   mode="train")
                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
                                                   frontend_conf=args.frontend_conf if hasattr(args,
                                                                                               "frontend_conf") else None,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   punc_dict_file=args.punc_list if hasattr(args,
                                                                                            "punc_list") else None,
                                                   bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None,
                                                   mode="eval")
                from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader
                train_iter_factory = LargeDataLoader(args, mode="train")
                valid_iter_factory = LargeDataLoader(args, mode="eval")
            elif args.dataset_type == "small":
                train_iter_factory = cls.build_iter_factory(
                    args=args,