zhifu gao
2024-03-30 6fa8ee48e117fa9c3bef450e02776e8c26b278e7
funasr/train_utils/load_pretrained_model.py
@@ -47,7 +47,6 @@
   oss_bucket=None,
   scope_map=[],
   excludes=None,
   ignore_mismatch=False,
   **kwargs,
):
   """Load a model state and set it to the model.
@@ -100,7 +99,7 @@
               
      if k_src in src_state.keys():
         if ignore_init_mismatch and dst_state[k].shape != src_state[k_src].shape:
            print(f"ignore_mismatch:{ignore_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}")
            print(f"ignore_init_mismatch:{ignore_init_mismatch}, dst: {k, dst_state[k].shape}, src: {k_src, src_state[k_src].shape}")
         else:
            dst_state[k] = src_state[k_src]