游雁
2024-05-17 d3ff05837bbc14749d09f44947633b87e8f2db0e
funasr/train_utils/trainer_ds.py
@@ -78,7 +78,7 @@
        self.world_size = world_size
        self.use_ddp = use_ddp
        self.use_fsdp = use_fsdp
        self.use_deepspeed = use_deepspeed
        self.device = kwargs.get("device", "cuda")
        self.output_dir = output_dir
@@ -136,6 +136,9 @@
            self.writer = SummaryWriter(tensorboard_dir)  # if trainer.rank == 0 else None
        except:
            self.writer = None
        self.use_deepspeed = use_deepspeed
        self.deepspeed_config = kwargs.get("deepspeed_config", "")
    def save_checkpoint(
        self,
@@ -443,7 +446,8 @@
            iterator_stop = torch.tensor(0).to(self.device)
    def forward_step(self, model, batch, loss_dict={}):
        with maybe_autocast(self.use_fp16):
        dtype = torch.bfloat16
        with torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False):
            retval = model(**batch)
        loss, stats, weight = retval
@@ -465,7 +469,7 @@
            else:
                loss.backward()
    def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=loss_dict):
    def update_step(self, model, optim, scheduler, scaler, batch_idx=0, loss_dict=None):
        if self.use_deepspeed:
            model.step()
@@ -613,7 +617,7 @@
        loss = loss_dict["loss"].detach().cpu().item()
        epoch = loss_dict["epoch"]
        batch_idx = loss_dict["batch_idx"]
        step_in_epoch = loss_dict["step_in_epoch"]
        step_in_epoch = self.step_in_epoch
        batch_num_epoch = loss_dict["batch_num_epoch"]
        lr = loss_dict["lr"]
@@ -732,36 +736,18 @@
                    "find_unused_parameters", False
                ),
            )
        # elif self.use_fsdp:
        #     # model = FSDP(model).cuda(local_rank)
        #
        #     def custom_auto_wrap_policy(
        #         module: nn.Module,
        #         recurse: bool,
        #         nonwrapped_numel: int,
        #         # Additional custom arguments
        #         min_num_params: int = int(1e8),
        #     ) -> bool:
        #         # 根据自定义逻辑决定是否包装模块
        #         is_large = unwrapped_params >= min_num_params
        #         requires_grad_uniform = len({p.requires_grad for p in module.parameters()}) == 1
        #         return is_large and requires_grad_uniform
        #
        #     # Configure a custom `min_num_params`
        #     my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5))
        #     torch.cuda.set_device(local_rank)
        #     model = FSDP(
        #         model,
        #         auto_wrap_policy=custom_auto_wrap_policy,
        #         mixed_precision=None,
        #         device_id=torch.cuda.current_device(),
        #     )
        else:
            model = model.to(device=kwargs.get("device", "cuda"))
        return model
    def warp_optim_scheduler(self, model, **kwargs):
        from funasr.optimizers import optim_classes
        from funasr.schedulers import scheduler_classes
        from omegaconf import OmegaConf, DictConfig
        import json
        import deepspeed
        # optim
        logging.info("Build optim")
@@ -777,15 +763,16 @@
        scheduler_class = scheduler_classes.get(scheduler)
        scheduler = scheduler_class(optim, **kwargs.get("scheduler_conf"))
        if use_deepspeed:
            deepspeed_config = kwargs.get("deepspeed_config", "")
            with open(deepspeed_config, "r") as fin:
        if self.use_deepspeed:
            args = OmegaConf.create({"deepspeed_config": self.deepspeed_config})
            with open(self.deepspeed_config, "r") as fin:
                ds_configs = json.load(fin)
            if "optimizer" in ds_configs:
                # NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
                # extremely useful when enable cpu_offload, DeepspeedCpuAdam
                # could be 4~5x faster than torch native adam
                deepspeed_config = None
                optim = None
                if "scheduler" in ds_configs:
                    scheduler = None
                else:
@@ -793,7 +780,6 @@
                    def scheduler(opt):
                        return scheduler_class(opt, **kwargs.get("scheduler_conf"))
            args = OmegaConf.create({"deepspeed_config": deepspeed_config})
            model, optimizer, _, scheduler = deepspeed.initialize(
                args=args,
                model=model,