From fb45c9a6ef4c5f94d8b36abafca072f62aff9b5f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 15 五月 2024 17:32:07 +0800
Subject: [PATCH] hf hub

---
 funasr/download/download_from_hub.py  |   82 ++++++++++++++++++++++++++++++++++++++++
 funasr/download/name_maps_from_hub.py |    4 +
 2 files changed, 84 insertions(+), 2 deletions(-)

diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 43f5b67..075b131 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -10,7 +10,7 @@
     if hub == "ms":
         kwargs = download_from_ms(**kwargs)
     elif hub == "hf":
-        pass
+        kwargs = download_from_hf(**kwargs)
     elif hub == "openai":
         model_or_path = kwargs.get("model")
         if os.path.exists(model_or_path):
@@ -34,6 +34,67 @@
     if not os.path.exists(model_or_path) and "model_path" not in kwargs:
         try:
             model_or_path = get_or_download_model_dir(
+                model_or_path,
+                model_revision,
+                is_training=kwargs.get("is_training"),
+                check_latest=kwargs.get("check_latest", True),
+            )
+        except Exception as e:
+            print(f"Download: {model_or_path} failed!: {e}")
+
+    kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
+
+    if os.path.exists(os.path.join(model_or_path, "configuration.json")):
+        with open(os.path.join(model_or_path, "configuration.json"), "r", encoding="utf-8") as f:
+            conf_json = json.load(f)
+
+            cfg = {}
+            if "file_path_metas" in conf_json:
+                add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+            cfg.update(kwargs)
+            if "config" in cfg:
+                config = OmegaConf.load(cfg["config"])
+                kwargs = OmegaConf.merge(config, cfg)
+                kwargs["model"] = config["model"]
+    elif os.path.exists(os.path.join(model_or_path, "config.yaml")) and os.path.exists(
+        os.path.join(model_or_path, "model.pt")
+    ):
+        config = OmegaConf.load(os.path.join(model_or_path, "config.yaml"))
+        kwargs = OmegaConf.merge(config, kwargs)
+        init_param = os.path.join(model_or_path, "model.pb")
+        kwargs["init_param"] = init_param
+        if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+            kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+        if os.path.exists(os.path.join(model_or_path, "tokens.json")):
+            kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
+        if os.path.exists(os.path.join(model_or_path, "seg_dict")):
+            kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+        if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+            kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+        kwargs["model"] = config["model"]
+        if os.path.exists(os.path.join(model_or_path, "am.mvn")):
+            kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+        if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
+            kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
+    if isinstance(kwargs, DictConfig):
+        kwargs = OmegaConf.to_container(kwargs, resolve=True)
+    if os.path.exists(os.path.join(model_or_path, "requirements.txt")):
+        requirements = os.path.join(model_or_path, "requirements.txt")
+        print(f"Detect model requirements, begin to install it: {requirements}")
+        from funasr.utils.install_model_requirements import install_requirements
+
+        install_requirements(requirements)
+    return kwargs
+
+
+def download_from_hf(**kwargs):
+    model_or_path = kwargs.get("model")
+    if model_or_path in name_maps_hf:
+        model_or_path = name_maps_hf[model_or_path]
+    model_revision = kwargs.get("model_revision", "master")
+    if not os.path.exists(model_or_path) and "model_path" not in kwargs:
+        try:
+            model_or_path = get_or_download_model_dir_hf(
                 model_or_path,
                 model_revision,
                 is_training=kwargs.get("is_training"),
@@ -136,3 +197,22 @@
             model, revision=model_revision, user_agent={Invoke.KEY: key, ThirdParty.KEY: "funasr"}
         )
     return model_cache_dir
+
+
+def get_or_download_model_dir_hf(
+    model,
+    model_revision=None,
+    is_training=False,
+    check_latest=True,
+):
+    """Get local model directory or download model if necessary.
+
+    Args:
+        model (str): model id or path to local model directory.
+        model_revision  (str, optional): model version number.
+        :param is_training:
+    """
+    from huggingface_hub import snapshot_download
+
+    model_cache_dir = snapshot_download(model)
+    return model_cache_dir
diff --git a/funasr/download/name_maps_from_hub.py b/funasr/download/name_maps_from_hub.py
index 87a89fc..3bb25a7 100644
--- a/funasr/download/name_maps_from_hub.py
+++ b/funasr/download/name_maps_from_hub.py
@@ -14,7 +14,9 @@
     "Qwen-Audio": "Qwen/Qwen-Audio",
 }
 
-name_maps_hf = {}
+name_maps_hf = {
+    "": "",
+}
 
 name_maps_openai = {
     "Whisper-tiny.en": "tiny.en",

--
Gitblit v1.9.1