| | |
| | | self.sort_size = sort_size * num_replicas |
| | | self.max_token_length = kwargs.get("max_token_length", 2048) |
| | | self.length_scale_source = kwargs.get("length_scale_source", 1.0) |
| | | self.start_step = kwargs.get("start_step", 2048) |
| | | self.batch_size_sample_max = kwargs.get("batch_size_sample_max", 200) |
| | | |
| | | super().__init__( |
| | | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last |
| | | ) |
| | |
| | | ) |
| | | batch = [] |
| | | max_len_in_batch = 0 |
| | | count = 0 |
| | | for idx in buffer: |
| | | original_sample_length = self.dataset.get_source_len(idx) |
| | | if original_sample_length > self.max_token_length: |
| | | continue |
| | | sample_length = 1 if self.batch_type == "example" else original_sample_length |
| | | potential_batch_length = max(max_len_in_batch, sample_length) * (len(batch) + 1) |
| | | if potential_batch_length <= self.batch_size: |
| | | if ( |
| | | potential_batch_length <= self.batch_size |
| | | and count <= self.batch_size_sample_max |
| | | ): |
| | | batch.append(idx) |
| | | max_len_in_batch = max(max_len_in_batch, sample_length) |
| | | count += 1 |
| | | else: |
| | | buffer_batches.append(batch) |
| | | batch = [idx] |
| | | max_len_in_batch = sample_length |
| | | count = 0 |
| | | if batch: |
| | | buffer_batches.append(batch) |
| | | |
| | |
| | | rank_batches[i % self.num_replicas].append(batch) |
| | | |
| | | # Assign all batches for the current rank directly |
| | | final_batches = rank_batches[self.rank] |
| | | final_batches = rank_batches[self.rank] # [self.start_step :] |
| | | self.batch_num = len(final_batches) |
| | | |
| | | logging.info( |
| | | f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {self.batch_num}" |
| | | ) |
| | | return iter(final_batches) |
| | | |
| | | def __len__(self): |
| | | |
| | | return 1 |
| | | # Calculate the number of batches per epoch for the current rank |
| | | return self.batch_num |
| | | |
| | | def set_epoch(self, epoch): |
| | | self.epoch = epoch |