游雁
2024-03-01 cc41a9ee88a8dca027a34c37cc1c67f8198958b9
funasr/train_utils/load_pretrained_model.py
@@ -68,9 +68,9 @@
   else:
      buffer = BytesIO(oss_bucket.get_object(path).read())
      src_state = torch.load(buffer, map_location=map_location)
   if "state_dict" in src_state:
      src_state = src_state["state_dict"]
   src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
   src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
   src_state = src_state["model"] if "model" in src_state else src_state
   
   if isinstance(scope_map, str):