| | |
| | | |
| | | def build_dataloader(args): |
| | | if args.dataset_type == "small": |
| | | train_iter_factory = SequenceIterFactory(args, mode="train") |
| | | valid_iter_factory = SequenceIterFactory(args, mode="valid") |
| | | if args.task_name == "diar" and args.model == "eend_ola": |
| | | from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader |
| | | train_iter_factory = EENDOLADataLoader( |
| | | data_file=args.train_data_path_and_name_and_type[0][0], |
| | | batch_size=args.dataset_conf["batch_conf"]["batch_size"], |
| | | num_workers=args.dataset_conf["num_workers"], |
| | | shuffle=True) |
| | | valid_iter_factory = EENDOLADataLoader( |
| | | data_file=args.valid_data_path_and_name_and_type[0][0], |
| | | batch_size=args.dataset_conf["batch_conf"]["batch_size"], |
| | | num_workers=0, |
| | | shuffle=False) |
| | | else: |
| | | train_iter_factory = SequenceIterFactory(args, mode="train") |
| | | valid_iter_factory = SequenceIterFactory(args, mode="valid") |
| | | elif args.dataset_type == "large": |
| | | train_iter_factory = LargeDataLoader(args, mode="train") |
| | | valid_iter_factory = LargeDataLoader(args, mode="valid") |