funasr/train_utils/load_pretrained_model.py
@@ -68,9 +68,9 @@ else: buffer = BytesIO(oss_bucket.get_object(path).read()) src_state = torch.load(buffer, map_location=map_location) if "state_dict" in src_state: src_state = src_state["state_dict"] src_state = src_state["state_dict"] if "state_dict" in src_state else src_state src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state src_state = src_state["model"] if "model" in src_state else src_state if isinstance(scope_map, str):