1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
| import numpy as np
| from torch.utils.data import DataLoader
|
| from funasr.datasets.iterable_dataset import IterableESPnetDataset
| from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
| from funasr.datasets.small_datasets.preprocessor import build_preprocess
|
|
| def build_streaming_iterator(
| task_name,
| preprocess_args,
| data_path_and_name_and_type,
| key_file: str = None,
| batch_size: int = 1,
| fs: dict = None,
| mc: bool = False,
| dtype: str = np.float32,
| num_workers: int = 1,
| use_collate_fn: bool = True,
| preprocess_fn=None,
| ngpu: int = 0,
| train: bool = False,
| ) -> DataLoader:
| """Build DataLoader using iterable dataset"""
|
| # preprocess
| if preprocess_fn is not None:
| preprocess_fn = preprocess_fn
| elif preprocess_args is not None:
| preprocess_args.task_name = task_name
| preprocess_fn = build_preprocess(preprocess_args, train)
| else:
| preprocess_fn = None
|
| # collate
| if not use_collate_fn:
| collate_fn = None
| elif task_name in ["punc", "lm"]:
| collate_fn = CommonCollateFn(int_pad_value=0)
| else:
| collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
| if collate_fn is not None:
| kwargs = dict(collate_fn=collate_fn)
| else:
| kwargs = {}
|
| dataset = IterableESPnetDataset(
| data_path_and_name_and_type,
| float_dtype=dtype,
| fs=fs,
| mc=mc,
| preprocess=preprocess_fn,
| key_file=key_file,
| )
| if dataset.apply_utt2category:
| kwargs.update(batch_size=1)
| else:
| kwargs.update(batch_size=batch_size)
|
| return DataLoader(
| dataset=dataset,
| pin_memory=ngpu > 0,
| num_workers=num_workers,
| **kwargs,
| )
|
|