From f8d1c79fe355efb18ae49e4363307dfec3ab89ce Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期一, 07 八月 2023 16:14:11 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main

---
 funasr/build_utils/build_dataloader.py |   17 +++++++++++++++--
 1 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
index c95c40d..473097e 100644
--- a/funasr/build_utils/build_dataloader.py
+++ b/funasr/build_utils/build_dataloader.py
@@ -4,8 +4,21 @@
 
 def build_dataloader(args):
     if args.dataset_type == "small":
-        train_iter_factory = SequenceIterFactory(args, mode="train")
-        valid_iter_factory = SequenceIterFactory(args, mode="valid")
+        if args.task_name == "diar" and args.model == "eend_ola":
+            from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
+            train_iter_factory = EENDOLADataLoader(
+                data_file=args.train_data_path_and_name_and_type[0][0],
+                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+                num_workers=args.dataset_conf["num_workers"],
+                shuffle=True)
+            valid_iter_factory = EENDOLADataLoader(
+                data_file=args.valid_data_path_and_name_and_type[0][0],
+                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+                num_workers=0,
+                shuffle=False)
+        else:
+            train_iter_factory = SequenceIterFactory(args, mode="train")
+            valid_iter_factory = SequenceIterFactory(args, mode="valid")
     elif args.dataset_type == "large":
         train_iter_factory = LargeDataLoader(args, mode="train")
         valid_iter_factory = LargeDataLoader(args, mode="valid")

--
Gitblit v1.9.1