From bdfd27b9e96bd55c449953bb577e1d4deeaf11c9 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期六, 13 一月 2024 23:43:17 +0800
Subject: [PATCH] funasr1.0
---
funasr/download/download_from_hub.py | 34 +++++++++++++++++++++++++---------
1 files changed, 25 insertions(+), 9 deletions(-)
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 27bd79d..9779050 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -22,8 +22,8 @@
config = os.path.join(model_or_path, "config.yaml")
if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
- cfg = OmegaConf.load(config)
- kwargs = OmegaConf.merge(cfg, kwargs)
+ config = OmegaConf.load(config)
+ kwargs = OmegaConf.merge(config, kwargs)
init_param = os.path.join(model_or_path, "model.pb")
kwargs["init_param"] = init_param
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
@@ -34,7 +34,7 @@
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
- kwargs["model"] = cfg["model"]
+ kwargs["model"] = config["model"]
if os.path.exists(os.path.join(model_or_path, "am.mvn")):
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
@@ -43,14 +43,30 @@
assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
conf_json = json.load(f)
- config = os.path.join(model_or_path, conf_json["model_config"])
- cfg = OmegaConf.load(config)
- kwargs = OmegaConf.merge(cfg, kwargs)
- init_param = os.path.join(model_or_path, conf_json["model_file"])
- kwargs["init_param"] = init_param
- kwargs["model"] = cfg["model"]
+ cfg = {}
+ add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+ cfg.update(kwargs)
+ config = OmegaConf.load(cfg["config"])
+ kwargs = OmegaConf.merge(config, cfg)
+ kwargs["model"] = config["model"]
return OmegaConf.to_container(kwargs, resolve=True)
+def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
+
+ if isinstance(file_path_metas, dict):
+ for k, v in file_path_metas.items():
+ if isinstance(v, str):
+ p = os.path.join(model_or_path, v)
+ if os.path.exists(p):
+ cfg[k] = p
+ elif isinstance(v, dict):
+ if k not in cfg:
+ cfg[k] = {}
+ return add_file_root_path(model_or_path, v, cfg[k])
+
+ return cfg
+
+
def get_or_download_model_dir(
model,
model_revision=None,
--
Gitblit v1.9.1