From 2ff405b2f4ab899eff9bece232969fbb0c8f0555 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 20 六月 2023 00:26:37 +0800
Subject: [PATCH] Merge pull request #653 from alibaba-damo-academy/dev_wjm_infer
---
funasr/build_utils/build_streaming_iterator.py | 67 +++++++++++++++++++++++++++++++++
1 files changed, 67 insertions(+), 0 deletions(-)
diff --git a/funasr/build_utils/build_streaming_iterator.py b/funasr/build_utils/build_streaming_iterator.py
new file mode 100644
index 0000000..1b16cf4
--- /dev/null
+++ b/funasr/build_utils/build_streaming_iterator.py
@@ -0,0 +1,67 @@
+import numpy as np
+from torch.utils.data import DataLoader
+from typeguard import check_argument_types
+
+from funasr.datasets.iterable_dataset import IterableESPnetDataset
+from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
+from funasr.datasets.small_datasets.preprocessor import build_preprocess
+
+
+def build_streaming_iterator(
+ task_name,
+ preprocess_args,
+ data_path_and_name_and_type,
+ key_file: str = None,
+ batch_size: int = 1,
+ fs: dict = None,
+ mc: bool = False,
+ dtype: str = np.float32,
+ num_workers: int = 1,
+ use_collate_fn: bool = True,
+ preprocess_fn=None,
+ ngpu: int = 0,
+ train: bool = False,
+) -> DataLoader:
+ """Build DataLoader using iterable dataset"""
+ assert check_argument_types()
+
+ # preprocess
+ if preprocess_fn is not None:
+ preprocess_fn = preprocess_fn
+ elif preprocess_args is not None:
+ preprocess_args.task_name = task_name
+ preprocess_fn = build_preprocess(preprocess_args, train)
+ else:
+ preprocess_fn = None
+
+ # collate
+ if not use_collate_fn:
+ collate_fn = None
+ elif task_name in ["punc", "lm"]:
+ collate_fn = CommonCollateFn(int_pad_value=0)
+ else:
+ collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
+ if collate_fn is not None:
+ kwargs = dict(collate_fn=collate_fn)
+ else:
+ kwargs = {}
+
+ dataset = IterableESPnetDataset(
+ data_path_and_name_and_type,
+ float_dtype=dtype,
+ fs=fs,
+ mc=mc,
+ preprocess=preprocess_fn,
+ key_file=key_file,
+ )
+ if dataset.apply_utt2category:
+ kwargs.update(batch_size=1)
+ else:
+ kwargs.update(batch_size=batch_size)
+
+ return DataLoader(
+ dataset=dataset,
+ pin_memory=ngpu > 0,
+ num_workers=num_workers,
+ **kwargs,
+ )
--
Gitblit v1.9.1