From 012903e42ec890ab5c50137beb365c3d94e731d1 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期五, 30 六月 2023 11:21:28 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR

---
 funasr/build_utils/build_streaming_iterator.py |   65 ++++++++++++++++++++++++++++++++
 1 files changed, 65 insertions(+), 0 deletions(-)

diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py
new file mode 100644
index 0000000..02fc263
--- /dev/null
+++ b/funasr/build_utils/build_streaming_iterator.py
@@ -0,0 +1,65 @@
+import numpy as np
+from torch.utils.data import DataLoader
+
+from funasr.datasets.iterable_dataset import IterableESPnetDataset
+from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
+from funasr.datasets.small_datasets.preprocessor import build_preprocess
+
+
+def build_streaming_iterator(
+        task_name,
+        preprocess_args,
+        data_path_and_name_and_type,
+        key_file: str = None,
+        batch_size: int = 1,
+        fs: dict = None,
+        mc: bool = False,
+        dtype: str = np.float32,
+        num_workers: int = 1,
+        use_collate_fn: bool = True,
+        preprocess_fn=None,
+        ngpu: int = 0,
+        train: bool = False,
+) -> DataLoader:
+    """Build DataLoader using iterable dataset"""
+
+    # preprocess
+    if preprocess_fn is not None:
+        preprocess_fn = preprocess_fn
+    elif preprocess_args is not None:
+        preprocess_args.task_name = task_name
+        preprocess_fn = build_preprocess(preprocess_args, train)
+    else:
+        preprocess_fn = None
+
+    # collate
+    if not use_collate_fn:
+        collate_fn = None
+    elif task_name in ["punc", "lm"]:
+        collate_fn = CommonCollateFn(int_pad_value=0)
+    else:
+        collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+    if collate_fn is not None:
+        kwargs = dict(collate_fn=collate_fn)
+    else:
+        kwargs = {}
+
+    dataset = IterableESPnetDataset(
+        data_path_and_name_and_type,
+        float_dtype=dtype,
+        fs=fs,
+        mc=mc,
+        preprocess=preprocess_fn,
+        key_file=key_file,
+    )
+    if dataset.apply_utt2category:
+        kwargs.update(batch_size=1)
+    else:
+        kwargs.update(batch_size=batch_size)
+
+    return DataLoader(
+        dataset=dataset,
+        pin_memory=ngpu > 0,
+        num_workers=num_workers,
+        **kwargs,
+    )

--
Gitblit v1.9.1