From ff0310bfb4ed69f00cbeab89a58f958ae5091d70 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 06 七月 2023 16:24:35 +0800
Subject: [PATCH] update eend_ola
---
funasr/build_utils/build_dataloader.py | 17 +++++++++++++++--
1 files changed, 15 insertions(+), 2 deletions(-)
diff --git a/funasr/build_utils/build_dataloader.py b/funasr/build_utils/build_dataloader.py
index c95c40d..473097e 100644
--- a/funasr/build_utils/build_dataloader.py
+++ b/funasr/build_utils/build_dataloader.py
@@ -4,8 +4,21 @@
def build_dataloader(args):
if args.dataset_type == "small":
- train_iter_factory = SequenceIterFactory(args, mode="train")
- valid_iter_factory = SequenceIterFactory(args, mode="valid")
+ if args.task_name == "diar" and args.model == "eend_ola":
+ from funasr.modules.eend_ola.eend_ola_dataloader import EENDOLADataLoader
+ train_iter_factory = EENDOLADataLoader(
+ data_file=args.train_data_path_and_name_and_type[0][0],
+ batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+ num_workers=args.dataset_conf["num_workers"],
+ shuffle=True)
+ valid_iter_factory = EENDOLADataLoader(
+ data_file=args.valid_data_path_and_name_and_type[0][0],
+ batch_size=args.dataset_conf["batch_conf"]["batch_size"],
+ num_workers=0,
+ shuffle=False)
+ else:
+ train_iter_factory = SequenceIterFactory(args, mode="train")
+ valid_iter_factory = SequenceIterFactory(args, mode="valid")
elif args.dataset_type == "large":
train_iter_factory = LargeDataLoader(args, mode="train")
valid_iter_factory = LargeDataLoader(args, mode="valid")
--
Gitblit v1.9.1