From cad1979179a110b154568dd6281035ece9aaf0b8 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 29 三月 2024 16:48:03 +0800
Subject: [PATCH] add batch for offline-stream

---
 runtime/onnxruntime/src/funasrruntime.cpp |  146 +++++++++++++++++++++++++++---------------------
 1 files changed, 82 insertions(+), 64 deletions(-)

diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 711eac7..e9a3f59 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -33,9 +33,9 @@
 		return mm;
 	}
 
-	_FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu)
+	_FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
 	{
-		funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu);
+		funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu, batch_size);
 		return mm;
 	}
 
@@ -74,16 +74,11 @@
 		if(p_result->snippet_time == 0){
 			return p_result;
 		}
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
+
 		while (audio.Fetch(buff, len, flag) > 0) {
 			string msg = recog_obj->Forward(buff, len, input_finished);
 			p_result->msg += msg;
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
-
 		return p_result;
 	}
 
@@ -109,8 +104,6 @@
 		float* buff;
 		int len;
 		int flag = 0;
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
 		funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
 		p_result->snippet_time = audio.GetTimeLen();
 		if(p_result->snippet_time == 0){
@@ -119,11 +112,7 @@
 		while (audio.Fetch(buff, len, flag) > 0) {
 			string msg = recog_obj->Forward(buff, len, true);
 			p_result->msg += msg;
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
-
 		return p_result;
 	}
 
@@ -248,37 +237,51 @@
 			audio.CutSplit(offline_stream);
 		}
 
-		float* buff;
-		int len;
-		int flag = 0;
+		float** buff;
+		int* len;
+		int* flag;
+		float* start_time;
+		int batch_size = offline_stream->asr_handle->GetBatchSize();
+		int batch_in = 0;
 
-		float start_time = 0.0;
 		std::string cur_stamp = "[";
 		std::string lang = (offline_stream->asr_handle)->GetLang();
-		while (audio.Fetch(buff, len, flag, start_time) > 0) {
+		while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
 			// dec reset
 			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
 			if (wfst_decoder){
 				wfst_decoder->StartUtterance();
 			}
-			string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
-			std::vector<std::string> msg_vec = funasr::split(msg, '|');
-			if(msg_vec.size()==0){
-				continue;
-			}
-			if(lang == "en-bpe" && p_result->msg != ""){
-				p_result->msg += " ";
-			}
-			p_result->msg += msg_vec[0];
-			//timestamp
-			if(msg_vec.size() > 1){
-				std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
-				for(int i=0; i<msg_stamp.size()-1; i+=2){
-					float begin = std::stof(msg_stamp[i])+start_time;
-					float end = std::stof(msg_stamp[i+1])+start_time;
-					cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
+			vector<string> msgs = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
+			for(int idx=0; idx<batch_in; idx++){
+				string msg = msgs[idx];
+				std::vector<std::string> msg_vec = funasr::split(msg, '|');
+				if(msg_vec.size()==0){
+					continue;
+				}
+				if(lang == "en-bpe" && p_result->msg != ""){
+					p_result->msg += " ";
+				}
+				p_result->msg += msg_vec[0];
+				//timestamp
+				if(msg_vec.size() > 1){
+					std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
+					for(int i=0; i<msg_stamp.size()-1; i+=2){
+						float begin = std::stof(msg_stamp[i])+start_time[idx];
+						float end = std::stof(msg_stamp[i+1])+start_time[idx];
+						cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
+					}
 				}
 			}
+			// release
+			delete[] buff;
+			buff = nullptr;
+			delete[] len;
+			len = nullptr;
+			delete[] flag;
+			flag = nullptr;
+			delete[] start_time;
+			start_time = nullptr;
 		}
 		if(cur_stamp != "["){
 			cur_stamp.erase(cur_stamp.length() - 1);
@@ -341,42 +344,51 @@
 			audio.CutSplit(offline_stream);
 		}
 
-		float* buff;
-		int len;
-		int flag = 0;
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
-		float start_time = 0.0;
+		float** buff;
+		int* len;
+		int* flag;
+		float* start_time;
+		int batch_size = offline_stream->asr_handle->GetBatchSize();
+		int batch_in = 0;
+
 		std::string cur_stamp = "[";
 		std::string lang = (offline_stream->asr_handle)->GetLang();
-		while (audio.Fetch(buff, len, flag, start_time) > 0) {
+		while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
 			// dec reset
 			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
 			if (wfst_decoder){
 				wfst_decoder->StartUtterance();
 			}
-			string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
-			std::vector<std::string> msg_vec = funasr::split(msg, '|');
-			if(msg_vec.size()==0){
-				continue;
-			}
-			if(lang == "en-bpe" && p_result->msg != ""){
-				p_result->msg += " ";
-			}
-			p_result->msg += msg_vec[0];
-			//timestamp
-			if(msg_vec.size() > 1){
-				std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
-				for(int i=0; i<msg_stamp.size()-1; i+=2){
-					float begin = std::stof(msg_stamp[i])+start_time;
-					float end = std::stof(msg_stamp[i+1])+start_time;
-					cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
+			vector<string> msgs = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
+			for(int idx=0; idx<batch_in; idx++){
+				string msg = msgs[idx];
+				std::vector<std::string> msg_vec = funasr::split(msg, '|');
+				if(msg_vec.size()==0){
+					continue;
+				}
+				if(lang == "en-bpe" && p_result->msg != ""){
+					p_result->msg += " ";
+				}
+				p_result->msg += msg_vec[0];
+				//timestamp
+				if(msg_vec.size() > 1){
+					std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
+					for(int i=0; i<msg_stamp.size()-1; i+=2){
+						float begin = std::stof(msg_stamp[i])+start_time[idx];
+						float end = std::stof(msg_stamp[i+1])+start_time[idx];
+						cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
+					}
 				}
 			}
-
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
+			// release
+			delete[] buff;
+			buff = nullptr;
+			delete[] len;
+			len = nullptr;
+			delete[] flag;
+			flag = nullptr;
+			delete[] start_time;
+			start_time = nullptr;
 		}
 		if(cur_stamp != "["){
 			cur_stamp.erase(cur_stamp.length() - 1);
@@ -513,8 +525,14 @@
 			if (wfst_decoder){
 				wfst_decoder->StartUtterance();
 			}
-			string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
-
+			float** buff;
+			int* len;
+			buff = new float*[1];
+        	len = new int[1];
+			buff[0] = frame->data;
+			len[0] = frame->len;
+			vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
+			string msg = msgs.size()>0?msgs[0]:"";
 			std::vector<std::string> msg_vec = funasr::split(msg, '|');  // split with timestamp
 			if(msg_vec.size()==0){
 				continue;

--
Gitblit v1.9.1