| | |
| | | job_type="training", |
| | | reinit=True, |
| | | ) |
| | | tensorboard_dir = os.path.join(output_dir, "tensorboard") |
| | | os.makedirs(tensorboard_dir, exist_ok=True) |
| | | try: |
| | | from tensorboardX import SummaryWriter |
| | | |
| | | self.writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None |
| | | except: |
| | | self.writer = None |
| | | |
| | | def save_checkpoint( |
| | | self, |
| | |
| | | dataloader_train=None, |
| | | dataloader_val=None, |
| | | epoch=None, |
| | | writer=None, |
| | | **kwargs, |
| | | ): |
| | | """ |
| | |
| | | time_beg = time.perf_counter() |
| | | time5 = time_beg |
| | | for batch_idx, batch in enumerate(dataloader_train): |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.all_reduce(iterator_stop, dist.ReduceOp.SUM) |
| | | if iterator_stop > 0: |
| | | break |
| | | loss_dict = { |
| | | "speed_stats": {}, |
| | | "epoch": epoch, |
| | | "batch_idx": batch_idx, |
| | | "data_split_i": kwargs.get("data_split_i", 0), |
| | | "data_split_num": kwargs.get("data_split_num", 1), |
| | | "log_step": batch_idx + kwargs.get("start_step", 0), |
| | | } |
| | | |
| | | self.batch_total += 1 |
| | | self.step_in_epoch += 1 |
| | | time1 = time.perf_counter() |
| | | speed_stats["data_load"] = f"{time1-time_beg:0.3f}" |
| | | loss_dict["speed_stats"]["data_load"] = f"{time1-time_beg:0.3f}" |
| | | |
| | | batch = to_device(batch, self.device) |
| | | |
| | |
| | | my_context = model.no_sync if batch_idx % accum_grad != 0 else my_context |
| | | with my_context(): |
| | | time2 = time.perf_counter() |
| | | loss_dict = {} |
| | | |
| | | self.forward_step(model, batch, loss_dict=loss_dict) |
| | | |
| | | time3 = time.perf_counter() |
| | | speed_stats["forward_time"] = f"{time3 - time2:0.3f}" |
| | | loss_dict["speed_stats"]["forward_time"] = f"{time3 - time2:0.3f}" |
| | | self.backward_step(model, scaler, loss_dict=loss_dict) |
| | | |
| | | time4 = time.perf_counter() |
| | | speed_stats["backward_and_AllReaduce_time"] = f"{time4 - time3:0.3f}" |
| | | |
| | | # self.train_loss_avg = ( |
| | | # self.train_loss_avg * (batch_idx + kwargs.get("start_step", 0)) |
| | | # + loss.detach().cpu().item() |
| | | # ) / (batch_idx + kwargs.get("start_step", 0) + 1) |
| | | # if "acc" in stats: |
| | | # self.train_acc_avg = ( |
| | | # self.train_acc_avg * (batch_idx + kwargs.get("start_step", 0)) |
| | | # + stats["acc"].detach().cpu().item() |
| | | # ) / (batch_idx + kwargs.get("start_step", 0) + 1) |
| | | loss_dict["speed_stats"]["backward_time"] = f"{time4 - time3:0.3f}" |
| | | |
| | | self.update_step(model, optim, scheduler, scaler, loss_dict) |
| | | # Perform an optimizer step only after accumulating enough gradients |
| | | total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}" |
| | | time5 = time.perf_counter() |
| | | |
| | | loss_dict["speed_stats"]["optim_time"] = f"{time5 - time4:0.3f}" |
| | | |
| | | loss_dict["speed_stats"]["total_time"] = total_time |
| | | |
| | | loss_dict["lr"] = scheduler.get_last_lr()[0] |
| | | loss_dict["batch_num_epoch"] = len(dataloader_train) |
| | | |
| | | self.log(loss_dict, tag="train") |
| | | |
| | | if self.step_in_epoch % self.validate_interval == 0: |
| | | self.validate_epoch( |
| | |
| | | with maybe_autocast(self.use_fp16): |
| | | retval = model(**batch) |
| | | |
| | | if ( |
| | | self.reset_gpu_cache |
| | | and (torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024) > 70 |
| | | ): |
| | | torch.cuda.empty_cache() |
| | | |
| | | loss, stats, weight = retval |
| | | stats = {k: v for k, v in stats.items() if v is not None} |
| | | # if self.use_ddp or self.use_fsdp: |
| | | # # Apply weighted averaging for loss and stats |
| | | # loss = (loss * weight.type(loss.dtype)).sum() |
| | | # # if distributed, this method can also apply all_reduce() |
| | | # # stats, weight = recursive_average(stats, weight, distributed=True) |
| | | # if self.use_ddp or self.use_fsdp: |
| | | # dist.all_reduce(weight, op=dist.ReduceOp.SUM) |
| | | # # Now weight is summation over all workers |
| | | # loss /= weight.sum() # shape:[1] -> shape:[] |
| | | # # Multiply world_size because DistributedDataParallel |
| | | # # automatically normalizes the gradient by world_size. |
| | | # loss *= self.world_size |
| | | # loss *= self.world_size |
| | | # Scale the loss since we're not updating for every mini-batch |
| | | |
| | | loss_dict["loss"] = loss |
| | | loss_dict["stats"] = stats |
| | |
| | | loss.backward() |
| | | |
| | | def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict): |
| | | |
| | | if self.use_deepspeed: |
| | | model.step() |
| | | else: |
| | | if (batch_idx + 1) % self.accum_grad == 0: |
| | | # Perform gradient clipping if it is set |
| | | if self.grad_clip > 0: |
| | |
| | | norm_type=self.grad_clip_type, |
| | | ) |
| | | if not torch.isfinite(grad_norm): |
| | | logging.warning(f"The grad norm is {grad_norm}. Skipping updating the model.") |
| | | logging.warning( |
| | | f"The grad norm is {grad_norm}. Skipping updating the model." |
| | | ) |
| | | optim.zero_grad() # Reset gradients |
| | | return |
| | | |
| | |
| | | scheduler.step() |
| | | # Clear gradients for the next accumulation stage |
| | | optim.zero_grad(set_to_none=True) |
| | | |
| | | if self.use_ddp or self.use_fsdp: |
| | | train_loss_avg = torch.tensor(self.train_loss_avg, dtype=torch.float32).to( |
| | | self.device |
| | | ) |
| | | train_acc_avg = torch.tensor(self.train_acc_avg, dtype=torch.float32).to( |
| | | self.device |
| | | ) |
| | | dist.all_reduce(train_loss_avg, op=dist.ReduceOp.SUM) |
| | | dist.all_reduce(train_acc_avg, op=dist.ReduceOp.SUM) |
| | | self.train_loss_avg = train_loss_avg.detach().cpu().item() / self.world_size |
| | | self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size |
| | | |
| | | total_time = f"{(time.perf_counter() - time5) / accum_grad:0.3f}" |
| | | time5 = time.perf_counter() |
| | | |
| | | speed_stats["optim_time"] = f"{time5 - time4:0.3f}" |
| | | |
| | | speed_stats["total_time"] = total_time |
| | | lr = scheduler.get_last_lr()[0] |
| | | batch_num_epoch = 1 |
| | | if hasattr(dataloader_train, "__len__"): |
| | | batch_num_epoch = len(dataloader_train) |
| | | self.log( |
| | | epoch, |
| | | batch_idx, |
| | | log_step=batch_idx + kwargs.get("start_step", 0), |
| | | step_in_epoch=self.step_in_epoch, |
| | | batch_num_epoch=batch_num_epoch, |
| | | lr=lr, |
| | | loss=loss.detach().cpu().item(), |
| | | speed_stats=speed_stats, |
| | | stats=stats, |
| | | writer=writer, |
| | | tag="train", |
| | | data_split_i=kwargs.get("data_split_i", 0), |
| | | data_split_num=kwargs.get("data_split_num", 1), |
| | | ) |
| | | |
| | | def validate_epoch( |
| | | self, |
| | |
| | | |
| | | def log( |
| | | self, |
| | | epoch=0, |
| | | batch_idx=0, |
| | | step_in_epoch=0, |
| | | batch_num_epoch=-1, |
| | | lr=0.0, |
| | | loss=0.0, |
| | | speed_stats=None, |
| | | stats=None, |
| | | writer=None, |
| | | loss_dict: dict = None, |
| | | tag="train", |
| | | data_split_i=0, |
| | | data_split_num=1, |
| | | log_step=None, |
| | | **kwargs, |
| | | ): |
| | | loss = loss_dict["loss"].detach().cpu().item() |
| | | epoch = loss_dict["epoch"] |
| | | batch_idx = loss_dict["batch_idx"] |
| | | step_in_epoch = loss_dict["step_in_epoch"] |
| | | batch_num_epoch = loss_dict["batch_num_epoch"] |
| | | lr = loss_dict["lr"] |
| | | |
| | | speed_stats = loss_dict["speed_stats"] |
| | | stats = loss_dict["stats"] |
| | | data_split_i = loss_dict["data_split_i"] |
| | | data_split_num = loss_dict["data_split_num"] |
| | | log_step = loss_dict.get("log_step", None) |
| | | |
| | | if (batch_idx + 1) % self.log_interval == 0: |
| | | batch_idx = log_step if log_step is not None else batch_idx |
| | |
| | | f"rank{self.rank}_lr/{tag}": lr, |
| | | } |
| | | |
| | | writer = self.writer |
| | | if writer is not None: |
| | | writer.add_scalar(f"rank{self.rank}_loss/{tag}", loss, self.batch_total) |
| | | writer.add_scalar(f"rank{self.rank}_lr/{tag}", lr, self.batch_total) |
| | |
| | | model = model.to(device=kwargs.get("device", "cuda")) |
| | | |
| | | return model |
| | | |
| | | def warp_optim_scheduler(self, model, **kwargs): |
| | | |
| | | # optim |
| | | logging.info("Build optim") |
| | | optim = kwargs.get("optim", "adam") |
| | | assert optim in optim_classes |
| | | optim_class = optim_classes.get(optim) |
| | | optim = optim_class(model.parameters(), **kwargs.get("optim_conf")) |
| | | |
| | | # scheduler |
| | | logging.info("Build scheduler") |
| | | scheduler = kwargs.get("scheduler", "warmuplr") |
| | | assert scheduler in scheduler_classes |
| | | scheduler_class = scheduler_classes.get(scheduler) |
| | | scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) |
| | | |
| | | if use_deepspeed: |
| | | deepspeed_config = kwargs.get("deepspeed_config", "") |
| | | with open(deepspeed_config, "r") as fin: |
| | | ds_configs = json.load(fin) |
| | | if "optimizer" in ds_configs: |
| | | # NOTE(xcsong): Disable custom optimizer if it is set in ds_config, |
| | | # extremely useful when enable cpu_offload, DeepspeedCpuAdam |
| | | # could be 4~5x faster than torch native adam |
| | | deepspeed_config = None |
| | | if "scheduler" in ds_configs: |
| | | scheduler = None |
| | | else: |
| | | |
| | | def scheduler(opt): |
| | | return scheduler_class(opt, **kwargs.get("scheduler_conf")) |
| | | |
| | | args = OmegaConf.create({"deepspeed_config": deepspeed_config}) |
| | | model, optimizer, _, scheduler = deepspeed.initialize( |
| | | args=args, |
| | | model=model, |
| | | optimizer=optim, |
| | | lr_scheduler=scheduler, |
| | | model_parameters=model.parameters(), |
| | | ) |
| | | |
| | | return model, optim, scheduler |