From 9ba0dbd98bf69c830dfcfde8f109a400cb65e4e5 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 29 三月 2024 17:24:59 +0800
Subject: [PATCH] fix func Forward

---
 funasr/train_utils/load_pretrained_model.py |  189 +++++++++++++++++++++-------------------------
 1 files changed, 87 insertions(+), 102 deletions(-)

diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index a6596a0..0c46449 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -7,121 +7,106 @@
 import torch
 import torch.nn
 import torch.optim
-
+import pdb
 
 def filter_state_dict(
-    dst_state: Dict[str, Union[float, torch.Tensor]],
-    src_state: Dict[str, Union[float, torch.Tensor]],
+	dst_state: Dict[str, Union[float, torch.Tensor]],
+	src_state: Dict[str, Union[float, torch.Tensor]],
 ):
-    """Filter name, size mismatch instances between dicts.
+	"""Filter name, size mismatch instances between dicts.
 
-    Args:
-        dst_state: reference state dict for filtering
-        src_state: target state dict for filtering
+	Args:
+		dst_state: reference state dict for filtering
+		src_state: target state dict for filtering
 
-    """
-    match_state = {}
-    for key, value in src_state.items():
-        if key in dst_state and (dst_state[key].size() == src_state[key].size()):
-            match_state[key] = value
-        else:
-            if key not in dst_state:
-                logging.warning(
-                    f"Filter out {key} from pretrained dict"
-                    + " because of name not found in target dict"
-                )
-            else:
-                logging.warning(
-                    f"Filter out {key} from pretrained dict"
-                    + " because of size mismatch"
-                    + f"({dst_state[key].size()}-{src_state[key].size()})"
-                )
-    return match_state
+	"""
+	match_state = {}
+	for key, value in src_state.items():
+		if key in dst_state and (dst_state[key].size() == src_state[key].size()):
+			match_state[key] = value
+		else:
+			if key not in dst_state:
+				logging.warning(
+					f"Filter out {key} from pretrained dict"
+					+ " because of name not found in target dict"
+				)
+			else:
+				logging.warning(
+					f"Filter out {key} from pretrained dict"
+					+ " because of size mismatch"
+					+ f"({dst_state[key].size()}-{src_state[key].size()})"
+				)
+	return match_state
 
 
 def load_pretrained_model(
-    init_param: str,
-    model: torch.nn.Module,
-    ignore_init_mismatch: bool,
-    map_location: str = "cpu",
-    oss_bucket=None,
+	path: str,
+	model: torch.nn.Module,
+	ignore_init_mismatch: bool=True,
+	map_location: str = "cpu",
+	oss_bucket=None,
+	scope_map=[],
+	excludes=None,
+	ignore_mismatch=False,
+	**kwargs,
 ):
-    """Load a model state and set it to the model.
+	"""Load a model state and set it to the model.
 
-    Args:
-        init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
+	Args:
+		init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys>
 
-    Examples:
-        >>> load_pretrained_model("somewhere/model.pb", model)
-        >>> load_pretrained_model("somewhere/model.pb:decoder:decoder", model)
-        >>> load_pretrained_model("somewhere/model.pb:decoder:decoder:", model)
-        >>> load_pretrained_model(
-        ...     "somewhere/model.pb:decoder:decoder:decoder.embed", model
-        ... )
-        >>> load_pretrained_model("somewhere/decoder.pb::decoder", model)
-    """
-    sps = init_param.split(":", 4)
-    if len(sps) == 4:
-        path, src_key, dst_key, excludes = sps
-    elif len(sps) == 3:
-        path, src_key, dst_key = sps
-        excludes = None
-    elif len(sps) == 2:
-        path, src_key = sps
-        dst_key, excludes = None, None
-    else:
-        (path,) = sps
-        src_key, dst_key, excludes = None, None, None
-    if src_key == "":
-        src_key = None
-    if dst_key == "":
-        dst_key = None
+	Examples:
 
-    if dst_key is None:
-        obj = model
-    else:
+	"""
+	
+	obj = model
+	dst_state = obj.state_dict()
+	
+	print(f"ckpt: {path}")
 
-        def get_attr(obj: Any, key: str):
-            """Get an nested attribute.
+	if oss_bucket is None:
+		src_state = torch.load(path, map_location=map_location)
+	else:
+		buffer = BytesIO(oss_bucket.get_object(path).read())
+		src_state = torch.load(buffer, map_location=map_location)
+		
+	src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
+	src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
+	src_state = src_state["model"] if "model" in src_state else src_state
+	
+	if isinstance(scope_map, str):
+		scope_map = scope_map.split(",")
+	scope_map += ["module.", "None"]
+	
+	for k in dst_state.keys():
+		
+		k_src = k
 
-            >>> class A(torch.nn.Module):
-            ...     def __init__(self):
-            ...         super().__init__()
-            ...         self.linear = torch.nn.Linear(10, 10)
-            >>> a = A()
-            >>> assert A.linear.weight is get_attr(A, 'linear.weight')
+		if scope_map is not None:
+			src_prefix = ""
+			dst_prefix = ""
+			for i in range(0, len(scope_map), 2):
+				src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
+				dst_prefix = scope_map[i+1] if scope_map[i+1].lower() != "none" else ""
+				
+				if dst_prefix == "" and (src_prefix + k) in src_state.keys():
+					k_src = src_prefix + k
+					if not k_src.startswith("module."):
+						print(f"init param, map: {k} from {k_src} in ckpt")
+				elif k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix, 1) in src_state.keys():
+					k_src = k.replace(dst_prefix, src_prefix, 1)
+					if not k_src.startswith("module."):
+						print(f"init param, map: {k} from {k_src} in ckpt")
+					
+		if k_src in src_state.keys():
+			if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
+				print(f"ignore_mismatch:{ignore_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}")
+			else:
+				dst_state[k] = src_state[k_src]
 
-            """
-            if key.strip() == "":
-                return obj
-            for k in key.split("."):
-                obj = getattr(obj, k)
-            return obj
 
-        obj = get_attr(model, dst_key)
-
-    if oss_bucket is None:
-        src_state = torch.load(path, map_location=map_location)
-    else:
-        buffer = BytesIO(oss_bucket.get_object(path).read())
-        src_state = torch.load(buffer, map_location=map_location)
-    src_state = src_state["model"] if "model" in src_state else src_state
-    if excludes is not None:
-        for e in excludes.split(","):
-            src_state = {k: v for k, v in src_state.items() if not k.startswith(e)}
-
-    if src_key is not None:
-        src_state = {
-            k[len(src_key) + 1 :]: v
-            for k, v in src_state.items()
-            if k.startswith(src_key)
-        }
-
-    dst_state = obj.state_dict()
-    if ignore_init_mismatch:
-        src_state = filter_state_dict(dst_state, src_state)
-
-    logging.debug("Loaded src_state keys: {}".format(src_state.keys()))
-    logging.debug("Loaded dst_state keys: {}".format(dst_state.keys()))
-    dst_state.update(src_state)
-    obj.load_state_dict(dst_state)
+		else:
+			print(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
+			
+	flag = obj.load_state_dict(dst_state, strict=True)
+	# print(flag)

--
Gitblit v1.9.1