From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update

---
 funasr/datasets/large_datasets/build_dataloader.py |    5 +++--
 1 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 134b20a..8a255f9 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -8,9 +8,10 @@
 from torch.utils.data import DataLoader
 
 from funasr.datasets.large_datasets.dataset import Dataset
-from funasr.iterators.abs_iter_factory import AbsIterFactory
+from funasr.datasets.large_datasets.abs_iter_factory import AbsIterFactory
 from funasr.tokenizer.abs_tokenizer import AbsTokenizer
 
+from funasr.register import tables
 
 def read_symbol_table(symbol_table_file):
     if isinstance(symbol_table_file, str):
@@ -61,7 +62,7 @@
         self._build_sentence_piece_processor()
         return self.sp.DecodePieces(list(tokens))
 
-
+@tables.register("dataset_classes", "LargeDataset")
 class LargeDataLoader(AbsIterFactory):
     def __init__(self, args, mode="train"):
         symbol_table, seg_dict, punc_dict, bpe_tokenizer = None, None, None, None

--
Gitblit v1.9.1