雾聪
2023-08-07 f8d1c79fe355efb18ae49e4363307dfec3ab89ce
funasr/build_utils/build_dataloader.py
@@ -4,8 +4,21 @@
def build_dataloader(args):
    if args.dataset_type == "small":
        train_iter_factory = SequenceIterFactory(args, mode="train")
        valid_iter_factory = SequenceIterFactory(args, mode="valid")
        if args.task_name == "diar" and args.model == "eend_ola":
            from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
            train_iter_factory = EENDOLADataLoader(
                data_file=args.train_data_path_and_name_and_type[0][0],
                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
                num_workers=args.dataset_conf["num_workers"],
                shuffle=True)
            valid_iter_factory = EENDOLADataLoader(
                data_file=args.valid_data_path_and_name_and_type[0][0],
                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
                num_workers=0,
                shuffle=False)
        else:
            train_iter_factory = SequenceIterFactory(args, mode="train")
            valid_iter_factory = SequenceIterFactory(args, mode="valid")
    elif args.dataset_type == "large":
        train_iter_factory = LargeDataLoader(args, mode="train")
        valid_iter_factory = LargeDataLoader(args, mode="valid")