From cad1979179a110b154568dd6281035ece9aaf0b8 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 29 三月 2024 16:48:03 +0800
Subject: [PATCH] add batch for offline-stream

---
 runtime/onnxruntime/src/audio.cpp          |   26 ++++++++
 runtime/onnxruntime/src/funasrruntime.cpp  |  146 +++++++++++++++++++++++++++---------------------
 runtime/onnxruntime/src/offline-stream.cpp |    7 +-
 3 files changed, 112 insertions(+), 67 deletions(-)

diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp
index 0135ab4..ef5d5f3 100644
--- a/runtime/onnxruntime/src/audio.cpp
+++ b/runtime/onnxruntime/src/audio.cpp
@@ -1023,6 +1023,32 @@
     }
 }
 
+int Audio::Fetch(float**& dout, int*& len, int*& flag, float*& start_time, int batch_size, int &batch_in)
+{
+    batch_in = std::min((int)frame_queue.size(), batch_size);
+    if (batch_in == 0){
+        return 0;
+    } else{
+        // init
+        dout = new float*[batch_in];
+        len = new int[batch_in];
+        flag = new int[batch_in];
+        start_time = new float[batch_in];
+
+        for(int idx=0; idx < batch_in; idx++){
+            AudioFrame *frame = frame_queue.front();
+            frame_queue.pop();
+
+            start_time[idx] = (float)(frame->GetStart())/ dest_sample_rate;
+            dout[idx] = speech_data + frame->GetStart();
+            len[idx] = frame->GetLen();
+            delete frame;
+            flag[idx] = S_END;
+        }
+        return 1;
+    }
+}
+
 void Audio::Padding()
 {
     float num_samples = speech_len;
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 711eac7..e9a3f59 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -33,9 +33,9 @@
 		return mm;
 	}
 
-	_FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu)
+	_FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
 	{
-		funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu);
+		funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu, batch_size);
 		return mm;
 	}
 
@@ -74,16 +74,11 @@
 		if(p_result->snippet_time == 0){
 			return p_result;
 		}
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
+
 		while (audio.Fetch(buff, len, flag) > 0) {
 			string msg = recog_obj->Forward(buff, len, input_finished);
 			p_result->msg += msg;
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
-
 		return p_result;
 	}
 
@@ -109,8 +104,6 @@
 		float* buff;
 		int len;
 		int flag = 0;
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
 		funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
 		p_result->snippet_time = audio.GetTimeLen();
 		if(p_result->snippet_time == 0){
@@ -119,11 +112,7 @@
 		while (audio.Fetch(buff, len, flag) > 0) {
 			string msg = recog_obj->Forward(buff, len, true);
 			p_result->msg += msg;
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
-
 		return p_result;
 	}
 
@@ -248,37 +237,51 @@
 			audio.CutSplit(offline_stream);
 		}
 
-		float* buff;
-		int len;
-		int flag = 0;
+		float** buff;
+		int* len;
+		int* flag;
+		float* start_time;
+		int batch_size = offline_stream->asr_handle->GetBatchSize();
+		int batch_in = 0;
 
-		float start_time = 0.0;
 		std::string cur_stamp = "[";
 		std::string lang = (offline_stream->asr_handle)->GetLang();
