| | |
| | | with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False): |
| | | yield |
| | | else: |
| | | if dtype == torch.float16: |
| | | with autocast(enabled=True): |
| | | if dtype == torch.float16 or dtype == torch.bfloat16: |
| | | with autocast(enabled=True, dtype=dtype): |
| | | yield |
| | | else: |
| | | yield |
| | |
| | | use_ddp: bool = False, |
| | | use_fsdp: bool = False, |
| | | use_fp16: bool = False, |
| | | use_bf16: bool = False, |
| | | use_deepspeed: bool = False, |
| | | output_dir: str = "./", |
| | | **kwargs, |
| | |
| | | self.batch_total = 0 |
| | | self.dtype = torch.float32 |
| | | self.use_fp16 = use_fp16 |
| | | self.use_bf16 = use_bf16 |
| | | if self.use_fp16: |
| | | self.dtype = torch.float16 |
| | | if self.use_bf16: |
| | | self.dtype = torch.bfloat16 |
| | | self.save_checkpoint_interval = kwargs.get("save_checkpoint_interval", 5000) |
| | | self.validate_interval = kwargs.get("validate_interval", 5000) |
| | | self.keep_nbest_models = kwargs.get("keep_nbest_models", 500) |
| | |
| | | scaled_loss = model.backward(loss) |
| | | else: |
| | | loss = loss / self.accum_grad |
| | | if self.use_fp16: |
| | | if self.use_fp16 or self.use_bf16: |
| | | scaler.scale(loss).backward() |
| | | else: |
| | | loss.backward() |
| | |
| | | # Execute an optimization step (update model parameters) |
| | | if self.use_ddp or self.use_fsdp: |
| | | dist.barrier() |
| | | if self.use_fp16: |
| | | if self.use_fp16 or self.use_bf16: |
| | | scaler.step(optim) |
| | | scaler.update() |
| | | else: |