From f6b611de44c3a535befa96da552d07b0ed1b073c Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 十二月 2023 15:52:16 +0800
Subject: [PATCH] funasr1.0

---
 funasr/bin/inference.py |   21 ++++++++++++++++++++-
 1 files changed, 20 insertions(+), 1 deletions(-)

diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index fda7abe..16ad0e2 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -18,6 +18,7 @@
 from funasr.register import tables
 from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio
 from funasr.utils.vad_utils import slice_padding_audio_samples
+from funasr.utils.timestamp_tools import time_stamp_sentence
 
 def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
 	"""
@@ -46,7 +47,7 @@
 						data = lines["source"]
 						key = data["key"] if "key" in data else key
 					else: # filelist, wav.scp, text.txt: id \t data or data
-						lines = line.strip().split()
+						lines = line.strip().split(maxsplit=1)
 						data = lines[1] if len(lines)>1 else lines[0]
 						key = lines[0] if len(lines)>1 else key
 					
@@ -227,6 +228,7 @@
 		# step.1: compute the vad model
 		model = self.vad_model
 		kwargs = self.vad_kwargs
+		kwargs.update(cfg)
 		beg_vad = time.time()
 		res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
 		end_vad = time.time()
@@ -322,6 +324,23 @@
 			result["key"] = key
 			results_ret_list.append(result)
 			pbar_total.update(1)
+		
+		# step.3 compute punc model
+		model = self.punc_model
+		kwargs = self.punc_kwargs
+		kwargs.update(cfg)
+
+		for i, result in enumerate(results_ret_list):
+			beg_punc = time.time()
+			res = self.generate(result["text"], model=model, kwargs=kwargs, **cfg)
+			end_punc = time.time()
+			print(f"time punc: {end_punc - beg_punc:0.3f}")
+			
+			# sentences = time_stamp_sentence(model.punc_list, model.sentence_end_id, results_ret_list[i]["timestamp"], res[i]["text"])
+			# results_ret_list[i]["time_stamp"] = res[0]["text_postprocessed_punc"]
+			# results_ret_list[i]["sentences"] = sentences
+			# results_ret_list[i]["text_with_punc"] = res[i]["text"]
+		
 		pbar_total.update(1)
 		end_total = time.time()
 		time_escape_total_all_samples = end_total - beg_total

--
Gitblit v1.9.1