From 5d916b5a8a68fa85e79f16c2df5c45871b5298e6 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 28 二月 2024 15:24:30 +0800
Subject: [PATCH] init param
---
funasr/train_utils/load_pretrained_model.py | 57 +++++++++++++++++++++++++++++++++++++++++++--------------
1 files changed, 43 insertions(+), 14 deletions(-)
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 16feabd..520aaca 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -75,13 +75,14 @@
return assignment_map
+
def load_pretrained_model(
path: str,
model: torch.nn.Module,
ignore_init_mismatch: bool,
map_location: str = "cpu",
oss_bucket=None,
- scope_map=None,
+ scope_map="module.:none",
excludes=None,
):
"""Load a model state and set it to the model.
@@ -94,25 +95,53 @@
"""
obj = model
+ dst_state = obj.state_dict()
+ print(f"ckpt: {path}")
if oss_bucket is None:
src_state = torch.load(path, map_location=map_location)
else:
buffer = BytesIO(oss_bucket.get_object(path).read())
src_state = torch.load(buffer, map_location=map_location)
+ if "state_dict" in src_state:
+ src_state = src_state["state_dict"]
+
src_state = src_state["model"] if "model" in src_state else src_state
- if excludes is not None:
- for e in excludes.split(","):
- src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
+ if isinstance(scope_map, str):
+ scope_map = scope_map.split(",")
- dst_state = obj.state_dict()
- src_state = assigment_scope_map(dst_state, src_state, scope_map)
-
- if ignore_init_mismatch:
- src_state = filter_state_dict(dst_state, src_state)
-
- logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
- logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
- # dst_state.update(src_state)
- obj.load_state_dict(dst_state)
\ No newline at end of file
+ for k in dst_state.keys():
+ # if not k.startswith("module.") and "module." + k in src_state.keys():
+ # k_ddp = "module." + k
+ # else:
+ # k_ddp = k
+ k_src = k
+
+ if scope_map is not None:
+ src_prefix = ""
+ dst_prefix = ""
+ for i in range(0, len(scope_map), 2):
+ src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
+ dst_prefix = scope_map[i+1] if scope_map[i+1].lower() != "none" else ""
+
+ if k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix) in src_state.keys():
+ k_src = k.replace(dst_prefix, src_prefix)
+ print(f"init param, map: {k} from {k_src} in ckpt")
+
+ if k_src in src_state.keys():
+ dst_state[k] = src_state[k_src]
+
+ # if k_ddp.startswith("audio_encoder"):
+ # if k_ddp.replace("audio_encoder", "encoder.model") in src_state.keys():
+ # k_ddp = k_ddp.replace("audio_encoder", "encoder.model")
+ # if k_ddp.startswith("adaptor"):
+ # if k_ddp.replace("adaptor", "encoder_projector") in src_state.keys():
+ # k_ddp = k_ddp.replace("adaptor", "encoder_projector")
+ # if k_ddp in src_state:
+ # dst_state[k] = src_state[k_ddp]
+ else:
+ print(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
+
+ flag = obj.load_state_dict(dst_state, strict=False)
+ # print(flag)
--
Gitblit v1.9.1