游雁
2024-06-12 6a74eb706c1f26023608b06013f6ae38bc569d6f
decoding
2个文件已修改
45 ■■■■ 已修改文件
funasr/train_utils/load_pretrained_model.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer_ds.py 41 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/load_pretrained_model.py
@@ -56,10 +56,14 @@
    logging.info(f"excludes: {excludes}")
    for k in dst_state.keys():
        excludes_flag = False
        if excludes is not None:
            for k_ex in excludes:
                if k.startswith(k_ex):
                    logging.info(f"key: {{k}} matching: {k_ex}, excluded")
                    excludes_flag = True
                    break
        if excludes_flag:
                    continue
        k_src = k
funasr/train_utils/trainer_ds.py
@@ -147,10 +147,16 @@
        self.use_deepspeed = use_deepspeed
        self.deepspeed_config = kwargs.get("deepspeed_config", "")
        self.excludes = kwargs.get("excludes", None)
        if self.excludes is not None:
            if isinstance(self.excludes, str):
                self.excludes = self.excludes.split(",")
        excludes = kwargs.get("excludes", None)
        if excludes is not None:
            if isinstance(excludes, str):
                excludes = excludes.split(",")
        self.excludes = excludes
        effective_save_name_excludes = kwargs.get("effective_save_name_excludes", None)
        if effective_save_name_excludes is not None:
            if isinstance(effective_save_name_excludes, str):
                effective_save_name_excludes = effective_save_name_excludes.split(",")
        self.effective_save_name_excludes = effective_save_name_excludes
    def save_checkpoint(
        self,
@@ -285,7 +291,6 @@
            # self.step_or_epoch += 1
            state = {
                "epoch": epoch,
                "state_dict": model.state_dict(),
                "optimizer": optim.state_dict(),
                "scheduler": scheduler.state_dict(),
                "saved_ckpts": self.saved_ckpts,
@@ -303,7 +308,23 @@
            }
            step = step_in_epoch
            if hasattr(model, "module"):
                state["state_dict"] = model.module.state_dict()
                state_dict = model.module.state_dict()
            else:
                state_dict = model.state_dict()
            if self.effective_save_name_excludes is not None:
                dst_state_dict = {}
                for k in state_dict.keys():
                    for k_ex in self.effective_save_name_excludes:
                        k_tmp = k.replace("module.", "")
                        if k.startswith(k_ex):
                            logging.info(f"key: {{k}} matching: {k_ex}, not save it")
                            break
                    else:
                        dst_state_dict[k] = state_dict[k]
                state["state_dict"] = dst_state_dict
            else:
                state["state_dict"] = state_dict
            if scaler:
                state["scaler_state"] = scaler.state_dict()
@@ -444,11 +465,15 @@
                    src_state = checkpoint["state_dict"]
                    dst_state = model.state_dict()
                    for k in dst_state.keys():
                        if excludes is not None:
                            for k_ex in excludes:
                        excludes_flag = False
                        if self.excludes is not None:
                            for k_ex in self.excludes:
                                k_tmp = k.replace("module.", "")
                                if k_tmp.startswith(k_ex):
                                    logging.info(f"key: {{k}} matching: {k_ex}, excluded")
                                    excludes_flag = True
                                    break
                        if excludes_flag:
                                    continue
                        if not k.startswith("module.") and "module." + k in src_state.keys():
                            k_ddp = "module." + k