| | |
| | | batch_size=8000, |
| | | len_fn=_default_len_fn, |
| | | buffer_size=10240, |
| | | sort_size=500 |
| | | sort_size=500, |
| | | batch_mode="padding", |
| | | ): |
| | | assert batch_size > 0, "Batch size is required to be larger than 0!" |
| | | assert buffer_size >= -1, "Buffer size is required to be larger than -1!" |
| | |
| | | self.batch_size = batch_size |
| | | self.buffer_size = buffer_size |
| | | self.sort_size = sort_size |
| | | self.batch_mode = batch_mode |
| | | |
| | | def set_epoch(self, epoch): |
| | | self.epoch = epoch |
| | |
| | | max_lengths = 0 |
| | | batch_lengths = 0 |
| | | |
| | | if self.buffer_size == -1: |
| | | for d in self.datapipe: |
| | | if d[0] > self.batch_size: |
| | | continue |
| | | buffer.append(d) |
| | | buffer.sort() |
| | | for sample in buffer: |
| | | length, _, token = sample |
| | | if length > max_lengths: |
| | | max_lengths = length |
| | | batch_lengths = max_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | bucket.append(batch) |
| | | batch = [] |
| | | max_lengths = length |
| | | batch.append(token) |
| | | random.shuffle(bucket) |
| | | if bucket: |
| | | for batch_sample in bucket: |
| | | yield batch_sample |
| | | if batch: |
| | | yield batch |
| | | |
| | | elif self.buffer_size == 0: |
| | | for d in self.datapipe: |
| | | if d[0] > self.batch_size: |
| | | continue |
| | | length, _, token = d |
| | | if length > self.batch_size: |
| | | continue |
| | | if length > max_lengths: |
| | | max_lengths = length |
| | | batch_lengths = max_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | max_lengths = length |
| | | batch.append(token) |
| | | if batch: |
| | | yield batch |
| | | |
| | | else: |
| | | if self.batch_mode == "clipping": |
| | | assert self.buffer_size > 0, "for clipping batch_mode, buffer_size must be > 1" |
| | | for d in self.datapipe: |
| | | if d[0] > self.batch_size: |
| | | continue |
| | | buffer.append(d) |
| | | if len(buffer) == self.buffer_size: |
| | | random.shuffle(buffer) |
| | | for sample in buffer: |
| | | bucket.append(sample) |
| | | if len(bucket) == self.sort_size: |
| | | bucket.sort() |
| | | for x in bucket: |
| | | length, _, token = x |
| | | if length < min_lengths: |
| | | min_lengths = length |
| | | batch_lengths = min_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | min_lengths = length |
| | | batch.append(token) |
| | | bucket = [] |
| | | buffer = [] |
| | | |
| | | if buffer: |
| | | random.shuffle(buffer) |
| | | for sample in buffer: |
| | | bucket.append(sample) |
| | | if len(bucket) == self.sort_size: |
| | | bucket.sort() |
| | | for x in bucket: |
| | | length, _, token = x |
| | | if length < min_lengths: |
| | | min_lengths = length |
| | | batch_lengths = min_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | min_lengths = length |
| | | batch.append(token) |
| | | bucket = [] |
| | | buffer = [] |
| | | |
| | | if bucket: |
| | | bucket.sort() |
| | | for x in bucket: |
| | | length, _, token = x |
| | | if length < min_lengths: |
| | | min_lengths = length |
| | | batch_lengths = min_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | min_lengths = length |
| | | batch.append(token) |
| | | bucket = [] |
| | | |
| | | if batch: |
| | | yield batch |
| | | |
| | | else: |
| | | if self.buffer_size == -1: |
| | | for d in self.datapipe: |
| | | if d[0] > self.batch_size: |
| | | continue |
| | | buffer.append(d) |
| | | buffer.sort() |
| | | for sample in buffer: |
| | | length, _, token = sample |
| | | if length > max_lengths: |
| | | max_lengths = length |
| | | batch_lengths = max_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | bucket.append(batch) |
| | | batch = [] |
| | | max_lengths = length |
| | | batch.append(token) |
| | | random.shuffle(bucket) |
| | | if bucket: |
| | | for batch_sample in bucket: |
| | | yield batch_sample |
| | | if batch: |
| | | yield batch |
| | | |
| | | elif self.buffer_size == 0: |
| | | for d in self.datapipe: |
| | | if d[0] > self.batch_size: |
| | | continue |
| | | length, _, token = d |
| | | if length > self.batch_size: |
| | | continue |
| | | if length > max_lengths: |
| | | max_lengths = length |
| | | batch_lengths = max_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | max_lengths = length |
| | | batch.append(token) |
| | | if batch: |
| | | yield batch |
| | | |
| | | else: |
| | | for d in self.datapipe: |
| | | if d[0] > self.batch_size: |
| | | continue |
| | | buffer.append(d) |
| | | if len(buffer) == self.buffer_size: |
| | | random.shuffle(buffer) |
| | | for sample in buffer: |
| | | bucket.append(sample) |
| | | if len(bucket) == self.sort_size: |
| | | bucket.sort() |
| | | for x in bucket: |
| | | length, _, token = x |
| | | if length > max_lengths: |
| | | max_lengths = length |
| | | batch_lengths = max_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | max_lengths = length |
| | | batch.append(token) |
| | | bucket = [] |
| | | buffer = [] |
| | | |
| | | if buffer: |
| | | random.shuffle(buffer) |
| | | for sample in buffer: |
| | | bucket.append(sample) |
| | |
| | | bucket = [] |
| | | buffer = [] |
| | | |
| | | if buffer: |
| | | random.shuffle(buffer) |
| | | for sample in buffer: |
| | | bucket.append(sample) |
| | | if len(bucket) == self.sort_size: |
| | | bucket.sort() |
| | | for x in bucket: |
| | | length, _, token = x |
| | | if length > max_lengths: |
| | | max_lengths = length |
| | | batch_lengths = max_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | max_lengths = length |
| | | batch.append(token) |
| | | bucket = [] |
| | | buffer = [] |
| | | if bucket: |
| | | bucket.sort() |
| | | for x in bucket: |
| | | length, _, token = x |
| | | if length > max_lengths: |
| | | max_lengths = length |
| | | batch_lengths = max_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | max_lengths = length |
| | | batch.append(token) |
| | | bucket = [] |
| | | |
| | | if bucket: |
| | | bucket.sort() |
| | | for x in bucket: |
| | | length, _, token = x |
| | | if length > max_lengths: |
| | | max_lengths = length |
| | | batch_lengths = max_lengths * (len(batch) + 1) |
| | | if batch_lengths > self.batch_size: |
| | | yield batch |
| | | batch = [] |
| | | max_lengths = length |
| | | batch.append(token) |
| | | bucket = [] |
| | | |
| | | if batch: |
| | | yield batch |
| | | if batch: |
| | | yield batch |