游雁
2024-01-08 fb176404cfeb40c053f4f42d01eb45c185d21ce2
funasr/download/download_from_hub.py
@@ -1,3 +1,4 @@
import json
import os
from omegaconf import OmegaConf
import torch
@@ -19,23 +20,34 @@
      model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"))
   
   config = os.path.join(model_or_path, "config.yaml")
   assert os.path.exists(config), "{} is not exist!".format(config)
   cfg = OmegaConf.load(config)
   kwargs = OmegaConf.merge(cfg, kwargs)
   init_param = os.path.join(model_or_path, "model.pb")
   kwargs["init_param"] = init_param
   if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
      kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
   if os.path.exists(os.path.join(model_or_path, "tokens.json")):
      kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
   if os.path.exists(os.path.join(model_or_path, "seg_dict")):
      kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
   if os.path.exists(os.path.join(model_or_path, "bpe.model")):
      kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
   kwargs["model"] = cfg["model"]
   if os.path.exists(os.path.join(model_or_path, "am.mvn")):
      kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
   if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
      # config = os.path.join(model_or_path, "config.yaml")
      # assert os.path.exists(config), "{} is not exist!".format(config)
      cfg = OmegaConf.load(config)
      kwargs = OmegaConf.merge(cfg, kwargs)
      init_param = os.path.join(model_or_path, "model.pb")
      kwargs["init_param"] = init_param
      if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
         kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
      if os.path.exists(os.path.join(model_or_path, "tokens.json")):
         kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
      if os.path.exists(os.path.join(model_or_path, "seg_dict")):
         kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
      if os.path.exists(os.path.join(model_or_path, "bpe.model")):
         kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
      kwargs["model"] = cfg["model"]
      if os.path.exists(os.path.join(model_or_path, "am.mvn")):
         kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
   else:# configuration.json
      assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
      with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
         conf_json = json.load(f)
         config = os.path.join(model_or_path, conf_json["model"]["model_config"])
         cfg = OmegaConf.load(config)
         kwargs = OmegaConf.merge(cfg, kwargs)
         init_param = os.path.join(model_or_path, conf_json["model"]["model_name"])
         kwargs["init_param"] = init_param
      kwargs["model"] = cfg["model"]
   return OmegaConf.to_container(kwargs, resolve=True)
def get_or_download_model_dir(
@@ -60,12 +72,15 @@
   if os.path.exists(model):
      model_cache_dir = model if os.path.isdir(
         model) else os.path.dirname(model)
      check_local_model_is_latest(
         model_cache_dir,
         user_agent={
            Invoke.KEY: key,
            ThirdParty.KEY: "funasr"
         })
      try:
         check_local_model_is_latest(
            model_cache_dir,
            user_agent={
               Invoke.KEY: key,
               ThirdParty.KEY: "funasr"
            })
      except:
         print("could not check the latest version")
   else:
      model_cache_dir = snapshot_download(
         model,