| | |
| | | ) |
| | | return match_state |
| | | |
| | | def assigment_scope_map(dst_state: dict, src_state: dict, scope_map: str=None): |
| | | """Compute the union of the current variables and checkpoint variables.""" |
| | | import collections |
| | | import re |
| | | |
| | | # current model variables |
| | | name_to_variable = collections.OrderedDict() |
| | | for name, var in dst_state.items(): |
| | | name_to_variable[name] = var |
| | | |
| | | scope_map_num = 0 |
| | | if scope_map is not None: |
| | | scope_map = scope_map.split(",") |
| | | scope_map_num = len(scope_map) // 2 |
| | | for scope_map_idx in range(scope_map_num): |
| | | scope_map_id = scope_map_idx * 2 |
| | | logging.info('assignment_map from scope {} to {}'.format(scope_map[scope_map_id], scope_map[scope_map_id+1])) |
| | | |
| | | assignment_map = {} |
| | | for name, var in src_state.items(): |
| | | |
| | | if scope_map: |
| | | for scope_map_idx in range(scope_map_num): |
| | | scope_map_id = scope_map_idx * 2 |
| | | try: |
| | | idx = name.index(scope_map[scope_map_id]) |
| | | new_name = scope_map[scope_map_id+1] + name[idx + len(scope_map[scope_map_id]):] |
| | | if new_name in name_to_variable: |
| | | assignment_map[name] = var |
| | | except: |
| | | continue |
| | | else: |
| | | if name in name_to_variable: |
| | | assignment_map[name] = var |
| | | |
| | | return assignment_map |
| | | |
| | | def load_pretrained_model( |
| | | init_param: str, |
| | | path: str, |
| | | model: torch.nn.Module, |
| | | ignore_init_mismatch: bool, |
| | | map_location: str = "cpu", |
| | | oss_bucket=None, |
| | | scope_map=None, |
| | | excludes=None, |
| | | ): |
| | | """Load a model state and set it to the model. |
| | | |
| | |
| | | init_param: <file_path>:<src_key>:<dst_key>:<exclude_Keys> |
| | | |
| | | Examples: |
| | | >>> load_pretrained_model("somewhere/model.pb", model) |
| | | >>> load_pretrained_model("somewhere/model.pb:decoder:decoder", model) |
| | | >>> load_pretrained_model("somewhere/model.pb:decoder:decoder:", model) |
| | | >>> load_pretrained_model( |
| | | ... "somewhere/model.pb:decoder:decoder:decoder.embed", model |
| | | ... ) |
| | | >>> load_pretrained_model("somewhere/decoder.pb::decoder", model) |
| | | """ |
| | | sps = init_param.split(":", 4) |
| | | if len(sps) == 4: |
| | | path, src_key, dst_key, excludes = sps |
| | | elif len(sps) == 3: |
| | | path, src_key, dst_key = sps |
| | | excludes = None |
| | | elif len(sps) == 2: |
| | | path, src_key = sps |
| | | dst_key, excludes = None, None |
| | | else: |
| | | (path,) = sps |
| | | src_key, dst_key, excludes = None, None, None |
| | | if src_key == "": |
| | | src_key = None |
| | | if dst_key == "": |
| | | dst_key = None |
| | | |
| | | if dst_key is None: |
| | | """ |
| | | |
| | | obj = model |
| | | else: |
| | | |
| | | def get_attr(obj: Any, key: str): |
| | | """Get an nested attribute. |
| | | |
| | | >>> class A(torch.nn.Module): |
| | | ... def __init__(self): |
| | | ... super().__init__() |
| | | ... self.linear = torch.nn.Linear(10, 10) |
| | | >>> a = A() |
| | | >>> assert A.linear.weight is get_attr(A, 'linear.weight') |
| | | |
| | | """ |
| | | if key.strip() == "": |
| | | return obj |
| | | for k in key.split("."): |
| | | obj = getattr(obj, k) |
| | | return obj |
| | | |
| | | obj = get_attr(model, dst_key) |
| | | |
| | | if oss_bucket is None: |
| | | src_state = torch.load(path, map_location=map_location) |
| | |
| | | buffer = BytesIO(oss_bucket.get_object(path).read()) |
| | | src_state = torch.load(buffer, map_location=map_location) |
| | | src_state = src_state["model"] if "model" in src_state else src_state |
| | | |
| | | if excludes is not None: |
| | | for e in excludes.split(","): |
| | | src_state = {k: v for k, v in src_state.items() if not k.startswith(e)} |
| | | |
| | | if src_key is not None: |
| | | src_state = { |
| | | k[len(src_key) + 1 :]: v |
| | | for k, v in src_state.items() |
| | | if k.startswith(src_key) |
| | | } |
| | | |
| | | dst_state = obj.state_dict() |
| | | src_state = assigment_scope_map(dst_state, src_state, scope_map) |
| | | |
| | | if ignore_init_mismatch: |
| | | src_state = filter_state_dict(dst_state, src_state) |
| | | |
| | | logging.debug("Loaded src_state keys: {}".format(src_state.keys())) |
| | | logging.debug("Loaded dst_state keys: {}".format(dst_state.keys())) |
| | | dst_state.update(src_state) |
| | | # dst_state.update(src_state) |
| | | obj.load_state_dict(dst_state) |
| | | |