zhifu gao
2024-04-24 861147c7308b91068ffa02724fdf74ee623a909e
funasr/train_utils/load_pretrained_model.py
@@ -9,6 +9,7 @@
import torch.optim
import pdb
def filter_state_dict(
   dst_state: Dict[str, Union[float, torch.Tensor]],
   src_state: Dict[str, Union[float, torch.Tensor]],
@@ -92,17 +93,21 @@
               k_src = src_prefix + k
               if not k_src.startswith("module."):
                  print(f"init param, map: {k} from {k_src} in ckpt")
            elif k.startswith(dst_prefix) and k.replace(dst_prefix, src_prefix, 1) in src_state.keys():
                elif (
                    k.startswith(dst_prefix)
                    and k.replace(dst_prefix, src_prefix, 1) in src_state.keys()
                ):
               k_src = k.replace(dst_prefix, src_prefix, 1)
               if not k_src.startswith("module."):
                  print(f"init param, map: {k} from {k_src} in ckpt")
               
      if k_src in src_state.keys():
         if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
            print(f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}")
                print(
                    f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}"
                )
         else:
            dst_state[k] = src_state[k_src]
      else:
         print(f"Warning, miss key in ckpt: {k}, mapped: {k_src}")