| | |
| | | yield |
| | | else: |
| | | if dtype == torch.float16 or dtype == torch.bfloat16: |
| | | yield |
| | | # with autocast(enabled=True, dtype=dtype): |
| | | # yield |
| | | with autocast(enabled=True, dtype=dtype): |
| | | yield |
| | | else: |
| | | yield |
| | | |
| | |
| | | ) |
| | | else: |
| | | print("Undo") |
| | | self.saved_ckpts[ckpt_name] = getattr( |
| | | self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch" |
| | | )[ckpt_name] |
| | | if self.keep_nbest_models > 0: |
| | | if len(self.saved_ckpts) > self.keep_nbest_models: |
| | | if self.avg_keep_nbest_models_type == "acc": |
| | | key = min(self.saved_ckpts, key=self.saved_ckpts.get) |
| | | else: |
| | | key = max(self.saved_ckpts, key=self.saved_ckpts.get) |
| | | if key in self.saved_ckpts: |
| | | del self.saved_ckpts[key] |
| | | filename = os.path.join(self.output_dir, key) |
| | | logging.info(f"Delete: {filename}") |
| | | if os.path.exists(filename): |
| | | # os.remove(filename) |
| | | misc_utils.smart_remove(filename) |
| | | if self.rank == 0: |
| | | self.saved_ckpts[ckpt_name] = getattr( |
| | | self, f"val_{self.avg_keep_nbest_models_type}_step_or_epoch" |
| | | )[ckpt_name] |
| | | if self.keep_nbest_models > 0: |
| | | if len(self.saved_ckpts) > self.keep_nbest_models: |
| | | if self.avg_keep_nbest_models_type == "acc": |
| | | key = min(self.saved_ckpts, key=self.saved_ckpts.get) |
| | | else: |
| | | key = max(self.saved_ckpts, key=self.saved_ckpts.get) |
| | | if key in self.saved_ckpts: |
| | | del self.saved_ckpts[key] |
| | | filename = os.path.join(self.output_dir, key) |
| | | logging.info(f"Delete: {filename}") |
| | | if os.path.exists(filename): |
| | | # os.remove(filename) |
| | | misc_utils.smart_remove(filename) |
| | | |
| | | elif self.use_fsdp: |
| | | pass |
| | |
| | | scaled_loss = model.backward(loss) |
| | | else: |
| | | loss = loss / self.accum_grad |
| | | if self.use_fp16 or self.use_bf16: |
| | | if scaler: |
| | | scaler.scale(loss).backward() |
| | | else: |
| | | loss.backward() |
| | |
| | | # Execute an optimization step (update model parameters) |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | if self.use_fp16 or self.use_bf16: |
| | | if scaler: |
| | | scaler.step(optim) |
| | | scaler.update() |
| | | else: |
| | |
| | | Args: |
| | | epoch (int): The current epoch number. |
| | | """ |
| | | self.val_loss_avg = 0.0 |
| | | self.val_acc_avg = 0.0 |
| | | |
| | | if self.use_ddp or self.use_fsdp or self.use_deepspeed: |
| | | dist.barrier() |
| | | logging.info(f"Validate epoch: {epoch}, rank: {self.rank}\n") |
| | |
| | | "data_split_i": kwargs.get("data_split_i", 0), |
| | | "data_split_num": kwargs.get("data_split_num", 1), |
| | | "log_step": batch_idx + kwargs.get("start_step", 0), |
| | | "batch_total": batch_idx + 1, |
| | | "batch_total": self.batch_total, |
| | | "step_in_epoch": batch_idx + 1, |
| | | "lr": 0.0, |
| | | } |
| | |
| | | if self.use_wandb and wandb is not None: |
| | | wandb.log( |
| | | description_dict, |
| | | setp=batch_total, |
| | | step=batch_total, |
| | | ) |
| | | |
| | | def close(self, writer=None): |