From f1c1cb0773fca5e9d1ee595ef6ca2ff4bad9f2a4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 30 一月 2024 14:47:26 +0800
Subject: [PATCH] funasr1.0.4

---
 funasr/download/download_from_hub.py |  193 ++++++++++++++++++++++++------------------------
 1 files changed, 97 insertions(+), 96 deletions(-)

diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index 9779050..c102549 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -1,110 +1,111 @@
-import json
 import os
+import json
 from omegaconf import OmegaConf
-import torch
+
 from funasr.download.name_maps_from_hub import name_maps_ms, name_maps_hf
 
+
 def download_model(**kwargs):
-	model_hub = kwargs.get("model_hub", "ms")
-	if model_hub == "ms":
-		kwargs = download_from_ms(**kwargs)
-	
-	return kwargs
+    model_hub = kwargs.get("model_hub", "ms")
+    if model_hub == "ms":
+        kwargs = download_from_ms(**kwargs)
+    
+    return kwargs
 
 def download_from_ms(**kwargs):
-	model_or_path = kwargs.get("model")
-	if model_or_path in name_maps_ms:
-		model_or_path = name_maps_ms[model_or_path]
-	model_revision = kwargs.get("model_revision")
-	if not os.path.exists(model_or_path):
-		model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
-	kwargs["model_path"] = model_or_path
-	
-	config = os.path.join(model_or_path, "config.yaml")
-	if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
-		config = OmegaConf.load(config)
-		kwargs = OmegaConf.merge(config, kwargs)
-		init_param = os.path.join(model_or_path, "model.pb")
-		kwargs["init_param"] = init_param
-		if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
-			kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
-		if os.path.exists(os.path.join(model_or_path, "tokens.json")):
-			kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
-		if os.path.exists(os.path.join(model_or_path, "seg_dict")):
-			kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
-		if os.path.exists(os.path.join(model_or_path, "bpe.model")):
-			kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
-		kwargs["model"] = config["model"]
-		if os.path.exists(os.path.join(model_or_path, "am.mvn")):
-			kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
-		if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
-			kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
-	else:# configuration.json
-		assert os.path.exists(os.path.join(model_or_path, "configuration.json"))
-		with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
-			conf_json = json.load(f)
-			cfg = {}
-			add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
-			cfg.update(kwargs)
-			config = OmegaConf.load(cfg["config"])
-			kwargs = OmegaConf.merge(config, cfg)
-		kwargs["model"] = config["model"]
-	return OmegaConf.to_container(kwargs, resolve=True)
+    model_or_path = kwargs.get("model")
+    if model_or_path in name_maps_ms:
+        model_or_path = name_maps_ms[model_or_path]
+    model_revision = kwargs.get("model_revision")
+    if not os.path.exists(model_or_path):
+        model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training"), check_latest=kwargs.get("kwargs", True))
+    kwargs["model_path"] = model_or_path
+    
+    config = os.path.join(model_or_path, "config.yaml")
+    if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")):
+        
+        config = OmegaConf.load(config)
+        kwargs = OmegaConf.merge(config, kwargs)
+        init_param = os.path.join(model_or_path, "model.pb")
+        kwargs["init_param"] = init_param
+        if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+            kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
+        if os.path.exists(os.path.join(model_or_path, "tokens.json")):
+            kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json")
+        if os.path.exists(os.path.join(model_or_path, "seg_dict")):
+            kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
+        if os.path.exists(os.path.join(model_or_path, "bpe.model")):
+            kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
+        kwargs["model"] = config["model"]
+        if os.path.exists(os.path.join(model_or_path, "am.mvn")):
+            kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
+        if os.path.exists(os.path.join(model_or_path, "jieba_usr_dict")):
+            kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
+    elif os.path.exists(os.path.join(model_or_path, "configuration.json")):
+        with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f:
+            conf_json = json.load(f)
+            cfg = {}
+            add_file_root_path(model_or_path, conf_json["file_path_metas"], cfg)
+            cfg.update(kwargs)
+            config = OmegaConf.load(cfg["config"])
+            kwargs = OmegaConf.merge(config, cfg)
+        kwargs["model"] = config["model"]
+    return OmegaConf.to_container(kwargs, resolve=True)
 
 def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
-	
-	if isinstance(file_path_metas, dict):
-		for k, v in file_path_metas.items():
-			if isinstance(v, str):
-				p = os.path.join(model_or_path, v)
-				if os.path.exists(p):
-					cfg[k] = p
-			elif isinstance(v, dict):
-				if k not in cfg:
-					cfg[k] = {}
-				return add_file_root_path(model_or_path, v, cfg[k])
-	
-	return cfg
+    
+    if isinstance(file_path_metas, dict):
+        for k, v in file_path_metas.items():
+            if isinstance(v, str):
+                p = os.path.join(model_or_path, v)
+                if os.path.exists(p):
+                    cfg[k] = p
+            elif isinstance(v, dict):
+                if k not in cfg:
+                    cfg[k] = {}
+                add_file_root_path(model_or_path, v, cfg[k])
+    
+    return cfg
 
 
 def get_or_download_model_dir(
-		model,
-		model_revision=None,
-		is_training=False,
-		check_latest=True,
-	):
-	""" Get local model directory or download model if necessary.
+        model,
+        model_revision=None,
+        is_training=False,
+        check_latest=True,
+    ):
+    """ Get local model directory or download model if necessary.
 
-	Args:
-		model (str): model id or path to local model directory.
-		model_revision  (str, optional): model version number.
-		:param is_training:
-	"""
-	from modelscope.hub.check_model import check_local_model_is_latest
-	from modelscope.hub.snapshot_download import snapshot_download
+    Args:
+        model (str): model id or path to local model directory.
+        model_revision  (str, optional): model version number.
+        :param is_training:
+    """
+    from modelscope.hub.check_model import check_local_model_is_latest
+    from modelscope.hub.snapshot_download import snapshot_download
 
-	from modelscope.utils.constant import Invoke, ThirdParty
-	
-	key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
-	
-	if os.path.exists(model) and check_latest:
-		model_cache_dir = model if os.path.isdir(
-			model) else os.path.dirname(model)
-		try:
-			check_local_model_is_latest(
-				model_cache_dir,
-				user_agent={
-					Invoke.KEY: key,
-					ThirdParty.KEY: "funasr"
-				})
-		except:
-			print("could not check the latest version")
-	else:
-		model_cache_dir = snapshot_download(
-			model,
-			revision=model_revision,
-			user_agent={
-				Invoke.KEY: key,
-				ThirdParty.KEY: "funasr"
-			})
-	return model_cache_dir
\ No newline at end of file
+    from modelscope.utils.constant import Invoke, ThirdParty
+    
+    key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE
+    
+    if os.path.exists(model) and check_latest:
+        model_cache_dir = model if os.path.isdir(
+            model) else os.path.dirname(model)
+        try:
+            check_local_model_is_latest(
+                model_cache_dir,
+                user_agent={
+                    Invoke.KEY: key,
+                    ThirdParty.KEY: "funasr"
+                })
+        except:
+            print("could not check the latest version")
+    else:
+        model_cache_dir = snapshot_download(
+            model,
+            revision=model_revision,
+            user_agent={
+                Invoke.KEY: key,
+                ThirdParty.KEY: "funasr"
+            })
+    return model_cache_dir
\ No newline at end of file

--
Gitblit v1.9.1