From fb45c9a6ef4c5f94d8b36abafca072f62aff9b5f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 15 五月 2024 17:32:07 +0800
Subject: [PATCH] hf hub
---
funasr/download/download_from_hub.py | 82 ++++++++++++++++++++++++++++++++++++++++
funasr/download/name_maps_from_hub.py | 4 +
2 files changed, 84 insertions(+), 2 deletions(-)
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 43f5b67..075b131 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -10,7 +10,7 @@
if hub == "ms":
kwargs = download_from_ms(**kwargs)
elif hub == "hf":
- pass
+ kwargs = download_from_hf(**kwargs)
elif hub == "openai":
model_or_path = kwargs.get("model")
if os.path.exists(model_or_path):
@@ -34,6 +34,67 @@
if not os.path.exists(model_or_path) and "model_path" not in kwargs:
try:
model_or_path = get_or_download_model_dir(
+ model_or_path,
+ model_revision,
+ is_training=kwargs.get("is_training"),
+ check_latest=kwargs.get("check_latest", True),
+ )
+ except Exception as e:
+ print(f"Download: {model_or_path} failed!: {e}")
+
+ kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
+
+ if os.path.exists(os.path.join(model_or_path, "configuration.json")):
+ with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
+ conf_json = json.load(f)
+
+ cfg = {}
+ if "file_path_metas" in conf_json:
+ add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+ cfg.update(kwargs)
+ if "config" in cfg:
+ config = OmegaConf.load(cfg["config"])
+ kwargs = OmegaConf.merge(config, cfg)
+ kwargs["model"] = config["model"]
+ elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
+ os.path.join(model_or_path, "model.pt")
+ ):
+ config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
+ kwargs = OmegaConf.merge(config, kwargs)
+ init_param = os.path.join(model_or_path, "model.pb")
+ kwargs["init_param"] = init_param
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+ if os.path.exists(os.path.join(model_or_path, "tokens.json")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
+ if os.path.exists(os.path.join(model_or_path, "seg_dict")):
+ kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+ if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+ kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+ kwargs["model"] = config["model"]
+ if os.path.exists(os.path.join(model_or_path, "am.mvn")):
+ kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+ if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
+ kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
+ if isinstance(kwargs, DictConfig):
+ kwargs = OmegaConf.to_container(kwargs, resolve=True)
+ if os.path.exists(os.path.join(model_or_path, "requirements.txt")):
+ requirements = os.path.join(model_or_path, "requirements.txt")
+ print(f"Detect model requirements, begin to install it: {requirements}")
+ from funasr.utils.install_model_requirements import install_requirements
+
+ install_requirements(requirements)
+ return kwargs
+
+
+def download_from_hf(**kwargs):
+ model_or_path = kwargs.get("model")
+ if model_or_path in name_maps_hf:
+ model_or_path = name_maps_hf[model_or_path]
+ model_revision = kwargs.get("model_revision", "master")
+ if not os.path.exists(model_or_path) and "model_path" not in kwargs:
+ try:
+ model_or_path = get_or_download_model_dir_hf(
model_or_path,
model_revision,
is_training=kwargs.get("is_training"),
@@ -136,3 +197,22 @@
model, revision=model_revision, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"}
)
return model_cache_dir
+
+
+def get_or_download_model_dir_hf(
+ model,
+ model_revision=None,
+ is_training=False,
+ check_latest=True,
+):
+ """Get local model directory or download model if necessary.
+
+ Args:
+ model (str): model id or path to local model directory.
+ model_revision (str, optional): model version number.
+ :param is_training:
+ """
+ from huggingface_hub import snapshot_download
+
+ model_cache_dir = snapshot_download(model)
+ return model_cache_dir
diff --git a/funasr/download/name_maps_from_hub.py b/funasr/download/name_maps_from_hub.py
index 87a89fc..3bb25a7 100644
--- a/funasr/download/name_maps_from_hub.py
+++ b/funasr/download/name_maps_from_hub.py
@@ -14,7 +14,9 @@
"Qwen-Audio": "Qwen/Qwen-Audio",
}
-name_maps_hf = {}
+name_maps_hf = {
+ "": "",
+}
name_maps_openai = {
"Whisper-tiny.en": "tiny.en",
--
Gitblit v1.9.1