| | |
| | | def main_worker(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | |
| | | args.ngpu = 0 |
| | | # 0. Init distributed process |
| | | distributed_option = build_dataclass(DistributedOption, args) |
| | | # Setting distributed_option.dist_rank, etc. |
| | |
| | | raise RuntimeError( |
| | | f"model must inherit {FunASRModel.__name__}, but got {type(model)}" |
| | | ) |
| | | #model = model.to( |
| | | # dtype=getattr(torch, args.train_dtype), |
| | | # device="cuda" if args.ngpu > 0 else "cpu", |
| | | #) |
| | | model = model.to( |
| | | dtype=getattr(torch, args.train_dtype), |
| | | device="cuda" if args.ngpu > 0 else "cpu", |
| | | device="cpu", |
| | | ) |
| | | for t in args.freeze_param: |
| | | for k, p in model.named_parameters(): |
| | |
| | | |
| | | # 7. Build iterator factories |
| | | if args.dataset_type == "large": |
| | | from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader |
| | | train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, |
| | | "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, |
| | | "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="train") |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, |
| | | "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, |
| | | "punc_list") else None, |
| | | bpemodel_file=args.bpemodel if hasattr(args, "bpemodel") else None, |
| | | mode="eval") |
| | | from funasr.datasets.large_datasets.build_dataloader import LargeDataLoader |
| | | train_iter_factory = LargeDataLoader(args, mode="train") |
| | | valid_iter_factory = LargeDataLoader(args, mode="eval") |
| | | |
| | | elif args.dataset_type == "small": |
| | | train_iter_factory = cls.build_iter_factory( |
| | | args=args, |