From cc41a9ee88a8dca027a34c37cc1c67f8198958b9 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 01 三月 2024 14:58:36 +0800
Subject: [PATCH] whisper

---
 funasr/train_utils/load_pretrained_model.py |    6 +++---
 1 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 84c6320..ea23725 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -68,9 +68,9 @@
 	else:
 		buffer = BytesIO(oss_bucket.get_object(path).read())
 		src_state = torch.load(buffer, map_location=map_location)
-	if "state_dict" in src_state:
-		src_state = src_state["state_dict"]
-	
+		
+	src_state = src_state["state_dict"] if "state_dict" in src_state else src_state
+	src_state = src_state["model_state_dict"] if "model_state_dict" in src_state else src_state
 	src_state = src_state["model"] if "model" in src_state else src_state
 	
 	if isinstance(scope_map, str):

--
Gitblit v1.9.1