speech_asr
2023-03-14 4d2bf9fe3cc385b441e94c3000b34a44cac8a8db
funasr/tasks/abs_task.py
@@ -1348,11 +1348,13 @@
            if args.dataset_type == "large":
                from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
                train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
                                                   frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
                                                   mode="train")
                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
                                                   frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
                                                   seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                               "seg_dict_file") else None,
                                                   punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
@@ -1574,6 +1576,7 @@
            preprocess=iter_options.preprocess_fn,
            max_cache_size=iter_options.max_cache_size,
            max_cache_fd=iter_options.max_cache_fd,
            dest_sample_rate=args.frontend_conf["fs"],
        )
        cls.check_task_requirements(
            dataset, args.allow_variable_data_keys, train=iter_options.train
@@ -1845,6 +1848,7 @@
            key_file: str = None,
            batch_size: int = 1,
            fs: dict = None,
            mc: bool = False,
            dtype: str = np.float32,
            num_workers: int = 1,
            allow_variable_data_keys: bool = False,
@@ -1863,6 +1867,7 @@
            data_path_and_name_and_type,
            float_dtype=dtype,
            fs=fs,
            mc=mc,
            preprocess=preprocess_fn,
            key_file=key_file,
        )