From bdfd27b9e96bd55c449953bb577e1d4deeaf11c9 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 13 一月 2024 23:43:17 +0800
Subject: [PATCH] funasr1.0

---
 funasr/download/download_from_hub.py |   34 +++++++++++++++++++++++++---------
 1 files changed, 25 insertions(+), 9 deletions(-)

diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 27bd79d..9779050 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -22,8 +22,8 @@
 	
 	config = os.path.join(model_or_path, "config.yaml")
 	if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
-		cfg = OmegaConf.load(config)
-		kwargs = OmegaConf.merge(cfg, kwargs)
+		config = OmegaConf.load(config)
+		kwargs = OmegaConf.merge(config, kwargs)
 		init_param = os.path.join(model_or_path, "model.pb")
 		kwargs["init_param"] = init_param
 		if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
@@ -34,7 +34,7 @@
 			kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
 		if os.path.exists(os.path.join(model_or_path, "bpe.model")):
 			kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
-		kwargs["model"] = cfg["model"]
+		kwargs["model"] = config["model"]
 		if os.path.exists(os.path.join(model_or_path, "am.mvn")):
 			kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
 		if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
@@ -43,14 +43,30 @@
 		assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
 		with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
 			conf_json = json.load(f)
-			config = os.path.join(model_or_path, conf_json["model_config"])
-			cfg = OmegaConf.load(config)
-			kwargs = OmegaConf.merge(cfg, kwargs)
-			init_param = os.path.join(model_or_path, conf_json["model_file"])
-			kwargs["init_param"] = init_param
-		kwargs["model"] = cfg["model"]
+			cfg = {}
+			add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+			cfg.update(kwargs)
+			config = OmegaConf.load(cfg["config"])
+			kwargs = OmegaConf.merge(config, cfg)
+		kwargs["model"] = config["model"]
 	return OmegaConf.to_container(kwargs, resolve=True)
 
+def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
+	
+	if isinstance(file_path_metas, dict):
+		for k, v in file_path_metas.items():
+			if isinstance(v, str):
+				p = os.path.join(model_or_path, v)
+				if os.path.exists(p):
+					cfg[k] = p
+			elif isinstance(v, dict):
+				if k not in cfg:
+					cfg[k] = {}
+				return add_file_root_path(model_or_path, v, cfg[k])
+	
+	return cfg
+
+
 def get_or_download_model_dir(
 		model,
 		model_revision=None,

--
Gitblit v1.9.1