| | |
| | | from funasr.utils.types import str2triple_str |
| | | from funasr.utils.types import str_or_int |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.utils.wav_utils import calc_shape, generate_data_list |
| | | from funasr.utils.wav_utils import calc_shape, generate_data_list, filter_wav_text |
| | | from funasr.utils.yaml_no_alias_safe_dump import yaml_no_alias_safe_dump |
| | | |
| | | try: |
| | |
| | | if args.batch_bins is not None: |
| | | args.batch_bins = args.batch_bins * args.ngpu |
| | | |
| | | # filter samples if wav.scp and text are mismatch |
| | | if (args.train_shape_file is None and args.dataset_type == "small") or args.train_data_file is None and args.dataset_type == "large": |
| | | if not args.simple_ddp or distributed_option.dist_rank == 0: |
| | | filter_wav_text(args.data_dir, args.train_set) |
| | | filter_wav_text(args.data_dir, args.dev_set) |
| | | if args.simple_ddp: |
| | | dist.barrier() |
| | | |
| | | if args.train_shape_file is None and args.dataset_type == "small": |
| | | if not args.simple_ddp or distributed_option.dist_rank == 0: |
| | | calc_shape(args.data_dir, args.train_set, args.frontend_conf, args.speech_length_min, args.speech_length_max) |
| | |
| | | if args.dataset_type == "large": |
| | | from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader |
| | | train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |
| | | mode="train") |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, |
| | | frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None, |
| | | seg_dict_file=args.seg_dict_file if hasattr(args, |
| | | "seg_dict_file") else None, |
| | | punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None, |
| | | mode="eval") |
| | | elif args.dataset_type == "small": |
| | | train_iter_factory = cls.build_iter_factory( |
| | |
| | | preprocess=iter_options.preprocess_fn, |
| | | max_cache_size=iter_options.max_cache_size, |
| | | max_cache_fd=iter_options.max_cache_fd, |
| | | dest_sample_rate=args.frontend_conf["fs"], |
| | | ) |
| | | cls.check_task_requirements( |
| | | dataset, args.allow_variable_data_keys, train=iter_options.train |
| | |
| | | key_file: str = None, |
| | | batch_size: int = 1, |
| | | fs: dict = None, |
| | | mc: bool = False, |
| | | dtype: str = np.float32, |
| | | num_workers: int = 1, |
| | | allow_variable_data_keys: bool = False, |
| | |
| | | data_path_and_name_and_type, |
| | | float_dtype=dtype, |
| | | fs=fs, |
| | | mc=mc, |
| | | preprocess=preprocess_fn, |
| | | key_file=key_file, |
| | | ) |