| | |
| | | class MaxTokenBucketizerIterDataPipe(IterableDataset): |
| | | |
| | | def __init__( |
| | | self, |
| | | datapipe, |
| | | batch_size=8000, |
| | | len_fn=_default_len_fn, |
| | | buffer_size=10240, |
| | | sort_size=500, |
| | | batch_mode="padding", |
| | | self, |
| | | datapipe, |
| | | batch_size=8000, |
| | | len_fn=_default_len_fn, |
| | | buffer_size=10240, |
| | | sort_size=500, |
| | | batch_mode="padding", |
| | | ): |
| | | assert batch_size > 0, "Batch size is required to be larger than 0!" |
| | | assert buffer_size >= -1, "Buffer size is required to be larger than -1!" |
| | |
| | | self.batch_mode = batch_mode |
| | | |
| | | def set_epoch(self, epoch): |
| | | self.epoch = epoch |
| | | self.datapipe.set_epoch(epoch) |
| | | |
| | | def __iter__(self): |
| | | buffer = [] |
| | | batch = [] |
| | | bucket = [] |
| | | max_lengths = 0 |
| | | min_lengths = 999999 |
| | | batch_lengths = 0 |
| | | |
| | | if self.batch_mode == "clipping": |