| | |
| | | yield |
| | | else: |
| | | if dtype == torch.float16 or dtype == torch.bfloat16: |
| | | yield |
| | | # with autocast(enabled=True, dtype=dtype): |
| | | # yield |
| | | with autocast(enabled=True, dtype=dtype): |
| | | yield |
| | | else: |
| | | yield |
| | | |
| | |
| | | scaled_loss = model.backward(loss) |
| | | else: |
| | | loss = loss / self.accum_grad |
| | | if self.use_fp16 or self.use_bf16: |
| | | if scaler: |
| | | scaler.scale(loss).backward() |
| | | else: |
| | | loss.backward() |
| | |
| | | # Execute an optimization step (update model parameters) |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | if self.use_fp16 or self.use_bf16: |
| | | if scaler: |
| | | scaler.step(optim) |
| | | scaler.update() |
| | | else: |
| | |
| | | Args: |
| | | epoch (int): The current epoch number. |
| | | """ |
| | | self.val_loss_avg = 0.0 |
| | | self.val_acc_avg = 0.0 |
| | | |
| | | if self.use_ddp or self.use_fsdp or self.use_deepspeed: |
| | | dist.barrier() |
| | | logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n") |
| | |
| | | "data_split_i": kwargs.get("data_split_i", 0), |
| | | "data_split_num": kwargs.get("data_split_num", 1), |
| | | "log_step": batch_idx + kwargs.get("start_step", 0), |
| | | "batch_total": batch_idx + 1, |
| | | "batch_total": self.batch_total, |
| | | "step_in_epoch": batch_idx + 1, |
| | | "lr": 0.0, |
| | | } |
| | |
| | | if self.use_wandb and wandb is not None: |
| | | wandb.log( |
| | | description_dict, |
| | | setp=batch_total, |
| | | step=batch_total, |
| | | ) |
| | | |
| | | def close(self, writer=None): |