From 4e7f5d075f98eb4f837f21be1dc498912ae31830 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 01 三月 2024 16:24:04 +0800
Subject: [PATCH] update deploy tools

---
 funasr/train_utils/load_pretrained_model.py |    6 +++---
 1 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 84c6320..ea23725 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -68,9 +68,9 @@
 	else:
 		buffer = BytesIO(oss_bucket.get_object(path).read())
 		src_state = torch.load(buffer, map_location=map_location)
-	if "state_dict" in src_state:
-		src_state = src_state["state_dict"]
-	
+		
+	src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
+	src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
 	src_state = src_state["model"] if "model" in src_state else src_state
 	
 	if isinstance(scope_map, str):

--
Gitblit v1.9.1