From 1c8b46a233ac4a782d7170e20533f536761e25c4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 09 六月 2024 00:21:44 +0800
Subject: [PATCH] fix bug
---
funasr/datasets/dataloader_entry.py | 21 ++++++++++++---------
1 files changed, 12 insertions(+), 9 deletions(-)
diff --git a/funasr/datasets/dataloader_entry.py b/funasr/datasets/dataloader_entry.py
index 925b1d3..055e4c8 100644
--- a/funasr/datasets/dataloader_entry.py
+++ b/funasr/datasets/dataloader_entry.py
@@ -49,14 +49,19 @@
def __init__(self, frontend=None, tokenizer=None, **kwargs):
# dataset
logging.info("Build dataloader")
+
dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
- dataset_tr = dataset_class(
- kwargs.get("train_data_set_list"),
- frontend=frontend,
- tokenizer=tokenizer,
- is_training=True,
- **kwargs.get("dataset_conf"),
- )
+ dataset_tr = None
+ # split dataset
+ self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
+ if self.data_split_num == 1:
+ dataset_tr = dataset_class(
+ kwargs.get("train_data_set_list"),
+ frontend=frontend,
+ tokenizer=tokenizer,
+ is_training=True,
+ **kwargs.get("dataset_conf"),
+ )
dataset_val = dataset_class(
kwargs.get("valid_data_set_list"),
frontend=frontend,
@@ -69,8 +74,6 @@
self.dataset_val = dataset_val
self.kwargs = kwargs
- # split dataset
- self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
self.dataset_class = dataset_class
self.frontend = frontend
self.tokenizer = tokenizer
--
Gitblit v1.9.1