From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 runtime/onnxruntime/src/funasrruntime.cpp |  219 ++++++++++++++++++++++++++++++++++++++++--------------
 1 files changed, 160 insertions(+), 59 deletions(-)

diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 21f7d82..93b89a5 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -33,9 +33,9 @@
 		return mm;
 	}
 
-	_FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num)
+	_FUNASRAPI FUNASR_HANDLE  FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num, bool use_gpu, int batch_size)
 	{
-		funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num);
+		funasr::OfflineStream* mm = funasr::CreateOfflineStream(model_path, thread_num, use_gpu, batch_size);
 		return mm;
 	}
 
@@ -74,16 +74,11 @@
 		if(p_result->snippet_time == 0){
 			return p_result;
 		}
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
+
 		while (audio.Fetch(buff, len, flag) > 0) {
 			string msg = recog_obj->Forward(buff, len, input_finished);
 			p_result->msg += msg;
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
-
 		return p_result;
 	}
 
@@ -109,8 +104,6 @@
 		float* buff;
 		int len;
 		int flag = 0;
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
 		funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
 		p_result->snippet_time = audio.GetTimeLen();
 		if(p_result->snippet_time == 0){
@@ -119,11 +112,7 @@
 		while (audio.Fetch(buff, len, flag) > 0) {
 			string msg = recog_obj->Forward(buff, len, true);
 			p_result->msg += msg;
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
-
 		return p_result;
 	}
 
@@ -146,6 +135,7 @@
 		funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
 		p_result->snippet_time = audio.GetTimeLen();
 		if(p_result->snippet_time == 0){
+			p_result->segments = new vector<std::vector<int>>();
             return p_result;
         }
 		
@@ -178,6 +168,7 @@
 		funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
 		p_result->snippet_time = audio.GetTimeLen();
 		if(p_result->snippet_time == 0){
+			p_result->segments = new vector<std::vector<int>>();
             return p_result;
         }
 		
@@ -242,26 +233,53 @@
 		if(p_result->snippet_time == 0){
             return p_result;
         }
+		std::vector<int> index_vector={0};
+		int msg_idx = 0;
 		if(offline_stream->UseVad()){
-			audio.Split(offline_stream);
+			audio.CutSplit(offline_stream, index_vector);
 		}
+		std::vector<string> msgs(index_vector.size());
+		std::vector<float> msg_stimes(index_vector.size());
 
-		float* buff;
-		int len;
-		int flag = 0;
+		float** buff;
+		int* len;
+		int* flag;
+		float* start_time;
+		int batch_size = offline_stream->asr_handle->GetBatchSize();
+		int batch_in = 0;
 
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
-		float start_time = 0.0;
 		std::string cur_stamp = "[";
 		std::string lang = (offline_stream->asr_handle)->GetLang();
-		while (audio.Fetch(buff, len, flag, start_time) > 0) {
+		while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
 			// dec reset
 			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
 			if (wfst_decoder){
 				wfst_decoder->StartUtterance();
 			}
-			string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
+			vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
+			for(int idx=0; idx<batch_in; idx++){
+				string msg = msg_batch[idx];
+				if(msg_idx < index_vector.size()){
+					msgs[index_vector[msg_idx]] = msg;
+					msg_stimes[index_vector[msg_idx]] = start_time[idx];
+					msg_idx++;
+				}else{
+					LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
+				}				
+			}
+
+			// release
+			delete[] buff;
+			buff = nullptr;
+			delete[] len;
+			len = nullptr;
+			delete[] flag;
+			flag = nullptr;
+			delete[] start_time;
+			start_time = nullptr;
+		}
+		for(int idx=0; idx<msgs.size(); idx++){
+			string msg = msgs[idx];
 			std::vector<std::string> msg_vec = funasr::split(msg, '|');
 			if(msg_vec.size()==0){
 				continue;
@@ -274,14 +292,11 @@
 			if(msg_vec.size() > 1){
 				std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
 				for(int i=0; i<msg_stamp.size()-1; i+=2){
-					float begin = std::stof(msg_stamp[i])+start_time;
-					float end = std::stof(msg_stamp[i+1])+start_time;
+					float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
+					float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
 					cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
 				}
 			}
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
 		if(cur_stamp != "["){
 			cur_stamp.erase(cur_stamp.length() - 1);
@@ -303,7 +318,9 @@
 			p_result->msg = msg_itn;
 		}
 #endif
-
+		if (!(p_result->stamp).empty()){
+			p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
+		}
 		return p_result;
 	}
 
@@ -338,25 +355,53 @@
 		if(p_result->snippet_time == 0){
             return p_result;
         }
+		std::vector<int> index_vector={0};
+		int msg_idx = 0;
 		if(offline_stream->UseVad()){
-			audio.Split(offline_stream);
+			audio.CutSplit(offline_stream, index_vector);
 		}
+		std::vector<string> msgs(index_vector.size());
+		std::vector<float> msg_stimes(index_vector.size());
 
-		float* buff;
-		int len;
-		int flag = 0;
-		int n_step = 0;
-		int n_total = audio.GetQueueSize();
-		float start_time = 0.0;
+		float** buff;
+		int* len;
+		int* flag;
+		float* start_time;
+		int batch_size = offline_stream->asr_handle->GetBatchSize();
+		int batch_in = 0;
+
 		std::string cur_stamp = "[";
 		std::string lang = (offline_stream->asr_handle)->GetLang();
-		while (audio.Fetch(buff, len, flag, start_time) > 0) {
+		while (audio.FetchDynamic(buff, len, flag, start_time, batch_size, batch_in) > 0) {
 			// dec reset
 			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
 			if (wfst_decoder){
 				wfst_decoder->StartUtterance();
 			}
-			string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle);
+			vector<string> msg_batch = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb, dec_handle, batch_in);
+			for(int idx=0; idx<batch_in; idx++){
+				string msg = msg_batch[idx];
+				if(msg_idx < index_vector.size()){
+					msgs[index_vector[msg_idx]] = msg;
+					msg_stimes[index_vector[msg_idx]] = start_time[idx];
+					msg_idx++;
+				}else{
+					LOG(ERROR) << "msg_idx: " << msg_idx <<" is out of range " << index_vector.size();
+				}				
+			}
+
+			// release
+			delete[] buff;
+			buff = nullptr;
+			delete[] len;
+			len = nullptr;
+			delete[] flag;
+			flag = nullptr;
+			delete[] start_time;
+			start_time = nullptr;
+		}
+		for(int idx=0; idx<msgs.size(); idx++){
+			string msg = msgs[idx];
 			std::vector<std::string> msg_vec = funasr::split(msg, '|');
 			if(msg_vec.size()==0){
 				continue;
@@ -369,15 +414,11 @@
 			if(msg_vec.size() > 1){
 				std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
 				for(int i=0; i<msg_stamp.size()-1; i+=2){
-					float begin = std::stof(msg_stamp[i])+start_time;
-					float end = std::stof(msg_stamp[i+1])+start_time;
+					float begin = std::stof(msg_stamp[i])+msg_stimes[idx];
+					float end = std::stof(msg_stamp[i+1])+msg_stimes[idx];
 					cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
 				}
 			}
-
-			n_step++;
-			if (fn_callback)
-				fn_callback(n_step, n_total);
 		}
 		if(cur_stamp != "["){
 			cur_stamp.erase(cur_stamp.length() - 1);
@@ -399,10 +440,13 @@
 			p_result->msg = msg_itn;
 		}
 #endif
+		if (!(p_result->stamp).empty()){
+			p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
+		}
 		return p_result;
 	}
 
-#if !defined(__APPLE__)
+//#if !defined(__APPLE__)
 	_FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode)
 	{
 		if (mode == ASR_OFFLINE){
@@ -426,13 +470,13 @@
 		}
 		
 	}
