From 1c8b46a233ac4a782d7170e20533f536761e25c4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 09 六月 2024 00:21:44 +0800
Subject: [PATCH] fix bug

---
 funasr/datasets/dataloader_entry.py |   21 ++++++++++++---------
 1 files changed, 12 insertions(+), 9 deletions(-)

diff --git a/funasr/datasets/dataloader_entry.py b/funasr/datasets/dataloader_entry.py
index 925b1d3..055e4c8 100644
--- a/funasr/datasets/dataloader_entry.py
+++ b/funasr/datasets/dataloader_entry.py
@@ -49,14 +49,19 @@
     def __init__(self, frontend=None, tokenizer=None, **kwargs):
         # dataset
         logging.info("Build dataloader")
+
         dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset"))
-        dataset_tr = dataset_class(
-            kwargs.get("train_data_set_list"),
-            frontend=frontend,
-            tokenizer=tokenizer,
-            is_training=True,
-            **kwargs.get("dataset_conf"),
-        )
+        dataset_tr = None
+        # split dataset
+        self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
+        if self.data_split_num == 1:
+            dataset_tr = dataset_class(
+                kwargs.get("train_data_set_list"),
+                frontend=frontend,
+                tokenizer=tokenizer,
+                is_training=True,
+                **kwargs.get("dataset_conf"),
+            )
         dataset_val = dataset_class(
             kwargs.get("valid_data_set_list"),
             frontend=frontend,
@@ -69,8 +74,6 @@
         self.dataset_val = dataset_val
         self.kwargs = kwargs
 
-        # split dataset
-        self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1)
         self.dataset_class = dataset_class
         self.frontend = frontend
         self.tokenizer = tokenizer

--
Gitblit v1.9.1