From eb3d5c78bf764799f98ba5b19307831efd62285d Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 25 三月 2024 09:59:20 +0800
Subject: [PATCH] install requirements automatically
---
funasr/download/download_from_hub.py | 3 +++
funasr/utils/install_model_requirements.py | 35 +++++++++++++++++++++++++++++++++++
2 files changed, 38 insertions(+), 0 deletions(-)
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index ef2832f..f23be0d 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -72,6 +72,9 @@
kwargs["jieba_usr_dict"] = os.path.join(model_or_path, "jieba_usr_dict")
if isinstance(kwargs, DictConfig):
kwargs = OmegaConf.to_container(kwargs, resolve=True)
+ if os.path.exists(os.path.join(model_or_path, "requirements.txt")):
+ from funasr.utils.install_model_requirements import install_requirements
+ install_requirements(os.path.join(model_or_path, "requirements.txt"))
return kwargs
def add_file_root_path(model_or_path: str, file_path_metas: dict, cfg = {}):
diff --git a/funasr/utils/install_model_requirements.py b/funasr/utils/install_model_requirements.py
new file mode 100644
index 0000000..b67345d
--- /dev/null
+++ b/funasr/utils/install_model_requirements.py
@@ -0,0 +1,35 @@
+import subprocess
+
+def install_requirements(requirements_path):
+ try:
+ result = subprocess.run(
+ ['pip', 'install', '-r', requirements_path],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True
+ )
+
+ # check status
+ if result.returncode == 0:
+ print("install model requirements successfully")
+ return True
+ else:
+ print("fail to install model requirements! ")
+ print("error", result.stderr)
+ return False
+ except Exception as e:
+ result = subprocess.run(
+ ['pip', 'install', '-r', requirements_path],
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True
+ )
+
+ # check status
+ if result.returncode == 0:
+ print("install model requirements successfully")
+ return True
+ else:
+ print("fail to install model requirements! ")
+ print("error", result.stderr)
+ return False
--
Gitblit v1.9.1