From fb176404cfeb40c053f4f42d01eb45c185d21ce2 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 08 一月 2024 16:20:45 +0800
Subject: [PATCH] funasr1.0 emotion2vec
---
funasr/download/download_from_hub.py | 61 +++++++++++++++++++-----------
1 files changed, 38 insertions(+), 23 deletions(-)
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 4f05b42..abf3ba0 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -1,3 +1,4 @@
+import json
import os
from omegaconf import OmegaConf
import torch
@@ -19,23 +20,34 @@
model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"))
config = os.path.join(model_or_path, "config.yaml")
- assert os.path.exists(config), "{} is not exist!".format(config)
- cfg = OmegaConf.load(config)
- kwargs = OmegaConf.merge(cfg, kwargs)
- init_param = os.path.join(model_or_path, "model.pb")
- kwargs["init_param"] = init_param
- if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
- kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
- if os.path.exists(os.path.join(model_or_path, "tokens.json")):
- kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
- if os.path.exists(os.path.join(model_or_path, "seg_dict")):
- kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
- if os.path.exists(os.path.join(model_or_path, "bpe.model")):
- kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
- kwargs["model"] = cfg["model"]
- if os.path.exists(os.path.join(model_or_path, "am.mvn")):
- kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
-
+ if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
+ # config = os.path.join(model_or_path, "config.yaml")
+ # assert os.path.exists(config), "{} is not exist!".format(config)
+ cfg = OmegaConf.load(config)
+ kwargs = OmegaConf.merge(cfg, kwargs)
+ init_param = os.path.join(model_or_path, "model.pb")
+ kwargs["init_param"] = init_param
+ if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+ if os.path.exists(os.path.join(model_or_path, "tokens.json")):
+ kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
+ if os.path.exists(os.path.join(model_or_path, "seg_dict")):
+ kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+ if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+ kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+ kwargs["model"] = cfg["model"]
+ if os.path.exists(os.path.join(model_or_path, "am.mvn")):
+ kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+ else:# configuration.json
+ assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
+ with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
+ conf_json = json.load(f)
+ config = os.path.join(model_or_path, conf_json["model"]["model_config"])
+ cfg = OmegaConf.load(config)
+ kwargs = OmegaConf.merge(cfg, kwargs)
+ init_param = os.path.join(model_or_path, conf_json["model"]["model_name"])
+ kwargs["init_param"] = init_param
+ kwargs["model"] = cfg["model"]
return OmegaConf.to_container(kwargs, resolve=True)
def get_or_download_model_dir(
@@ -60,12 +72,15 @@
if os.path.exists(model):
model_cache_dir = model if os.path.isdir(
model) else os.path.dirname(model)
- check_local_model_is_latest(
- model_cache_dir,
- user_agent={
- Invoke.KEY: key,
- ThirdParty.KEY: "funasr"
- })
+ try:
+ check_local_model_is_latest(
+ model_cache_dir,
+ user_agent={
+ Invoke.KEY: key,
+ ThirdParty.KEY: "funasr"
+ })
+ except:
+ print("could not check the latest version")
else:
model_cache_dir = snapshot_download(
model,
--
Gitblit v1.9.1