-#endif
+//#endif
 
 	// APIs for 2pass-stream Infer
 	_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, 
 												 int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, 
 												 int sampling_rate, std::string wav_format, ASR_TYPE mode, 
-												 const std::vector<std::vector<float>> &hw_emb, bool itn)
+												 const std::vector<std::vector<float>> &hw_emb, bool itn, FUNASR_DEC_HANDLE dec_handle)
 	{
 		funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
 		funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
@@ -473,7 +517,7 @@
 		
 		audio->Split(vad_online_handle, chunk_len, input_finished, mode);
 
-		funasr::AudioFrame* frame = NULL;
+		funasr::AudioFrame* frame = nullptr;
 		while(audio->FetchChunck(frame) > 0){
 			string msg = ((funasr::ParaformerOnline*)asr_online_handle)->Forward(frame->data, frame->len, frame->is_final);
 			if(mode == ASR_ONLINE){
@@ -497,17 +541,28 @@
 			}else if(mode == ASR_TWO_PASS){
 				p_result->msg += msg;
 			}
-			if(frame != NULL){
+			if(frame != nullptr){
 				delete frame;
-				frame = NULL;
+				frame = nullptr;
 			}
 		}
 
 		// timestamp
 		std::string cur_stamp = "[";		
 		while(audio->FetchTpass(frame) > 0){
-			string msg = ((funasr::Paraformer*)asr_handle)->Forward(frame->data, frame->len, frame->is_final, hw_emb);
-
+			// dec reset
+			funasr::WfstDecoder* wfst_decoder = (funasr::WfstDecoder*)dec_handle;
+			if (wfst_decoder){
+				wfst_decoder->StartUtterance();
+			}
+			float** buff;
+			int* len;
+			buff = new float*[1];
+        	len = new int[1];
+			buff[0] = frame->data;
+			len[0] = frame->len;
+			vector<string> msgs = ((funasr::Paraformer*)asr_handle)->Forward(buff, len, frame->is_final, hw_emb, dec_handle);
+			string msg = msgs.size()>0?msgs[0]:"";
 			std::vector<std::string> msg_vec = funasr::split(msg, '|');  // split with timestamp
 			if(msg_vec.size()==0){
 				continue;
@@ -546,10 +601,12 @@
 				p_result->tpass_msg = msg_itn;
 			}
 #endif
-
-			if(frame != NULL){
+			if (!(p_result->stamp).empty()){
+				p_result->stamp_sents = funasr::TimestampSentence(p_result->tpass_msg, p_result->stamp);
+			}
+			if(frame != nullptr){
 				delete frame;
-				frame = NULL;
+				frame = nullptr;
 			}
 		}
 
@@ -601,6 +658,15 @@
 			return nullptr;
 
 		return p_result->stamp.c_str();
+	}
+
+		_FUNASRAPI const char* FunASRGetStampSents(FUNASR_RESULT result)
+	{
+		funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
+		if(!p_result)
+			return nullptr;
+
+		return p_result->stamp_sents.c_str();
 	}
 
 	_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index)
@@ -744,10 +810,45 @@
 		funasr::WfstDecoder* mm = nullptr;
 		if (asr_type == ASR_OFFLINE) {
 			funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
-			funasr::Paraformer* paraformer = (funasr::Paraformer*)offline_stream->asr_handle.get();
-			if (paraformer->lm_)
-				mm = new funasr::WfstDecoder(paraformer->lm_.get(), 
-					paraformer->GetPhoneSet(), paraformer->GetVocab(), glob_beam, lat_beam, am_scale);
+			auto paraformer = dynamic_cast<funasr::Paraformer*>(offline_stream->asr_handle.get());
+			if(paraformer !=nullptr){
+				if (paraformer->lm_){
+					mm = new funasr::WfstDecoder(paraformer->lm_.get(),
+						paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+				}
+				return mm;
+			}
+			#ifdef USE_GPU
+			auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(offline_stream->asr_handle.get());
+			if(paraformer_torch !=nullptr){
+				if (paraformer_torch->lm_){
+					mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
+						paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
+				}
+				return mm;
+			}
+			#endif
+
+		} else if (asr_type == ASR_TWO_PASS){
+			funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
+			auto paraformer = dynamic_cast<funasr::Paraformer*>(tpass_stream->asr_handle.get());
+			if(paraformer !=nullptr){
+				if (paraformer->lm_){
+					mm = new funasr::WfstDecoder(paraformer->lm_.get(),
+						paraformer->GetPhoneSet(), paraformer->GetLmVocab(), glob_beam, lat_beam, am_scale);
+				}
+				return mm;
+			}
+			#ifdef USE_GPU
+			auto paraformer_torch = dynamic_cast<funasr::ParaformerTorch*>(tpass_stream->asr_handle.get());
+			if(paraformer_torch !=nullptr){
+				if (paraformer_torch->lm_){
+					mm = new funasr::WfstDecoder(paraformer_torch->lm_.get(),
+						paraformer_torch->GetPhoneSet(), paraformer_torch->GetLmVocab(), glob_beam, lat_beam, am_scale);
+				}
+				return mm;
+			}
+			#endif
 		}
 		return mm;
 	}

--
Gitblit v1.9.1