游雁
2024-01-08 fb176404cfeb40c053f4f42d01eb45c185d21ce2
funasr/download/download_from_hub.py
@@ -1,3 +1,4 @@
import json
import os
from omegaconf import OmegaConf
import torch
@@ -19,7 +20,9 @@
      model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"))
   
   config = os.path.join(model_or_path, "config.yaml")
   assert os.path.exists(config), "{} is not exist!".format(config)
   if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
      # config = os.path.join(model_or_path, "config.yaml")
      # assert os.path.exists(config), "{} is not exist!".format(config)
   cfg = OmegaConf.load(config)
   kwargs = OmegaConf.merge(cfg, kwargs)
   init_param = os.path.join(model_or_path, "model.pb")
@@ -35,7 +38,16 @@
   kwargs["model"] = cfg["model"]
   if os.path.exists(os.path.join(model_or_path, "am.mvn")):
      kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
   else:# configuration.json
      assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
      with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
         conf_json = json.load(f)
         config = os.path.join(model_or_path, conf_json["model"]["model_config"])
         cfg = OmegaConf.load(config)
         kwargs = OmegaConf.merge(cfg, kwargs)
         init_param = os.path.join(model_or_path, conf_json["model"]["model_name"])
         kwargs["init_param"] = init_param
      kwargs["model"] = cfg["model"]
   return OmegaConf.to_container(kwargs, resolve=True)
def get_or_download_model_dir(
@@ -60,12 +72,15 @@
   if os.path.exists(model):
      model_cache_dir = model if os.path.isdir(
         model) else os.path.dirname(model)
      try:
      check_local_model_is_latest(
         model_cache_dir,
         user_agent={
            Invoke.KEY: key,
            ThirdParty.KEY: "funasr"
         })
      except:
         print("could not check the latest version")
   else:
      model_cache_dir = snapshot_download(
         model,