From 012903e42ec890ab5c50137beb365c3d94e731d1 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期五, 30 六月 2023 11:21:28 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
---
funasr/build_utils/build_streaming_iterator.py | 65 ++++++++++++++++++++++++++++++++
1 files changed, 65 insertions(+), 0 deletions(-)
diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py
new file mode 100644
index 0000000..02fc263
--- /dev/null
+++ b/funasr/build_utils/build_streaming_iterator.py
@@ -0,0 +1,65 @@
+import numpy as np
+from torch.utils.data import DataLoader
+
+from funasr.datasets.iterable_dataset import IterableESPnetDataset
+from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
+from funasr.datasets.small_datasets.preprocessor import build_preprocess
+
+
+def build_streaming_iterator(
+ task_name,
+ preprocess_args,
+ data_path_and_name_and_type,
+ key_file: str = None,
+ batch_size: int = 1,
+ fs: dict = None,
+ mc: bool = False,
+ dtype: str = np.float32,
+ num_workers: int = 1,
+ use_collate_fn: bool = True,
+ preprocess_fn=None,
+ ngpu: int = 0,
+ train: bool = False,
+) -> DataLoader:
+ """Build DataLoader using iterable dataset"""
+
+ # preprocess
+ if preprocess_fn is not None:
+ preprocess_fn = preprocess_fn
+ elif preprocess_args is not None:
+ preprocess_args.task_name = task_name
+ preprocess_fn = build_preprocess(preprocess_args, train)
+ else:
+ preprocess_fn = None
+
+ # collate
+ if not use_collate_fn:
+ collate_fn = None
+ elif task_name in ["punc", "lm"]:
+ collate_fn = CommonCollateFn(int_pad_value=0)
+ else:
+ collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+ if collate_fn is not None:
+ kwargs = dict(collate_fn=collate_fn)
+ else:
+ kwargs = {}
+
+ dataset = IterableESPnetDataset(
+ data_path_and_name_and_type,
+ float_dtype=dtype,
+ fs=fs,
+ mc=mc,
+ preprocess=preprocess_fn,
+ key_file=key_file,
+ )
+ if dataset.apply_utt2category:
+ kwargs.update(batch_size=1)
+ else:
+ kwargs.update(batch_size=batch_size)
+
+ return DataLoader(
+ dataset=dataset,
+ pin_memory=ngpu > 0,
+ num_workers=num_workers,
+ **kwargs,
+ )
--
Gitblit v1.9.1