| | |
| | | ) |
| | | batch = [] |
| | | max_len_in_batch = 0 |
| | | count = 0 |
| | | count = 1 |
| | | for idx in buffer: |
| | | original_sample_length = self.dataset.get_source_len(idx) |
| | | if original_sample_length > self.max_token_length: |
| | |
| | | buffer_batches.append(batch) |
| | | batch = [idx] |
| | | max_len_in_batch = sample_length |
| | | count = 0 |
| | | count = 1 |
| | | if batch: |
| | | buffer_batches.append(batch) |
| | | |
| | |
| | | self.batch_num = len(final_batches) |
| | | |
| | | logging.info( |
| | | f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {rank_batches[self.rank]}, after: {self.batch_num}" |
| | | f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {len(rank_batches[self.rank])}, after: {self.batch_num}" |
| | | ) |
| | | return iter(final_batches) |
| | | |