From 590dfdefe39baf7da18693228e1ce6bf60b23bee Mon Sep 17 00:00:00 2001 From: Shi Xian <40013335+R1ckShi@users.noreply.github.com> Date: 星期五, 01 三月 2024 15:09:55 +0800 Subject: [PATCH] Merge pull request #1411 from alibaba-damo-academy/dev_gzf --- funasr/train_utils/load_pretrained_model.py | 6 +++--- 1 files changed, 3 insertions(+), 3 deletions(-) diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py index 84c6320..ea23725 100644 --- a/funasr/train_utils/load_pretrained_model.py +++ b/funasr/train_utils/load_pretrained_model.py @@ -68,9 +68,9 @@ else: buffer = BytesIO(oss_bucket.get_object(path).read()) src_state = torch.load(buffer, map_location=map_location) - if "state_dict" in src_state: - src_state = src_state["state_dict"] - + + src_state = src_state["state_dict"] if "state_dict" in src_state else src_state + src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state src_state = src_state["model"] if "model" in src_state else src_state if isinstance(scope_map, str): -- Gitblit v1.9.1