From f14f9f8d15037c7b81cbdc880d61d05e23382a8f Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 09 一月 2024 00:13:51 +0800
Subject: [PATCH] funasr1.0 infer url modelscope

---
 funasr/utils/load_utils.py |   29 ++++++++++++++++++++++++++---
 1 files changed, 26 insertions(+), 3 deletions(-)

diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index 4fb27c0..c5c3ffc 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -9,7 +9,12 @@
 import time
 import logging
 from torch.nn.utils.rnn import pad_sequence
-
+try:
+	from urllib.parse import urlparse
+	from funasr.download.file import HTTPStorage
+	import tempfile
+except:
+	print("urllib is not installed, if you infer from url, please install it first.")
 # def load_audio(data_or_path_or_list, fs: int=16000, audio_fs: int=16000):
 #
 # 	if isinstance(data_or_path_or_list, (list, tuple)):
@@ -43,7 +48,8 @@
 			return data_or_path_or_list_ret
 		else:
 			return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs) for audio in data_or_path_or_list]
-	
+	if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'):
+		data_or_path_or_list = download_from_url(data_or_path_or_list)
 	if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list):
 		data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
 		data_or_path_or_list = data_or_path_or_list[0, :]
@@ -99,4 +105,21 @@
 	
 	if isinstance(data_len, (list, tuple)):
 		data_len = torch.tensor([data_len])
-	return data.to(torch.float32), data_len.to(torch.int32)
\ No newline at end of file
+	return data.to(torch.float32), data_len.to(torch.int32)
+
+def download_from_url(url):
+	
+	result = urlparse(url)
+	file_path = None
+	if result.scheme is not None and len(result.scheme) > 0:
+		storage = HTTPStorage()
+		# bytes
+		data = storage.read(url)
+		work_dir = tempfile.TemporaryDirectory().name
+		if not os.path.exists(work_dir):
+			os.makedirs(work_dir)
+		file_path = os.path.join(work_dir, os.path.basename(url))
+		with open(file_path, 'wb') as fb:
+			fb.write(data)
+	assert file_path is not None, f"failed to download: {url}"
+	return file_path
\ No newline at end of file

--
Gitblit v1.9.1