| | |
| | | |
| | | class BatchSampler(torch.utils.data.BatchSampler): |
| | | |
| | | def __init__(self, dataset, batch_size_type: str="example", batch_size: int=14, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs): |
| | | def __init__(self, dataset, batch_size_type: str="example", batch_size: int=100, sort_size: int=30, drop_last: bool=False, shuffle: bool=True, **kwargs): |
| | | |
| | | self.drop_last = drop_last |
| | | self.pre_idx = -1 |
| | |
| | | |
| | | idx_map = self.shuffle_idx[idx] |
| | | # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] |
| | | sample_len_cur = self.dataset.indexed_dataset[idx_map]["source_len"] + \ |
| | | self.dataset.indexed_dataset[idx_map]["target_len"] |
| | | sample_len_cur = self.dataset.indexed_dataset.get_source_len(self.dataset.indexed_dataset[idx_map]) + \ |
| | | self.dataset.indexed_dataset.get_target_len(self.dataset.indexed_dataset[idx_map]) |
| | | |
| | | datalen_with_index.append([idx, sample_len_cur]) |
| | | |