游雁
2024-02-28 7a4816651fd59ba02f780884613c1fbf52031f76
funasr/train_utils/load_pretrained_model.py
@@ -82,7 +82,7 @@
   ignore_init_mismatch: bool,
   map_location: str = "cpu",
   oss_bucket=None,
   scope_map=None,
   scope_map="module.:none",
   excludes=None,
):
   """Load a model state and set it to the model.
@@ -108,15 +108,40 @@
   
   src_state = src_state["model"] if "model" in src_state else src_state
   
   if isinstance(scope_map, str):
      scope_map = scope_map.split(",")
   for k in dst_state.keys():
      if not k.startswith("module.") and "module." + k in src_state.keys():
         k_ddp = "module." + k
      # if not k.startswith("module.") and "module." + k in src_state.keys():
      #    k_ddp = "module." + k
      # else:
      #    k_ddp = k
      k_src = k
      if scope_map is not None:
         src_prefix = ""
         dst_prefix = ""
         for i in range(0, len(scope_map), 2):
            src_prefix = scope_map[i] if scope_map[i].lower() != "none" else ""
            dst_prefix = scope_map[i+1] if scope_map[i+1].lower() != "none" else ""
            if k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix) in src_state.keys():
               k_src = k.replace(dst_prefix, src_prefix)
               print(f"init param, map: {k} from {k_src} in ckpt")
      if k_src in src_state.keys():
         dst_state[k] = src_state[k_src]
      # if k_ddp.startswith("audio_encoder"):
      #    if k_ddp.replace("audio_encoder", "encoder.model") in src_state.keys():
      #       k_ddp = k_ddp.replace("audio_encoder", "encoder.model")
      # if k_ddp.startswith("adaptor"):
      #    if k_ddp.replace("adaptor", "encoder_projector") in src_state.keys():
      #       k_ddp = k_ddp.replace("adaptor", "encoder_projector")
      # if k_ddp in src_state:
      #    dst_state[k] = src_state[k_ddp]
      else:
         k_ddp = k
      if k_ddp in src_state:
         dst_state[k] = src_state[k_ddp]
      else:
         print(f"Warning, miss key in ckpt: {k}, mapped: {k_ddp}")
         print(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")
         
   flag = obj.load_state_dict(dst_state, strict=False)
   # print(flag)