| | |
| | | drop_last=False, |
| | | is_training: bool = True, |
| | | sort_size: int = 1024, |
| | | start_step: int = 0, |
| | | **kwargs, |
| | | ): |
| | | |
| | |
| | | self.max_token_length = kwargs.get("max_token_length", 2048) |
| | | self.min_token_length = kwargs.get("min_token_length", 0) |
| | | self.length_scale_source = kwargs.get("length_scale_source", 1.0) |
| | | |
| | | self.start_step = 0 |
| | | if self.start_step > 0: |
| | | logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}") |
| | | # super().__init__(dataset, num_replicas=num_replicas, rank=rank, |
| | | # shuffle=shuffle, drop_last=drop_last) |
| | | |
| | |
| | | max_len_in_batch = 0 # Tracks the max sample length within the current batch |
| | | |
| | | for idx in sorted_indices: |
| | | |
| | | |
| | | # original_sample_length = self.dataset.get_source_len(idx) |
| | | # if ( |
| | | # original_sample_length < self.min_token_length |
| | |
| | | # Allocate the batches to the current rank |
| | | start_idx = self.rank * batches_per_rank |
| | | end_idx = start_idx + batches_per_rank |
| | | rank_batches = buffer_batches[start_idx:end_idx] |
| | | rank_batches = buffer_batches[start_idx + self.start_step : end_idx] |
| | | |
| | | # Return an iterator over the batches for the current rank |
| | | return iter(rank_batches) |