| | |
| | | self.max_token_length = kwargs.get("max_token_length", 2048) |
| | | self.length_scale_source = kwargs.get("length_scale_source", 1.0) |
| | | self.start_step = kwargs.get("start_step", 2048) |
| | | self.batch_size_sample_max = kwargs.get("batch_size_sample_max", 200) |
| | | |
| | | super().__init__( |
| | | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last |
| | |
| | | ) |
| | | batch = [] |
| | | max_len_in_batch = 0 |
| | | count = 0 |
| | | for idx in buffer: |
| | | original_sample_length = self.dataset.get_source_len(idx) |
| | | if original_sample_length > self.max_token_length: |
| | | continue |
| | | sample_length = 1 if self.batch_type == "example" else original_sample_length |
| | | potential_batch_length = max(max_len_in_batch, sample_length) * (len(batch) + 1) |
| | | if potential_batch_length <= self.batch_size: |
| | | if ( |
| | | potential_batch_length <= self.batch_size |
| | | and count <= self.batch_size_sample_max |
| | | ): |
| | | batch.append(idx) |
| | | max_len_in_batch = max(max_len_in_batch, sample_length) |
| | | count += 1 |
| | | else: |
| | | buffer_batches.append(batch) |
| | | batch = [idx] |
| | | max_len_in_batch = sample_length |
| | | count = 0 |
| | | if batch: |
| | | buffer_batches.append(batch) |
| | | |