From 9be8a443d74d68f179de88fff13b4e8424579d7b Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 10 三月 2023 18:24:39 +0800
Subject: [PATCH] Merge pull request #207 from alibaba-damo-academy/dev_dzh

---
 funasr/tasks/abs_task.py |    6 +++++-
 1 files changed, 5 insertions(+), 1 deletions(-)

diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index d2a00b2..723a67c 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -1348,11 +1348,13 @@
             if args.dataset_type == "large":
                 from funasr.datasets.large_datasets.build_dataloader import ArkDataLoader
                 train_iter_factory = ArkDataLoader(args.train_data_file, args.token_list, args.dataset_conf,
+                                                   frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
                                                    seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                                "seg_dict_file") else None,
                                                    punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
                                                    mode="train")
-                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf,
+                valid_iter_factory = ArkDataLoader(args.valid_data_file, args.token_list, args.dataset_conf, 
+                                                   frontend_conf=args.frontend_conf if hasattr(args, "frontend_conf") else None,
                                                    seg_dict_file=args.seg_dict_file if hasattr(args,
                                                                                                "seg_dict_file") else None,
                                                    punc_dict_file=args.punc_list if hasattr(args, "punc_list") else None,
@@ -1845,6 +1847,7 @@
             key_file: str = None,
             batch_size: int = 1,
             fs: dict = None,
+            mc: bool = False,
             dtype: str = np.float32,
             num_workers: int = 1,
             allow_variable_data_keys: bool = False,
@@ -1863,6 +1866,7 @@
             data_path_and_name_and_type,
             float_dtype=dtype,
             fs=fs,
+            mc=mc,
             preprocess=preprocess_fn,
             key_file=key_file,
         )

--
Gitblit v1.9.1