From fb176404cfeb40c053f4f42d01eb45c185d21ce2 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 08 一月 2024 16:20:45 +0800
Subject: [PATCH] funasr1.0 emotion2vec

---
 funasr/download/download_from_hub.py |   61 +++++++++++++++++++-----------
 1 files changed, 38 insertions(+), 23 deletions(-)

diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 4f05b42..abf3ba0 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -1,3 +1,4 @@
+import json
 import os
 from omegaconf import OmegaConf
 import torch
@@ -19,23 +20,34 @@
 		model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"))
 	
 	config = os.path.join(model_or_path, "config.yaml")
-	assert os.path.exists(config), "{} is not exist!".format(config)
-	cfg = OmegaConf.load(config)
-	kwargs = OmegaConf.merge(cfg, kwargs)
-	init_param = os.path.join(model_or_path, "model.pb")
-	kwargs["init_param"] = init_param
-	if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
-		kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
-	if os.path.exists(os.path.join(model_or_path, "tokens.json")):
-		kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
-	if os.path.exists(os.path.join(model_or_path, "seg_dict")):
-		kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
-	if os.path.exists(os.path.join(model_or_path, "bpe.model")):
-		kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
-	kwargs["model"] = cfg["model"]
-	if os.path.exists(os.path.join(model_or_path, "am.mvn")):
-		kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
-	
+	if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
+		# config = os.path.join(model_or_path, "config.yaml")
+		# assert os.path.exists(config), "{} is not exist!".format(config)
+		cfg = OmegaConf.load(config)
+		kwargs = OmegaConf.merge(cfg, kwargs)
+		init_param = os.path.join(model_or_path, "model.pb")
+		kwargs["init_param"] = init_param
+		if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+			kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+		if os.path.exists(os.path.join(model_or_path, "tokens.json")):
+			kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
+		if os.path.exists(os.path.join(model_or_path, "seg_dict")):
+			kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+		if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+			kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+		kwargs["model"] = cfg["model"]
+		if os.path.exists(os.path.join(model_or_path, "am.mvn")):
+			kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+	else:# configuration.json
+		assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
+		with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
+			conf_json = json.load(f)
+			config = os.path.join(model_or_path, conf_json["model"]["model_config"])
+			cfg = OmegaConf.load(config)
+			kwargs = OmegaConf.merge(cfg, kwargs)
+			init_param = os.path.join(model_or_path, conf_json["model"]["model_name"])
+			kwargs["init_param"] = init_param
+		kwargs["model"] = cfg["model"]
 	return OmegaConf.to_container(kwargs, resolve=True)
 
 def get_or_download_model_dir(
@@ -60,12 +72,15 @@
 	if os.path.exists(model):
 		model_cache_dir = model if os.path.isdir(
 			model) else os.path.dirname(model)
-		check_local_model_is_latest(
-			model_cache_dir,
-			user_agent={
-				Invoke.KEY: key,
-				ThirdParty.KEY: "funasr"
-			})
+		try:
+			check_local_model_is_latest(
+				model_cache_dir,
+				user_agent={
+					Invoke.KEY: key,
+					ThirdParty.KEY: "funasr"
+				})
+		except:
+			print("could not check the latest version")
 	else:
 		model_cache_dir = snapshot_download(
 			model,

--
Gitblit v1.9.1