From f8d1c79fe355efb18ae49e4363307dfec3ab89ce Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期一, 07 八月 2023 16:14:11 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main
---
funasr/build_utils/build_dataloader.py | 17 +++++++++++++++--
1 files changed, 15 insertions(+), 2 deletions(-)
diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
index c95c40d..473097e 100644
--- a/funasr/build_utils/build_dataloader.py
+++ b/funasr/build_utils/build_dataloader.py
@@ -4,8 +4,21 @@
def build_dataloader(args):
if args.dataset_type == "small":
- train_iter_factory = SequenceIterFactory(args, mode="train")
- valid_iter_factory = SequenceIterFactory(args, mode="valid")
+ if args.task_name == "diar" and args.model == "eend_ola":
+ from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
+ train_iter_factory = EENDOLADataLoader(
+ data_file=args.train_data_path_and_name_and_type[0][0],
+ batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+ num_workers=args.dataset_conf["num_workers"],
+ shuffle=True)
+ valid_iter_factory = EENDOLADataLoader(
+ data_file=args.valid_data_path_and_name_and_type[0][0],
+ batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+ num_workers=0,
+ shuffle=False)
+ else:
+ train_iter_factory = SequenceIterFactory(args, mode="train")
+ valid_iter_factory = SequenceIterFactory(args, mode="valid")
elif args.dataset_type == "large":
train_iter_factory = LargeDataLoader(args, mode="train")
valid_iter_factory = LargeDataLoader(args, mode="valid")
--
Gitblit v1.9.1