| | |
| | | drop_last=False, |
| | | is_training: bool = True, |
| | | sort_size: int = 1024, |
| | | start_step: int = 0, |
| | | **kwargs, |
| | | ): |
| | | |
| | |
| | | self.sort_size = sort_size * num_replicas |
| | | self.max_token_length = kwargs.get("max_token_length", 2048) |
| | | self.length_scale_source = kwargs.get("length_scale_source", 1.0) |
| | | self.start_step = kwargs.get("start_step", 2048) |
| | | self.batch_size_sample_max = kwargs.get("batch_size_sample_max", 200) |
| | | |
| | | super().__init__( |
| | | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last |
| | | ) |
| | | self.start_step = start_step |
| | | self.batch_num = 1 |
| | | if self.start_step > 0: |
| | | logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}") |
| | | # super().__init__( |
| | | # dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last |
| | | # ) |
| | | |
| | | def __iter__(self): |
| | | if self.shuffle: |
| | |
| | | rank_batches[i % self.num_replicas].append(batch) |
| | | |
| | | # Assign all batches for the current rank directly |
| | | final_batches = rank_batches[self.rank] # [self.start_step :] |
| | | final_batches = rank_batches[self.rank][self.start_step :] |
| | | self.batch_num = len(final_batches) |
| | | |
| | | logging.info( |
| | | f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {self.batch_num}" |
| | | f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {rank_batches[self.rank]}, after: {self.batch_num}" |
| | | ) |
| | | return iter(final_batches) |
| | | |