| | |
| | | frontend=frontend, |
| | | tokenizer=tokenizer, |
| | | is_training=True, |
| | | **kwargs.get("dataset_conf") |
| | | **kwargs.get("dataset_conf"), |
| | | ) |
| | | dataset_val = dataset_class( |
| | | kwargs.get("valid_data_set_list"), |
| | | frontend=frontend, |
| | | tokenizer=tokenizer, |
| | | is_training=False, |
| | | **kwargs.get("dataset_conf") |
| | | **kwargs.get("dataset_conf"), |
| | | ) |
| | | |
| | | # dataloader |
| | |
| | | frontend=frontend, |
| | | tokenizer=tokenizer, |
| | | is_training=True, |
| | | **kwargs.get("dataset_conf") |
| | | **kwargs.get("dataset_conf"), |
| | | ) |
| | | dataset_val = dataset_class( |
| | | kwargs.get("valid_data_set_list"), |
| | | frontend=frontend, |
| | | tokenizer=tokenizer, |
| | | is_training=False, |
| | | **kwargs.get("dataset_conf") |
| | | **kwargs.get("dataset_conf"), |
| | | ) |
| | | |
| | | self.dataset_tr = dataset_tr |
| | |
| | | self.tokenizer = tokenizer |
| | | self.kwargs = kwargs |
| | | |
| | | def build_iter(self, epoch=0, data_split_i=0, **kwargs): |
| | | def build_iter(self, epoch=0, data_split_i=0, start_step=0, **kwargs): |
| | | |
| | | # reload dataset slice |
| | | if self.data_split_num > 1: |
| | |
| | | tokenizer=self.tokenizer, |
| | | is_training=True, |
| | | **self.kwargs.get("dataset_conf"), |
| | | data_split_i=data_split_i |
| | | data_split_i=data_split_i, |
| | | ) |
| | | |
| | | # dataloader |
| | |
| | | batch_sampler_val = None |
| | | if batch_sampler is not None: |
| | | batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) |
| | | batch_sampler = batch_sampler_class(self.dataset_tr, **self.kwargs.get("dataset_conf")) |
| | | batch_sampler = batch_sampler_class( |
| | | self.dataset_tr, start_step=start_step, **self.kwargs.get("dataset_conf") |
| | | ) |
| | | batch_sampler_val = batch_sampler_class( |
| | | self.dataset_val, is_training=False, **self.kwargs.get("dataset_conf") |
| | | ) |
| | |
| | | frontend=frontend, |
| | | tokenizer=tokenizer, |
| | | is_training=True, |
| | | **kwargs.get("dataset_conf") |
| | | **kwargs.get("dataset_conf"), |
| | | ) |
| | | dataset_val = dataset_class( |
| | | kwargs.get("valid_data_set_list"), |
| | | frontend=frontend, |
| | | tokenizer=tokenizer, |
| | | is_training=False, |
| | | **kwargs.get("dataset_conf") |
| | | **kwargs.get("dataset_conf"), |
| | | ) |
| | | |
| | | return dataset_tr, dataset_val |