1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
| import numpy as np
| from torch.utils.data import DataLoader
| from typeguard import check_argument_types
|
| from funasr.datasets.iterable_dataset import IterableESPnetDataset
| from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
| from funasr.datasets.small_datasets.preprocessor import build_preprocess
|
|
| def build_streaming_iterator(
| task_name,
| preprocess_args,
| data_path_and_name_and_type,
| key_file: str = None,
| batch_size: int = 1,
| fs: dict = None,
| mc: bool = False,
| dtype: str = np.float32,
| num_workers: int = 1,
| ngpu: int = 0,
| train: bool=False,
| ) -> DataLoader:
| """Build DataLoader using iterable dataset"""
| assert check_argument_types()
|
| # preprocess
| preprocess_fn = build_preprocess(preprocess_args, train)
|
| # collate
| if task_name in ["punc", "lm"]:
| collate_fn = CommonCollateFn(int_pad_value=0)
| else:
| collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
| if collate_fn is not None:
| kwargs = dict(collate_fn=collate_fn)
| else:
| kwargs = {}
|
| dataset = IterableESPnetDataset(
| data_path_and_name_and_type,
| float_dtype=dtype,
| fs=fs,
| mc=mc,
| preprocess=preprocess_fn,
| key_file=key_file,
| )
| if dataset.apply_utt2category:
| kwargs.update(batch_size=1)
| else:
| kwargs.update(batch_size=batch_size)
|
| return DataLoader(
| dataset=dataset,
| pin_memory=ngpu > 0,
| num_workers=num_workers,
| **kwargs,
| )
|
|