| | |
| | | self.world_size = world_size |
| | | self.use_ddp = use_ddp |
| | | self.use_fsdp = use_fsdp |
| | | self.use_deepspeed = use_deepspeed |
| | | |
| | | self.device = kwargs.get("device", "cuda") |
| | | |
| | | self.output_dir = output_dir |
| | |
| | | self.writer = SummaryWriter(tensorboard_dir) # if trainer.rank == 0 else None |
| | | except: |
| | | self.writer = None |
| | | |
| | | self.use_deepspeed = use_deepspeed |
| | | self.deepspeed_config = kwargs.get("deepspeed_config", "") |
| | | |
| | | def save_checkpoint( |
| | | self, |
| | |
| | | iterator_stop = torch.tensor(0).to(self.device) |
| | | |
| | | def forward_step(self, model, batch, loss_dict={}): |
| | | with maybe_autocast(self.use_fp16): |
| | | dtype = torch.bfloat16 |
| | | with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False): |
| | | retval = model(**batch) |
| | | |
| | | loss, stats, weight = retval |
| | |
| | | else: |
| | | loss.backward() |
| | | |
| | | def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict): |
| | | def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=None): |
| | | |
| | | if self.use_deepspeed: |
| | | model.step() |
| | |
| | | loss = loss_dict["loss"].detach().cpu().item() |
| | | epoch = loss_dict["epoch"] |
| | | batch_idx = loss_dict["batch_idx"] |
| | | step_in_epoch = loss_dict["step_in_epoch"] |
| | | step_in_epoch = self.step_in_epoch |
| | | batch_num_epoch = loss_dict["batch_num_epoch"] |
| | | lr = loss_dict["lr"] |
| | | |
| | |
| | | "find_unused_parameters", False |
| | | ), |
| | | ) |
| | | # elif self.use_fsdp: |
| | | # # model = FSDP(model).cuda(local_rank) |
| | | # |
| | | # def custom_auto_wrap_policy( |
| | | # module: nn.Module, |
| | | # recurse: bool, |
| | | # nonwrapped_numel: int, |
| | | # # Additional custom arguments |
| | | # min_num_params: int = int(1e8), |
| | | # ) -> bool: |
| | | # # 根据自定义逻辑决定是否包装模块 |
| | | # is_large = unwrapped_params >= min_num_params |
| | | # requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1 |
| | | # return is_large and requires_grad_uniform |
| | | # |
| | | # # Configure a custom `min_num_params` |
| | | # my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) |
| | | # torch.cuda.set_device(local_rank) |
| | | # model = FSDP( |
| | | # model, |
| | | # auto_wrap_policy=custom_auto_wrap_policy, |
| | | # mixed_precision=None, |
| | | # device_id=torch.cuda.current_device(), |
| | | # ) |
| | | |
| | | else: |
| | | model = model.to(device=kwargs.get("device", "cuda")) |
| | | |
| | | return model |
| | | |
| | | def warp_optim_scheduler(self, model, **kwargs): |
| | | from funasr.optimizers import optim_classes |
| | | from funasr.schedulers import scheduler_classes |
| | | from omegaconf import OmegaConf, DictConfig |
| | | import json |
| | | import deepspeed |
| | | |
| | | # optim |
| | | logging.info("Build optim") |
| | |
| | | scheduler_class = scheduler_classes.get(scheduler) |
| | | scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf")) |
| | | |
| | | if use_deepspeed: |
| | | deepspeed_config = kwargs.get("deepspeed_config", "") |
| | | with open(deepspeed_config, "r") as fin: |
| | | if self.use_deepspeed: |
| | | |
| | | args = OmegaConf.create({"deepspeed_config": self.deepspeed_config}) |
| | | with open(self.deepspeed_config, "r") as fin: |
| | | ds_configs = json.load(fin) |
| | | if "optimizer" in ds_configs: |
| | | # NOTE(xcsong): Disable custom optimizer if it is set in ds_config, |
| | | # extremely useful when enable cpu_offload, DeepspeedCpuAdam |
| | | # could be 4~5x faster than torch native adam |
| | | deepspeed_config = None |
| | | optim = None |
| | | if "scheduler" in ds_configs: |
| | | scheduler = None |
| | | else: |
| | |
| | | def scheduler(opt): |
| | | return scheduler_class(opt, **kwargs.get("scheduler_conf")) |
| | | |
| | | args = OmegaConf.create({"deepspeed_config": deepspeed_config}) |
| | | model, optimizer, _, scheduler = deepspeed.initialize( |
| | | args=args, |
| | | model=model, |