| | |
| | | time1 = time.perf_counter() |
| | | |
| | | for data_split_i in range(trainer.start_data_split_i, dataloader.data_split_num): |
| | | time_slice_i = time.perf_counter() |
| | | |
| | | dataloader_tr, dataloader_val = dataloader.build_iter( |
| | | epoch, data_split_i=data_split_i, start_step=trainer.start_step |
| | | ) |
| | |
| | | |
| | | torch.cuda.empty_cache() |
| | | |
| | | time_escaped = (time.perf_counter() - time1) / 3600.0 |
| | | logging.info( |
| | | f"rank: {local_rank}, " |
| | | f"time_escaped_epoch: {time_escaped:.3f} hours, " |
| | | f"estimated to finish {dataloader.data_split_num} data_slices, remaining: {(dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours" |
| | | f"epoch: {((trainer.max_epoch - epoch - 1)*dataloader.data_split_num + dataloader.data_split_num-data_split_i)*time_escaped:.3f} hours\n" |
| | | ) |
| | | |
| | | trainer.start_data_split_i = 0 |
| | | trainer.validate_epoch(model=model, dataloader_val=dataloader_val, epoch=epoch + 1) |
| | | scheduler.step() |
| | |
| | | drop_last=False, |
| | | is_training: bool = True, |
| | | sort_size: int = 1024, |
| | | start_step: int = 0, |
| | | **kwargs, |
| | | ): |
| | | |
| | |
| | | self.sort_size = sort_size * num_replicas |
| | | self.max_token_length = kwargs.get("max_token_length", 2048) |
| | | self.length_scale_source = kwargs.get("length_scale_source", 1.0) |
| | | self.start_step = kwargs.get("start_step", 2048) |
| | | self.batch_size_sample_max = kwargs.get("batch_size_sample_max", 200) |
| | | |
| | | super().__init__( |
| | | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last |
| | | ) |
| | | self.start_step = start_step |
| | | self.batch_num = 1 |
| | | if self.start_step > 0: |
| | | logging.info(f"Warning, start_step > 0, dataloader start from step: {self.start_step}") |
| | | # super().__init__( |
| | | # dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle, drop_last=drop_last |
| | | # ) |
| | | |
| | | def __iter__(self): |
| | | if self.shuffle: |
| | |
| | | rank_batches[i % self.num_replicas].append(batch) |
| | | |
| | | # Assign all batches for the current rank directly |
| | | final_batches = rank_batches[self.rank] # [self.start_step :] |
| | | final_batches = rank_batches[self.rank][self.start_step :] |
| | | self.batch_num = len(final_batches) |
| | | |
| | | logging.info( |
| | | f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {self.batch_num}" |
| | | f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {rank_batches[self.rank]}, after: {self.batch_num}" |
| | | ) |
| | | return iter(final_batches) |
| | | |
| | |
| | | def __init__(self, frontend=None, tokenizer=None, **kwargs): |
| | | # dataset |
| | | logging.info("Build dataloader") |
| | | |
| | | dataset_class = tables.dataset_classes.get(kwargs.get("dataset", "AudioDataset")) |
| | | dataset_tr = dataset_class( |
| | | kwargs.get("train_data_set_list"), |
| | | frontend=frontend, |
| | | tokenizer=tokenizer, |
| | | is_training=True, |
| | | **kwargs.get("dataset_conf"), |
| | | ) |
| | | dataset_tr = None |
| | | # split dataset |
| | | self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1) |
| | | if self.data_split_num == 1: |
| | | dataset_tr = dataset_class( |
| | | kwargs.get("train_data_set_list"), |
| | | frontend=frontend, |
| | | tokenizer=tokenizer, |
| | | is_training=True, |
| | | **kwargs.get("dataset_conf"), |
| | | ) |
| | | dataset_val = dataset_class( |
| | | kwargs.get("valid_data_set_list"), |
| | | frontend=frontend, |
| | |
| | | self.dataset_val = dataset_val |
| | | self.kwargs = kwargs |
| | | |
| | | # split dataset |
| | | self.data_split_num = kwargs["dataset_conf"].get("data_split_num", 1) |
| | | self.dataset_class = dataset_class |
| | | self.frontend = frontend |
| | | self.tokenizer = tokenizer |
| | |
| | | Args: |
| | | epoch (int): The epoch number at which the checkpoint is being saved. |
| | | """ |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | step_in_epoch = None if step is None else step_in_epoch |
| | | if self.use_deepspeed: |
| | | |
| | |
| | | ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}' |
| | | self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg |
| | | self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg |
| | | |
| | | if self.use_ddp or self.use_fsdp or self.use_deepspeed: |
| | | dist.barrier() |
| | | |
| | | model.train() |
| | | |
| | | def log( |