From ff0310bfb4ed69f00cbeab89a58f958ae5091d70 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 06 七月 2023 16:24:35 +0800
Subject: [PATCH] update eend_ola

---
 funasr/build_utils/build_dataloader.py |   17 +++++++++++++++--
 1 files changed, 15 insertions(+), 2 deletions(-)

diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
index c95c40d..473097e 100644
--- a/funasr/build_utils/build_dataloader.py
+++ b/funasr/build_utils/build_dataloader.py
@@ -4,8 +4,21 @@
 
 def build_dataloader(args):
     if args.dataset_type == "small":
-        train_iter_factory = SequenceIterFactory(args, mode="train")
-        valid_iter_factory = SequenceIterFactory(args, mode="valid")
+        if args.task_name == "diar" and args.model == "eend_ola":
+            from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
+            train_iter_factory = EENDOLADataLoader(
+                data_file=args.train_data_path_and_name_and_type[0][0],
+                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+                num_workers=args.dataset_conf["num_workers"],
+                shuffle=True)
+            valid_iter_factory = EENDOLADataLoader(
+                data_file=args.valid_data_path_and_name_and_type[0][0],
+                batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+                num_workers=0,
+                shuffle=False)
+        else:
+            train_iter_factory = SequenceIterFactory(args, mode="train")
+            valid_iter_factory = SequenceIterFactory(args, mode="valid")
     elif args.dataset_type == "large":
         train_iter_factory = LargeDataLoader(args, mode="train")
         valid_iter_factory = LargeDataLoader(args, mode="valid")

--
Gitblit v1.9.1