From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update
---
funasr/datasets/large_datasets/build_dataloader.py | 5 +++--
1 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/funasr/datasets/large_datasets/build_dataloader.py b/funasr/datasets/large_datasets/build_dataloader.py
index 134b20a..8a255f9 100644
--- a/funasr/datasets/large_datasets/build_dataloader.py
+++ b/funasr/datasets/large_datasets/build_dataloader.py
@@ -8,9 +8,10 @@
from torch.utils.data import DataLoader
from funasr.datasets.large_datasets.dataset import Dataset
-from funasr.iterators.abs_iter_factory import AbsIterFactory
+from funasr.datasets.large_datasets.abs_iter_factory import AbsIterFactory
from funasr.tokenizer.abs_tokenizer import AbsTokenizer
+from funasr.register import tables
def read_symbol_table(symbol_table_file):
if isinstance(symbol_table_file, str):
@@ -61,7 +62,7 @@
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))
-
+@tables.register("dataset_classes", "LargeDataset")
class LargeDataLoader(AbsIterFactory):
def __init__(self, args, mode="train"):
symbol_table, seg_dict, punc_dict, bpe_tokenizer = None, None, None, None
--
Gitblit v1.9.1