| | |
| | | use_ddp=use_ddp, |
| | | use_fsdp=use_fsdp, |
| | | device=kwargs["device"], |
| | | excludes=kwargs.get("excludes", None), |
| | | output_dir=kwargs.get("output_dir", "./exp"), |
| | | **kwargs.get("train_conf"), |
| | | ) |
| | |
| | | |
| | | self.use_deepspeed = use_deepspeed |
| | | self.deepspeed_config = kwargs.get("deepspeed_config", "") |
| | | self.excludes = kwargs.get("excludes", None) |
| | | if self.excludes is not None: |
| | | if isinstance(self.excludes, str): |
| | | self.excludes = self.excludes.split(",") |
| | | |
| | | def save_checkpoint( |
| | | self, |
| | |
| | | src_state = checkpoint["state_dict"] |
| | | dst_state = model.state_dict() |
| | | for k in dst_state.keys(): |
| | | if excludes is not None: |
| | | for k_ex in excludes: |
| | | k_tmp = k.replace("module.", "") |
| | | if k_tmp.startswith(k_ex): |
| | | logging.info(f"key: {{k}} matching: {k_ex}, excluded") |
| | | continue |
| | | if not k.startswith("module.") and "module." + k in src_state.keys(): |
| | | k_ddp = "module." + k |
| | | elif k.startswith("module.") and "module." + k not in src_state.keys(): |