| | |
| | | self.sort_size = sort_size * num_replicas |
| | | self.max_token_length = kwargs.get("max_token_length", 2048) |
| | | self.length_scale_source = kwargs.get("length_scale_source", 1.0) |
| | | self.start_step = kwargs.get("start_step", 2048) |
| | | |
| | | super().__init__( |
| | | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last |
| | | ) |
| | |
| | | rank_batches[i % self.num_replicas].append(batch) |
| | | |
| | | # Assign all batches for the current rank directly |
| | | final_batches = rank_batches[self.rank] |
| | | final_batches = rank_batches[self.rank][self.start_step :] |
| | | self.batch_num = len(final_batches) |
| | | |
| | | logging.info( |
| | | f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {self.batch_num}" |
| | | ) |
| | | return iter(final_batches) |
| | | |
| | | def __len__(self): |
| | | |
| | | return 1 |
| | | # Calculate the number of batches per epoch for the current rank |
| | | return self.batch_num |
| | | |
| | | def set_epoch(self, epoch): |
| | | self.epoch = epoch |