游雁
2024-06-12 407625a73478c2bb5e20c62bfbdf53a55a1e6575
funasr/train_utils/trainer_ds.py
@@ -147,6 +147,10 @@
        self.use_deepspeed = use_deepspeed
        self.deepspeed_config = kwargs.get("deepspeed_config", "")
        self.excludes = kwargs.get("excludes", None)
        if self.excludes is not None:
            if isinstance(self.excludes, str):
                self.excludes = self.excludes.split(",")
    def save_checkpoint(
        self,
@@ -440,6 +444,12 @@
                    src_state = checkpoint["state_dict"]
                    dst_state = model.state_dict()
                    for k in dst_state.keys():
                        if excludes is not None:
                            for k_ex in excludes:
                                k_tmp = k.replace("module.", "")
                                if k_tmp.startswith(k_ex):
                                    logging.info(f"key: {{k}} matching: {k_ex}, excluded")
                                    continue
                        if not k.startswith("module.") and "module." + k in src_state.keys():
                            k_ddp = "module." + k
                        elif k.startswith("module.") and "module." + k not in src_state.keys():