From 78c78c39a90c62b7c552019043a970e9f85bf378 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期二, 10 十月 2023 17:11:15 +0800
Subject: [PATCH] big fix for speaker pipeline

---
 funasr/utils/modelscope_utils.py |   26 +++++++++++++++++++++++++-
 1 files changed, 25 insertions(+), 1 deletions(-)

diff --git a/funasr/utils/modelscope_utils.py b/funasr/utils/modelscope_utils.py
index 9712e09..4179885 100644
--- a/funasr/utils/modelscope_utils.py
+++ b/funasr/utils/modelscope_utils.py
@@ -1,5 +1,6 @@
 import os
 from modelscope.hub.snapshot_download import snapshot_download
+from pathlib import Path
 
 
 def check_model_dir(model_dir, model_name: str = "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"):
@@ -13,4 +14,27 @@
 	if not os.path.exists(dst):
 		os.symlink(model_dir, dst)
 	
-	model_dir = snapshot_download(model_name, cache_dir=dst_dir_root)
\ No newline at end of file
+	model_dir = snapshot_download(model_name, cache_dir=dst_dir_root)
+
+def get_default_cache_dir():
+    """
+    default base dir: '~/.cache/modelscope'
+    """
+    default_cache_dir = Path.home().joinpath('.cache', 'modelscope')
+    return default_cache_dir
+
+def get_cache_dir(model_id):
+    """cache dir precedence:
+        function parameter > environment > ~/.cache/modelscope/hub
+
+    Args:
+        model_id (str, optional): The model id.
+
+    Returns:
+        str: the model_id dir if model_id not None, otherwise cache root dir.
+    """
+    default_cache_dir = get_default_cache_dir()
+    base_path = os.getenv('MODELSCOPE_CACHE',
+                          os.path.join(default_cache_dir, 'hub'))
+    return base_path if model_id is None else os.path.join(
+        base_path, model_id + '/')
\ No newline at end of file

--
Gitblit v1.9.1