嘉渊
2023-04-24 189b51d42bd29032091f1e29ae5585eb52c0af57
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import logging
 
import numpy as np
import torch
from torch.utils.data import DataLoader
 
from funasr.datasets.small_datasets.collate_fn import CommonCollateFn
from funasr.datasets.small_datasets.dataset import ESPnetDataset
from funasr.datasets.small_datasets.length_batch_sampler import LengthBatchSampler
from funasr.datasets.small_datasets.preprocessor import build_preprocess
from funasr.iterators.abs_iter_factory import AbsIterFactory
from funasr.samplers.abs_sampler import AbsSampler
 
 
class RawSampler(AbsSampler):
    def __init__(self, batches):
        self.batches = batches
 
    def __len__(self):
        return len(self.batches)
 
    def __iter__(self):
        return iter(self.batches)
 
    def generate(self, seed):
        return list(self.batches)
 
 
class SequenceIterFactory(AbsIterFactory):
    """Build iterator for each epoch, modified from ESPnet
 
    """
 
    def __init__(self, args, mode="train"):
 
        # preprocess
        preprocess_fn = build_preprocess(args, train=mode == "train")
 
        # collate
        if args.task_name in ["punc", "lm"]:
            collate_fn = CommonCollateFn(int_pad_value=0)
        else:
            collate_fn = CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
 
        # dataset
        dest_sample_rate = args.frontend_conf["fs"] if (
                args.frontend_conf is not None and "fs" in args.frontend_conf) else 16000
        if mode == "train":
            data_path_and_name_and_type = args.train_data_path_and_name_and_type
            shape_files = args.train_shape_file
        elif mode == "valid":
            data_path_and_name_and_type = args.valid_data_path_and_name_and_type
            shape_files = args.valid_shape_file
        else:
            raise NotImplementedError(f"mode={mode}")
        dataset = ESPnetDataset(
            data_path_and_name_and_type,
            preprocess=preprocess_fn,
            dest_sample_rate=dest_sample_rate,
        )
 
        # sampler
        dataset_conf = args.dataset_conf
        batch_sampler = LengthBatchSampler(
            batch_bins=dataset_conf["batch_size"] * args.ngpu,
            shape_files=shape_files,
            sort_in_batch=dataset_conf["sort_in_batch"] if hasattr(dataset_conf, "sort_in_batch") else "descending",
            sort_batch=dataset_conf["sort_batch"] if hasattr(dataset_conf, "sort_batch") else "ascending",
            drop_last=False,
            padding=True,
        )
 
        batches = list(batch_sampler)
        bs_list = [len(batch) for batch in batches]
        logging.info(f"[{mode}] dataset:\n{dataset}")
        logging.info(f"[{mode}] Batch sampler: {batch_sampler}")
        logging.info(
            f"[{mode}] mini-batch sizes summary: N-batch={len(bs_list)}, "
            f"mean={np.mean(bs_list):.1f}, min={np.min(bs_list)}, max={np.max(bs_list)}"
        )
 
        if args.scheduler == "tri_stage" and mode == "train":
            args.max_update = len(bs_list) * args.max_epoch
            logging.info("Max update: {}".format(args.max_update))
 
        if args.distributed:
            world_size = torch.distributed.get_world_size()
            rank = torch.distributed.get_rank()
            for batch in batches:
                if len(batch) < world_size:
                    raise RuntimeError(
                        f"The batch-size must be equal or more than world_size: "
                        f"{len(batch)} < {world_size}"
                    )
            batches = [batch[rank::world_size] for batch in batches]
 
        if not isinstance(batches, AbsSampler):
            self.sampler = RawSampler(batches)
        else:
            self.sampler = batches
 
        self.dataset = dataset
        self.num_iters_per_epoch = None
        self.shuffle = mode == "train"
        self.seed = args.seed
        self.num_workers = args.num_workers
        self.collate_fn = collate_fn
        self.pin_memory = args.ngpu > 0
 
    def build_iter(self, epoch: int, shuffle: bool = None) -> DataLoader:
        if shuffle is None:
            shuffle = self.shuffle
 
        if self.num_iters_per_epoch is not None:
            N = len(self.sampler)
            # If corpus size is larger than the num_per_epoch
            if self.num_iters_per_epoch < N:
                N = len(self.sampler)
                real_epoch, offset = divmod(self.num_iters_per_epoch * epoch, N)
 
                if offset >= self.num_iters_per_epoch:
                    current_batches = self.sampler.generate(real_epoch + self.seed)
                    if shuffle:
                        np.random.RandomState(real_epoch + self.seed).shuffle(
                            current_batches
                        )
                    batches = current_batches[
                              offset - self.num_iters_per_epoch: offset
                              ]
                else:
                    prev_batches = self.sampler.generate(real_epoch - 1 + self.seed)
                    current_batches = self.sampler.generate(real_epoch + self.seed)
                    if shuffle:
                        np.random.RandomState(real_epoch - 1 + self.seed).shuffle(
                            prev_batches
                        )
                        np.random.RandomState(real_epoch + self.seed).shuffle(
                            current_batches
                        )
                    batches = (
                            prev_batches[offset - self.num_iters_per_epoch:]
                            + current_batches[:offset]
                    )
 
            # If corpus size is less than the num_per_epoch
            else:
                _epoch, _cursor = divmod(self.num_iters_per_epoch * (epoch - 1), N)
                _remain = self.num_iters_per_epoch
                batches = []
                current_batches = self.sampler.generate(_epoch + self.seed)
                if shuffle:
                    np.random.RandomState(_epoch + self.seed).shuffle(current_batches)
                while _remain > 0:
 
                    _batches = current_batches[_cursor: _cursor + _remain]
                    batches += _batches
                    if _cursor + _remain >= N:
                        _epoch += 1
                        _cursor = 0
                        current_batches = self.sampler.generate(_epoch + self.seed)
                        if shuffle:
                            np.random.RandomState(_epoch + self.seed).shuffle(
                                current_batches
                            )
                    else:
                        _cursor = _cursor + _remain
                    _remain -= len(_batches)
 
                assert len(batches) == self.num_iters_per_epoch
 
        else:
            batches = self.sampler.generate(epoch + self.seed)
            if shuffle:
                np.random.RandomState(epoch + self.seed).shuffle(batches)
 
        # For backward compatibility for pytorch DataLoader
        if self.collate_fn is not None:
            kwargs = dict(collate_fn=self.collate_fn)
        else:
            kwargs = {}
 
        return DataLoader(
            dataset=self.dataset,
            batch_sampler=batches,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            **kwargs,
        )