From d8b586e02cd14f7eed6b330bd4f110cb1e7f24ad Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 09 一月 2024 20:33:12 +0800
Subject: [PATCH] funasr1.0  modelscope

---
 funasr/bin/inference.py |   17 +++++++++--------
 1 files changed, 9 insertions(+), 8 deletions(-)

diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index 5b58907..dedaf7d 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -20,6 +20,7 @@
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 from funasr.utils.vad_utils import slice_padding_audio_samples
 from funasr.utils.timestamp_tools import time_stamp_sentence
+from funasr.download.file import download_from_url
 
 def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
 	"""
@@ -35,7 +36,8 @@
 	filelist = [".scp", ".txt", ".json", ".jsonl"]
 	
 	chars = string.ascii_letters + string.digits
-	
+	if isinstance(data_in, str) and data_in.startswith('http'): # url
+		data_in = download_from_url(data_in)
 	if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
 		_, file_extension = os.path.splitext(data_in)
 		file_extension = file_extension.lower()
@@ -59,7 +61,7 @@
 			data_list = [data_in]
 			key_list = [key]
 	elif isinstance(data_in, (list, tuple)):
-		if data_type is not None and isinstance(data_type, (list, tuple)):
+		if data_type is not None and isinstance(data_type, (list, tuple)): # mutiple inputs
 			data_list_tmp = []
 			for data_in_i, data_type_i in zip(data_in, data_type):
 				key_list, data_list_i = prepare_data_iterator(data_in=data_in_i, data_type=data_type_i)
@@ -68,7 +70,7 @@
 			for item in zip(*data_list_tmp):
 				data_list.append(item)
 		else:
-			# [audio sample point, fbank]
+			# [audio sample point, fbank, text]
 			data_list = data_in
 			key_list = ["rand_key_" + ''.join(random.choice(chars) for _ in range(13)) for _ in range(len(data_in))]
 	else: # raw text; audio sample point, fbank; bytes
@@ -198,13 +200,12 @@
 		kwargs = self.kwargs if kwargs is None else kwargs
 		kwargs.update(cfg)
 		model = self.model if model is None else model
-		
-		data_type = kwargs.get("data_type", "sound")
+
 		batch_size = kwargs.get("batch_size", 1)
 		# if kwargs.get("device", "cpu") == "cpu":
 		# 	batch_size = 1
 		
-		key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=data_type, key=key)
+		key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
 		
 		speed_stats = {}
 		asr_result_list = []
@@ -268,8 +269,8 @@
 		batch_size = int(kwargs.get("batch_size_s", 300))*1000
 		batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
 		kwargs["batch_size"] = batch_size
-		data_type = kwargs.get("data_type", "sound")
-		key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=data_type)
+		
+		key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None))
 		results_ret_list = []
 		time_speech_total_all_samples = 0.0
 

--
Gitblit v1.9.1