From 96e4ff1870656b6b9d10de5f1a994959b286b909 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 20 二月 2024 18:38:51 +0800
Subject: [PATCH] train finetune
---
funasr/download/download_from_hub.py | 24 +++++++++++-------------
1 files changed, 11 insertions(+), 13 deletions(-)
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index c102549..4a8e57a 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -21,10 +21,17 @@
model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
kwargs["model_path"] = model_or_path
- config = os.path.join(model_or_path, "config.yaml")
- if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
-
- config = OmegaConf.load(config)
+ if os.path.exists(os.path.join(model_or_path, "configuration.json")):
+ with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
+ conf_json = json.load(f)
+ cfg = {}
+ add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+ cfg.update(kwargs)
+ config = OmegaConf.load(cfg["config"])
+ kwargs = OmegaConf.merge(config, cfg)
+ kwargs["model"] = config["model"]
+ elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(os.path.join(model_or_path, "model.pt")):
+ config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
kwargs = OmegaConf.merge(config, kwargs)
init_param = os.path.join(model_or_path, "model.pb")
kwargs["init_param"] = init_param
@@ -41,15 +48,6 @@
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
- elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
- with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
- conf_json = json.load(f)
- cfg = {}
- add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
- cfg.update(kwargs)
- config = OmegaConf.load(cfg["config"])
- kwargs = OmegaConf.merge(config, cfg)
- kwargs["model"] = config["model"]
return OmegaConf.to_container(kwargs, resolve=True)
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
--
Gitblit v1.9.1