| | |
| | | max_len_in_batch = 0 # Tracks the max sample length within the current batch |
| | | |
| | | for idx in sorted_indices: |
| | | original_sample_length = self.dataset.get_source_len(idx) |
| | | if ( |
| | | original_sample_length < self.min_token_length |
| | | or original_sample_length > self.max_token_length |
| | | ): # Skip samples that exceed the max length |
| | | continue |
| | | |
| | | # original_sample_length = self.dataset.get_source_len(idx) |
| | | # if ( |
| | | # original_sample_length < self.min_token_length |
| | | # or original_sample_length > self.max_token_length |
| | | # ): # Skip samples that exceed the max length |
| | | # continue |
| | | |
| | | # sample_length = 1 if self.batch_type == "example" else original_sample_length |
| | | |
| | | # Set sample_length based on the batch type |
| | | sample_length = 1 if self.batch_type == "example" else original_sample_length |
| | | if self.batch_type == "example": |
| | | sample_length = 1 |
| | | elif self.batch_type == "token": |
| | | sample_length = self.dataset.get_source_len(idx) + int( |
| | | self.dataset.get_target_len(idx) * 1.2 |
| | | ) |
| | | else: |
| | | sample_length = self.dataset.get_source_len(idx) |
| | | # Calculate potential batch size with the new sample |
| | | potential_batch_length = max(max_len_in_batch, sample_length) * (len(batch) + 1) |
| | | # Add index to batch if it doesn't exceed batch size limit |