From b454a1054fadbff0ee963944ff42f66b98317582 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期二, 08 八月 2023 11:17:43 +0800
Subject: [PATCH] update online runtime, including vad-online, paraformer-online, punc-online,2pass (#815)

---
 funasr/runtime/onnxruntime/src/funasrruntime.cpp |  149 ++++++++++++++++++++++++++++++++++++++++++++++---
 1 files changed, 140 insertions(+), 9 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
index a1829fd..2e6a079 100644
--- a/funasr/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -5,9 +5,15 @@
 #endif
 
 	// APIs for Init
-	_FUNASRAPI FUNASR_HANDLE  FunASRInit(std::map<std::string, std::string>& model_path, int thread_num)
+	_FUNASRAPI FUNASR_HANDLE  FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type)
 	{
-		funasr::Model* mm = funasr::CreateModel(model_path, thread_num);
+		funasr::Model* mm = funasr::CreateModel(model_path, thread_num, type);
+		return mm;
+	}
+
+	_FUNASRAPI FUNASR_HANDLE  FunASROnlineInit(FUNASR_HANDLE asr_hanlde, std::vector<int> chunk_size)
+	{
+		funasr::Model* mm = funasr::CreateModel(asr_hanlde, chunk_size);
 		return mm;
 	}
 
@@ -35,8 +41,19 @@
 		return mm;
 	}
 
+	_FUNASRAPI FUNASR_HANDLE  FunTpassInit(std::map<std::string, std::string>& model_path, int thread_num)
+	{
+		funasr::TpassStream* mm = funasr::CreateTpassStream(model_path, thread_num);
+		return mm;
+	}
+
+	_FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size)
+	{
+		return funasr::CreateTpassOnlineStream(tpass_handle, chunk_size);
+	}
+
 	// APIs for ASR Infer
-	_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate, std::string wav_format)
+	_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate, std::string wav_format)
 	{
 		funasr::Model* recog_obj = (funasr::Model*)handle;
 		if (!recog_obj)
@@ -57,12 +74,12 @@
 		funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
 		p_result->snippet_time = audio.GetTimeLen();
 		if(p_result->snippet_time == 0){
-            return p_result;
-        }
+			return p_result;
+		}
 		int n_step = 0;
 		int n_total = audio.GetQueueSize();
 		while (audio.Fetch(buff, len, flag) > 0) {
-			string msg = recog_obj->Forward(buff, len, flag);
+			string msg = recog_obj->Forward(buff, len, input_finished);
 			p_result->msg += msg;
 			n_step++;
 			if (fn_callback)
@@ -102,7 +119,7 @@
             return p_result;
         }
 		while (audio.Fetch(buff, len, flag) > 0) {
-			string msg = recog_obj->Forward(buff, len, flag);
+			string msg = recog_obj->Forward(buff, len, true);
 			p_result->msg += msg;
 			n_step++;
 			if (fn_callback)
@@ -230,7 +247,7 @@
 		int n_step = 0;
 		int n_total = audio.GetQueueSize();
 		while (audio.Fetch(buff, len, flag) > 0) {
-			string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
+			string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
 			p_result->msg += msg;
 			n_step++;
 			if (fn_callback)
@@ -277,7 +294,7 @@
 		int n_step = 0;
 		int n_total = audio.GetQueueSize();
 		while (audio.Fetch(buff, len, flag) > 0) {
-			string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
+			string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
 			p_result->msg+= msg;
 			n_step++;
 			if (fn_callback)
@@ -288,6 +305,91 @@
 			p_result->msg = punc_res;
 		}
 	
+		return p_result;
+	}
+
+	// APIs for 2pass-stream Infer
+	_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, int sampling_rate, std::string wav_format, ASR_TYPE mode)
+	{
+		funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
+		funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
+		if (!tpass_stream || !tpass_online_stream)
+			return nullptr;
+		
+		funasr::VadModel* vad_online_handle = (tpass_online_stream->vad_online_handle).get();
+		if (!vad_online_handle)
+			return nullptr;
+
+		funasr::Audio* audio = ((funasr::FsmnVadOnline*)vad_online_handle)->audio_handle.get();
+
+		funasr::Model* asr_online_handle = (tpass_online_stream->asr_online_handle).get();
+		if (!asr_online_handle)
+			return nullptr;
+		int chunk_len = ((funasr::ParaformerOnline*)asr_online_handle)->chunk_len;
+		
+		funasr::Model* asr_handle = (tpass_stream->asr_handle).get();
+		if (!asr_handle)
+			return nullptr;
+
+		funasr::PuncModel* punc_online_handle = (tpass_stream->punc_online_handle).get();
+		if (!punc_online_handle)
+			return nullptr;
+
+		if(wav_format == "pcm" || wav_format == "PCM"){
+			if (!audio->LoadPcmwavOnline(sz_buf, n_len, &sampling_rate))
+				return nullptr;
+		}else{
+			// if (!audio->FfmpegLoad(sz_buf, n_len))
+			// 	return nullptr;
+			LOG(ERROR) <<"Wrong wav_format: " << wav_format ;
+			exit(-1);
+		}
+
+		funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
+		p_result->snippet_time = audio->GetTimeLen();
+		if(p_result->snippet_time == 0){
+			return p_result;
+		}
+		
+		audio->Split(vad_online_handle, chunk_len, input_finished, mode);
+
+		funasr::AudioFrame* frame = NULL;
+		while(audio->FetchChunck(frame) > 0){
+			string msg = asr_online_handle->Forward(frame->data, frame->len, frame->is_final);
+			if(mode == ASR_ONLINE){
+				((funasr::ParaformerOnline*)asr_online_handle)->online_res += msg;
+				if(frame->is_final){
+					string online_msg = ((funasr::ParaformerOnline*)asr_online_handle)->online_res;
+					string msg_punc = punc_online_handle->AddPunc(online_msg.c_str(), punc_cache[0]);
+					p_result->tpass_msg = msg_punc;
+					((funasr::ParaformerOnline*)asr_online_handle)->online_res = "";
+					p_result->msg += msg;
+				}else{
+					p_result->msg += msg;
+				}
+			}else if(mode == ASR_TWO_PASS){
+				p_result->msg += msg;
+			}
+			if(frame != NULL){
+				delete frame;
+				frame = NULL;
+			}
+		}
+
+		while(audio->FetchTpass(frame) > 0){
+			string msg = asr_handle->Forward(frame->data, frame->len, frame->is_final);
+			string msg_punc = punc_online_handle->AddPunc(msg.c_str(), punc_cache[1]);
+			p_result->tpass_msg = msg_punc;
+			if(frame != NULL){
+				delete frame;
+				frame = NULL;
+			}
+		}
+
+		if(input_finished){
+			audio->ResetIndex();
+		}
+
 		return p_result;
 	}
 
@@ -324,6 +426,15 @@
 			return nullptr;
 
 		return p_result->msg.c_str();
+	}
+
+	_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index)
+	{
+		funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
+		if(!p_result)
+			return nullptr;
+
+		return p_result->tpass_msg.c_str();
 	}
 
 	_FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index)
@@ -414,6 +525,26 @@
 		delete offline_stream;
 	}
 
+	_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle)
+	{
+		funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
+
+		if (!tpass_stream)
+			return;
+
+		delete tpass_stream;
+	}
+
+	_FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle)
+	{
+		funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)handle;
+
+		if (!tpass_online_stream)
+			return;
+
+		delete tpass_online_stream;
+	}
+
 #ifdef __cplusplus 
 
 }

--
Gitblit v1.9.1