游雁
2024-06-12 407625a73478c2bb5e20c62bfbdf53a55a1e6575
funasr/train_utils/trainer_ds.py
@@ -147,6 +147,10 @@
        self.use_deepspeed = use_deepspeed
        self.deepspeed_config = kwargs.get("deepspeed_config", "")
        self.excludes = kwargs.get("excludes", None)
        if self.excludes is not None:
            if isinstance(self.excludes, str):
                self.excludes = self.excludes.split(",")
    def save_checkpoint(
        self,
@@ -167,6 +171,8 @@
        Args:
            epoch (int): The epoch number at which the checkpoint is being saved.
        """
        if self.use_ddp or self.use_fsdp:
            dist.barrier()
        step_in_epoch = None if step is None else step_in_epoch
        if self.use_deepspeed:
@@ -438,6 +444,12 @@
                    src_state = checkpoint["state_dict"]
                    dst_state = model.state_dict()
                    for k in dst_state.keys():
                        if excludes is not None:
                            for k_ex in excludes:
                                k_tmp = k.replace("module.", "")
                                if k_tmp.startswith(k_ex):
                                    logging.info(f"key: {{k}} matching: {k_ex}, excluded")
                                    continue
                        if not k.startswith("module.") and "module." + k in src_state.keys():
                            k_ddp = "module." + k
                        elif k.startswith("module.") and "module." + k not in src_state.keys():
@@ -621,7 +633,6 @@
            self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
    def forward_step(self, model, batch, loss_dict={}):
        dtype = torch.bfloat16
        with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
            retval = model(**batch)
@@ -761,6 +772,10 @@
            ckpt_name = f'model.pt.ep{epoch}.{kwargs.get("step_in_epoch")}'
        self.val_acc_step_or_eoch[ckpt_name] = self.val_acc_avg
        self.val_loss_step_or_eoch[ckpt_name] = self.val_loss_avg
        if self.use_ddp or self.use_fsdp or self.use_deepspeed:
            dist.barrier()
        model.train()
    def log(