From 41c64e4729ca359f7212534055239c8289b5e2f4 Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 21 九月 2023 16:15:15 +0800
Subject: [PATCH] Merge pull request #975 from alibaba-damo-academy/main

---
 funasr/runtime/onnxruntime/src/funasrruntime.cpp |   96 +++++++++++++++++++++++++++++++++++++++++-------
 1 files changed, 82 insertions(+), 14 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
index 207cf8b..0d4af5c 100644
--- a/funasr/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -217,7 +217,9 @@
 	}
 
 	// APIs for Offline-stream Infer
-	_FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, int sampling_rate, std::string wav_format)
+	_FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, 
+												   FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, 
+												   int sampling_rate, std::string wav_format, bool itn)
 	{
 		funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
 		if (!offline_stream)
@@ -283,11 +285,18 @@
 			string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
 			p_result->msg = punc_res;
 		}
+#if !defined(__APPLE__)
+		if(offline_stream->UseITN() && itn){
+			string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
+			p_result->msg = msg_itn;
+		}
+#endif
 
 		return p_result;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, int sampling_rate)
+	_FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, 
+											 const std::vector<std::vector<float>> &hw_emb, int sampling_rate, bool itn)
 	{
 		funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
 		if (!offline_stream)
@@ -357,20 +366,46 @@
 			string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
 			p_result->msg = punc_res;
 		}
-	
+#if !defined(__APPLE__)
+		if(offline_stream->UseITN() && itn){
+			string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
+			p_result->msg = msg_itn;
+		}
+#endif
 		return p_result;
 	}
 
-	_FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords) {
-		funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
-    	std::vector<std::vector<float>> emb;
-		if (!offline_stream)
+#if !defined(__APPLE__)
+	_FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode)
+	{
+		if (mode == ASR_OFFLINE){
+			funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
+    		std::vector<std::vector<float>> emb;
+			if (!offline_stream)
+				return emb;
+			return (offline_stream->asr_handle)->CompileHotwordEmbedding(hotwords);
+		}
+		else if (mode == ASR_TWO_PASS){
+			funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
+    		std::vector<std::vector<float>> emb;
+			if (!tpass_stream)
+				return emb;
+			return (tpass_stream->asr_handle)->CompileHotwordEmbedding(hotwords);
+		}
+		else{
+			LOG(ERROR) << "Not implement: Online model does not support Hotword yet!";
+			std::vector<std::vector<float>> emb;
 			return emb;
-		return (offline_stream->asr_handle)->CompileHotwordEmbedding(hotwords);
+		}
+		
 	}
+#endif
 
 	// APIs for 2pass-stream Infer
-	_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, int sampling_rate, std::string wav_format, ASR_TYPE mode)
+	_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, 
+												 int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, 
+												 int sampling_rate, std::string wav_format, ASR_TYPE mode, 
+												 const std::vector<std::vector<float>> &hw_emb, bool itn)
 	{
 		funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
 		funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
@@ -408,9 +443,6 @@
 
 		funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
 		p_result->snippet_time = audio->GetTimeLen();
-		// if(p_result->snippet_time == 0){
-		// 	return p_result;
-		// }
 		
 		audio->Split(vad_online_handle, chunk_len, input_finished, mode);
 
@@ -423,6 +455,13 @@
 					string online_msg = ((funasr::ParaformerOnline*)asr_online_handle)->online_res;
 					string msg_punc = punc_online_handle->AddPunc(online_msg.c_str(), punc_cache[0]);
 					p_result->tpass_msg = msg_punc;
+#if !defined(__APPLE__)
+					// ITN
+					if(tpass_stream->UseITN() && itn){
+						string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
+						p_result->tpass_msg = msg_itn;
+					}
+#endif
 					((funasr::ParaformerOnline*)asr_online_handle)->online_res = "";
 					p_result->msg += msg;
 				}else{
@@ -437,13 +476,43 @@
 			}
 		}
 
+		// timestamp
+		std::string cur_stamp = "[";		
 		while(audio->FetchTpass(frame) > 0){
-			string msg = asr_handle->Forward(frame->data, frame->len, frame->is_final);
+			string msg = asr_handle->Forward(frame->data, frame->len, frame->is_final, hw_emb);
+
+			std::vector<std::string> msg_vec = funasr::split(msg, '|');  // split with timestamp
+			if(msg_vec.size()==0){
+				continue;
+			}
+			msg = msg_vec[0];
+			//timestamp
+			if(msg_vec.size() > 1){
+				std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
+				for(int i=0; i<msg_stamp.size()-1; i+=2){
+					float begin = std::stof(msg_stamp[i]) + float(frame->global_start)/1000.0;
+					float end = std::stof(msg_stamp[i+1]) + float(frame->global_start)/1000.0;
+					cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"],";
+				}
+			}
+
+			if(cur_stamp != "["){
+				cur_stamp.erase(cur_stamp.length() - 1);
+				p_result->stamp += cur_stamp + "]";
+			}
+
 			string msg_punc = punc_online_handle->AddPunc(msg.c_str(), punc_cache[1]);
 			if(input_finished){
 				msg_punc += "銆�";
 			}
 			p_result->tpass_msg = msg_punc;
+#if !defined(__APPLE__)
+			if(tpass_stream->UseITN() && itn){
+				string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
+				p_result->tpass_msg = msg_itn;
+			}
+#endif
+
 			if(frame != NULL){
 				delete frame;
 				frame = NULL;
@@ -461,7 +530,6 @@
 	{
 		if (!result)
 			return 0;
-
 		return 1;
 	}
 

--
Gitblit v1.9.1