| | |
| | | self.max_token_length = kwargs.get("max_token_length", 5000) |
| | | self.shuffle_idx = np.arange(self.total_samples) |
| | | self.shuffle = shuffle and is_training |
| | | self.length_scale_source = kwargs.get("length_scale_source", 1.0) |
| | | |
| | | |
| | | def __len__(self): |
| | | return (self.total_samples-1) // self.batch_size + 1 |
| | |
| | | |
| | | idx_map = self.shuffle_idx[idx] |
| | | # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] |
| | | sample_len_cur = self.dataset.get_source_len(idx_map) + \ |
| | | self.dataset.get_target_len(idx_map) |
| | | target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0 |
| | | source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source |
| | | sample_len_cur = source_len + target_len |
| | | |
| | | |
| | | datalen_with_index.append([idx, sample_len_cur]) |
| | | |
| | |
| | | |
| | | max_token_cur = max(max_token, sample_len_cur_raw) |
| | | max_token_padding = 1 + num_sample |
| | | if self.batch_type == 'length': |
| | | if self.batch_type != 'example': |
| | | max_token_padding *= max_token_cur |
| | | if max_token_padding <= self.batch_size: |
| | | batch.append(idx) |