-		while (audio.Fetch(buff, len, flag, start_time) > 0) {
+		while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
 			// dec reset
 			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
 			if (wfst_decoder){
 				wfst_decoder->StartUtterance();
 			}
-			string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
-			std::vector<std::string> msg_vec = funasr::split(msg, '|');
-			if(msg_vec.size()==0){
-				continue;
-			}
-			if(lang == "en-bpe" && p_result->msg != ""){
-				p_result->msg += " ";
-			}
-			p_result->msg += msg_vec[0];
-			//timestamp
-			if(msg_vec.size() > 1){
-				std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
-				for(int i=0; i<msg_stamp.size()-1; i+=2){
-					float begin = std::stof(msg_stamp[i])+start_time;
-					float end = std::stof(msg_stamp[i+1])+start_time;
-					cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
+			vector<string> msgs = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
+			for(int idx=0; idx<batch_in; idx++){
+				string msg = msgs[idx];
+				std::vector<std::string> msg_vec = funasr::split(msg, '|');
+				if(msg_vec.size()==0){
+					continue;
+				}
+				if(lang == "en-bpe" && p_result->msg != ""){
+					p_result->msg += " ";
+				}
+				p_result->msg += msg_vec[0];
+				//timestamp
+				if(msg_vec.size() > 1){
+					std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
+					for(int i=0; i<msg_stamp.size()-1; i+=2){
+						float begin = std::stof(msg_stamp[i])+start_time[idx];
+						float end = std::stof(msg_stamp[i+1])+start_time[idx];
+						cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
+					}
 				}
 			}
+			// release
+			delete[] buff;
+			buff = nullptr;
+			delete[] len;
+			len = nullptr;
+			delete[] flag;
+			flag = nullptr;
+			delete[] start_time;
+			start_time = nullptr;
 		}
 		if(cur_stamp != "["){
 			cur_stamp.erase(cur_stamp.length() - 1);
@@ -341,42 +344,51 @@
 			audio.CutSplit(offline_stream);
 		}
 
-		float* buff;
-		int len;
-		int flag = 0;
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
-		float start_time = 0.0;
+		float** buff;
+		int* len;
+		int* flag;
+		float* start_time;
+		int batch_size = offline_stream->asr_handle->GetBatchSize();
+		int batch_in = 0;
+
 		std::string cur_stamp = "[";
 		std::string lang = (offline_stream->asr_handle)->GetLang();
-		while (audio.Fetch(buff, len, flag, start_time) > 0) {
+		while (audio.Fetch(buff, len, flag, start_time, batch_size, batch_in) > 0) {
 			// dec reset
 			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
 			if (wfst_decoder){
 				wfst_decoder->StartUtterance();
 			}
-			string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
-			std::vector<std::string> msg_vec = funasr::split(msg, '|');
-			if(msg_vec.size()==0){
-				continue;
-			}
-			if(lang == "en-bpe" && p_result->msg != ""){
-				p_result->msg += " ";
-			}
-			p_result->msg += msg_vec[0];
-			//timestamp
-			if(msg_vec.size() > 1){
-				std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
-				for(int i=0; i<msg_stamp.size()-1; i+=2){
-					float begin = std::stof(msg_stamp[i])+start_time;
-					float end = std::stof(msg_stamp[i+1])+start_time;
-					cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
+			vector<string> msgs = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
+			for(int idx=0; idx<batch_in; idx++){
+				string msg = msgs[idx];
+				std::vector<std::string> msg_vec = funasr::split(msg, '|');
+				if(msg_vec.size()==0){
+					continue;
+				}
+				if(lang == "en-bpe" && p_result->msg != ""){
+					p_result->msg += " ";
+				}
+				p_result->msg += msg_vec[0];
+				//timestamp
+				if(msg_vec.size() > 1){
+					std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
+					for(int i=0; i<msg_stamp.size()-1; i+=2){
+						float begin = std::stof(msg_stamp[i])+start_time[idx];
+						float end = std::stof(msg_stamp[i+1])+start_time[idx];
+						cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
+					}
 				}
 			}
-
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
+			// release
+			delete[] buff;
+			buff = nullptr;
+			delete[] len;
+			len = nullptr;
+			delete[] flag;
+			flag = nullptr;
+			delete[] start_time;
+			start_time = nullptr;
 		}
 		if(cur_stamp != "["){
 			cur_stamp.erase(cur_stamp.length() - 1);
@@ -513,8 +525,14 @@
 			if (wfst_decoder){
 				wfst_decoder->StartUtterance();
 			}
-			string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb, dec_handle);
-
+			float** buff;
+			int* len;
+			buff = new float*[1];
+        	len = new int[1];
+			buff[0] = frame->data;
+			len[0] = frame->len;
+			vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
+			string msg = msgs.size()>0?msgs[0]:"";
 			std::vector<std::string> msg_vec = funasr::split(msg, '|');  // split with timestamp
 			if(msg_vec.size()==0){
 				continue;
diff --git a/runtime/onnxruntime/src/offline-stream.cpp b/runtime/onnxruntime/src/offline-stream.cpp
index 69befc6..3f914aa 100644
--- a/runtime/onnxruntime/src/offline-stream.cpp
+++ b/runtime/onnxruntime/src/offline-stream.cpp
@@ -1,7 +1,7 @@
 #include "precomp.h"
 
 namespace funasr {
-OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu)
+OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
 {
     // VAD model
     if(model_path.find(VAD_DIR) != model_path.end()){
@@ -38,6 +38,7 @@
         if(use_gpu){
             #ifdef USE_GPU
             asr_handle = make_unique<ParaformerTorch>();
+            asr_handle->SetBatchSize(batch_size);
             #else
             LOG(ERROR) <<"GPU is not supported! CPU will be used! If you want to use GPU, please add -DGPU=ON when cmake";
             asr_handle = make_unique<Paraformer>();
@@ -135,10 +136,10 @@
 #endif
 }
 
-OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu)
+OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
 {
     OfflineStream *mm;
-    mm = new OfflineStream(model_path, thread_num, use_gpu);
+    mm = new OfflineStream(model_path, thread_num, use_gpu, batch_size);
     return mm;
 }
 

--
Gitblit v1.9.1