From 9fcb3cc06b4e324f0913d2f61b89becc2baeef1b Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 11 九月 2023 17:40:03 +0800
Subject: [PATCH] Merge pull request #932 from alibaba-damo-academy/dev_lhn
---
funasr/build_utils/build_dataloader.py | 17 +++++++++++++++--
1 files changed, 15 insertions(+), 2 deletions(-)
diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
index c95c40d..473097e 100644
--- a/funasr/build_utils/build_dataloader.py
+++ b/funasr/build_utils/build_dataloader.py
@@ -4,8 +4,21 @@
def build_dataloader(args):
if args.dataset_type == "small":
- train_iter_factory = SequenceIterFactory(args, mode="train")
- valid_iter_factory = SequenceIterFactory(args, mode="valid")
+ if args.task_name == "diar" and args.model == "eend_ola":
+ from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
+ train_iter_factory = EENDOLADataLoader(
+ data_file=args.train_data_path_and_name_and_type[0][0],
+ batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+ num_workers=args.dataset_conf["num_workers"],
+ shuffle=True)
+ valid_iter_factory = EENDOLADataLoader(
+ data_file=args.valid_data_path_and_name_and_type[0][0],
+ batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+ num_workers=0,
+ shuffle=False)
+ else:
+ train_iter_factory = SequenceIterFactory(args, mode="train")
+ valid_iter_factory = SequenceIterFactory(args, mode="valid")
elif args.dataset_type == "large":
train_iter_factory = LargeDataLoader(args, mode="train")
valid_iter_factory = LargeDataLoader(args, mode="valid")
--
Gitblit v1.9.1