游雁
2024-06-12 407625a73478c2bb5e20c62bfbdf53a55a1e6575
decoding
2个文件已修改
11 ■■■■■ 已修改文件
funasr/bin/train_ds.py 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/trainer_ds.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/train_ds.py
@@ -124,6 +124,7 @@
        use_ddp=use_ddp,
        use_fsdp=use_fsdp,
        device=kwargs["device"],
        excludes=kwargs.get("excludes", None),
        output_dir=kwargs.get("output_dir", "./exp"),
        **kwargs.get("train_conf"),
    )
funasr/train_utils/trainer_ds.py
@@ -147,6 +147,10 @@
        self.use_deepspeed = use_deepspeed
        self.deepspeed_config = kwargs.get("deepspeed_config", "")
        self.excludes = kwargs.get("excludes", None)
        if self.excludes is not None:
            if isinstance(self.excludes, str):
                self.excludes = self.excludes.split(",")
    def save_checkpoint(
        self,
@@ -440,6 +444,12 @@
                    src_state = checkpoint["state_dict"]
                    dst_state = model.state_dict()
                    for k in dst_state.keys():
                        if excludes is not None:
                            for k_ex in excludes:
                                k_tmp = k.replace("module.", "")
                                if k_tmp.startswith(k_ex):
                                    logging.info(f"key: {{k}} matching: {k_ex}, excluded")
                                    continue
                        if not k.startswith("module.") and "module." + k in src_state.keys():
                            k_ddp = "module." + k
                        elif k.startswith("module.") and "module." + k not in src_state.keys():