From f72914003a8c4ab7ae72d52dbd7c5f70ea22313a Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期一, 18 十二月 2023 17:33:24 +0800
Subject: [PATCH] add sentence timestamp
---
runtime/onnxruntime/include/funasrruntime.h | 1
runtime/onnxruntime/bin/funasr-onnx-offline.cpp | 4 +
runtime/onnxruntime/src/util.cpp | 73 ++++++++++++++++++++++++
runtime/websocket/bin/websocket-server.cpp | 7 ++
runtime/websocket/bin/websocket-server.h | 1
runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp | 4 +
runtime/websocket/bin/websocket-server-2pass.cpp | 8 ++
runtime/onnxruntime/src/commonfunc.h | 1
runtime/onnxruntime/src/funasrruntime.cpp | 20 ++++++
runtime/onnxruntime/src/util.h | 2
10 files changed, 115 insertions(+), 6 deletions(-)
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index 39928b4..b1a7c87 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -83,6 +83,10 @@
if(stamp !=""){
LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << stamp;
}
+ string stamp_sents = FunASRGetStampSents(result);
+ if(stamp_sents !=""){
+ LOG(INFO)<< wav_ids[i] <<" : "<<stamp_sents;
+ }
float snippet_time = FunASRGetRetSnippetTime(result);
n_total_length += snippet_time;
FunASRFreeResult(result);
diff --git a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
index eb908d8..55eda93 100644
--- a/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
+++ b/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -172,6 +172,10 @@
if(stamp !=""){
LOG(INFO)<< wav_id <<" : "<<stamp;
}
+ string stamp_sents = FunASRGetStampSents(result);
+ if(stamp_sents !=""){
+ LOG(INFO)<< wav_id <<" : "<<stamp_sents;
+ }
snippet_time += FunASRGetRetSnippetTime(result);
FunASRFreeResult(result);
}
diff --git a/runtime/onnxruntime/include/funasrruntime.h b/runtime/onnxruntime/include/funasrruntime.h
index 3b52f38..27ee6c6 100644
--- a/runtime/onnxruntime/include/funasrruntime.h
+++ b/runtime/onnxruntime/include/funasrruntime.h
@@ -68,6 +68,7 @@
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const char* FunASRGetStamp(FUNASR_RESULT result);
+_FUNASRAPI const char* FunASRGetStampSents(FUNASR_RESULT result);
_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result);
diff --git a/runtime/onnxruntime/src/commonfunc.h b/runtime/onnxruntime/src/commonfunc.h
index 9bd2a00..3449ebc 100644
--- a/runtime/onnxruntime/src/commonfunc.h
+++ b/runtime/onnxruntime/src/commonfunc.h
@@ -9,6 +9,7 @@
{
std::string msg;
std::string stamp;
+ std::string stamp_sents;
std::string tpass_msg;
float snippet_time;
}FUNASR_RECOG_RESULT;
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 21f7d82..ccd0412 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -303,7 +303,9 @@
p_result->msg = msg_itn;
}
#endif
-
+ if (!(p_result->stamp).empty()){
+ p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
+ }
return p_result;
}
@@ -399,6 +401,9 @@
p_result->msg = msg_itn;
}
#endif
+ if (!(p_result->stamp).empty()){
+ p_result->stamp_sents = funasr::TimestampSentence(p_result->msg, p_result->stamp);
+ }
return p_result;
}
@@ -546,7 +551,9 @@
p_result->tpass_msg = msg_itn;
}
#endif
-
+ if (!(p_result->stamp).empty()){
+ p_result->stamp_sents = funasr::TimestampSentence(p_result->tpass_msg, p_result->stamp);
+ }
if(frame != NULL){
delete frame;
frame = NULL;
@@ -603,6 +610,15 @@
return p_result->stamp.c_str();
}
+ _FUNASRAPI const char* FunASRGetStampSents(FUNASR_RESULT result)
+ {
+ funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
+ if(!p_result)
+ return nullptr;
+
+ return p_result->stamp_sents.c_str();
+ }
+
_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index)
{
funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
diff --git a/runtime/onnxruntime/src/util.cpp b/runtime/onnxruntime/src/util.cpp
index 2738d35..ac793f5 100644
--- a/runtime/onnxruntime/src/util.cpp
+++ b/runtime/onnxruntime/src/util.cpp
@@ -255,7 +255,8 @@
}
bool TimestampIsPunctuation(const std::string& str) {
- const std::string punctuation = u8"锛屻�傦紵銆�,.?";
+ const std::string punctuation = u8"锛屻�傦紵銆�,?";
+ // const std::string punctuation = u8"锛屻�傦紵銆�,.?";
for (char ch : str) {
if (punctuation.find(ch) == std::string::npos) {
return false;
@@ -557,6 +558,76 @@
return timestamps_str;
}
+std::string TimestampSentence(std::string &text, std::string &str_time){
+ std::vector<std::string> characters;
+ funasr::TimestampSplitChiEngCharacters(text, characters);
+ vector<vector<int>> timestamps = funasr::ParseTimestamps(str_time);
+
+ int idx_str = 0, idx_ts = 0;
+ int start = -1, end = -1;
+ std::string text_seg = "";
+ std::string ts_sentences = "";
+ std::string ts_sent = "";
+ vector<vector<int>> ts_seg;
+ while(idx_str < characters.size()){
+ if (TimestampIsPunctuation(characters[idx_str])){
+ if(ts_seg.size() >0){
+ if (ts_seg[0].size() == 2){
+ start = ts_seg[0][0];
+ }
+ if (ts_seg[ts_seg.size()-1].size() == 2){
+ end = ts_seg[ts_seg.size()-1][1];
+ }
+ }
+ // format
+ ts_sent += "{'text':'" + text_seg + "',";
+ ts_sent += "'start':'" + to_string(start) + "',";
+ ts_sent += "'end':'" + to_string(end) + "',";
+ ts_sent += "'ts_list':" + VectorToString(ts_seg) + "}";
+
+ if (idx_str == characters.size()-1){
+ ts_sentences += ts_sent;
+ } else{
+ ts_sentences += ts_sent + ",";
+ }
+
+ // clear
+ idx_str++;
+ text_seg = "";
+ ts_sent = "";
+ start = 0;
+ end = 0;
+ ts_seg.clear();
+ } else if(idx_ts < timestamps.size()) {
+ if (text_seg.empty()){
+ text_seg = characters[idx_str];
+ }else{
+ text_seg += " " + characters[idx_str];
+ }
+ ts_seg.push_back(timestamps[idx_ts]);
+ idx_str++;
+ idx_ts++;
+ }
+ }
+ // for none punc results
+ if(ts_seg.size() >0){
+ if (ts_seg[0].size() == 2){
+ start = ts_seg[0][0];
+ }
+ if (ts_seg[ts_seg.size()-1].size() == 2){
+ end = ts_seg[ts_seg.size()-1][1];
+ }
+ // format
+ ts_sent += "{'text':'" + text_seg + "',";
+ ts_sent += "'start':'" + to_string(start) + "',";
+ ts_sent += "'end':'" + to_string(end) + "',";
+ ts_sent += "'ts_list':" + VectorToString(ts_seg) + "}";
+ ts_sentences += ts_sent;
+ }
+
+ return "[" +ts_sentences + "]";
+}
+
std::vector<std::string> split(const std::string &s, char delim) {
std::vector<std::string> elems;
std::stringstream ss(s);
diff --git a/runtime/onnxruntime/src/util.h b/runtime/onnxruntime/src/util.h
index 46d24b3..eda9b49 100644
--- a/runtime/onnxruntime/src/util.h
+++ b/runtime/onnxruntime/src/util.h
@@ -47,7 +47,7 @@
std::vector<std::string> &characters);
std::string VectorToString(const std::vector<std::vector<int>>& vec);
std::string TimestampSmooth(std::string &text, std::string &text_itn, std::string &str_time);
-
+std::string TimestampSentence(std::string &text, std::string &str_time);
std::vector<std::string> split(const std::string &s, char delim);
template<typename T>
diff --git a/runtime/websocket/bin/websocket-server-2pass.cpp b/runtime/websocket/bin/websocket-server-2pass.cpp
index 499c950..44dd82e 100644
--- a/runtime/websocket/bin/websocket-server-2pass.cpp
+++ b/runtime/websocket/bin/websocket-server-2pass.cpp
@@ -80,6 +80,12 @@
jsonresult["timestamp"] = tmp_stamp_msg;
}
+ std::string tmp_stamp_sents = FunASRGetStampSents(result);
+ if (tmp_stamp_sents != "") {
+ LOG(INFO) << "offline stamp_sents : " << tmp_stamp_sents;
+ jsonresult["stamp_sents"] = tmp_stamp_sents;
+ }
+
return jsonresult;
}
// feed buffer to asr engine for decoder
@@ -318,7 +324,7 @@
data_msg->msg["is_eof"]=true;
guard_decoder.unlock();
to_remove.push_back(hdl);
- LOG(INFO)<<"connection is closed: "<<e.what();
+ LOG(INFO)<<"connection is closed.";
}
iter++;
diff --git a/runtime/websocket/bin/websocket-server.cpp b/runtime/websocket/bin/websocket-server.cpp
index f1cd38b..42bc60a 100644
--- a/runtime/websocket/bin/websocket-server.cpp
+++ b/runtime/websocket/bin/websocket-server.cpp
@@ -74,6 +74,7 @@
if (!buffer.empty() && hotwords_embedding.size() > 0) {
std::string asr_result;
std::string stamp_res;
+ std::string stamp_sents;
try{
FUNASR_RESULT Result = FunOfflineInferBuffer(
asr_handle, buffer.data(), buffer.size(), RASR_NONE, NULL,
@@ -81,6 +82,7 @@
asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg; // get decode result
stamp_res = ((FUNASR_RECOG_RESULT*)Result)->stamp;
+ stamp_sents = ((FUNASR_RECOG_RESULT*)Result)->stamp_sents;
FunASRFreeResult(Result);
}catch (std::exception const& e) {
LOG(ERROR) << e.what();
@@ -94,6 +96,9 @@
jsonresult["is_final"] = false;
if(stamp_res != ""){
jsonresult["timestamp"] = stamp_res;
+ }
+ if(stamp_sents != ""){
+ jsonresult["stamp_sents"] = stamp_sents;
}
jsonresult["wav_name"] = wav_name;
@@ -227,7 +232,7 @@
data_msg->msg["is_eof"]=true;
guard_decoder.unlock();
to_remove.push_back(hdl);
- LOG(INFO)<<"connection is closed: "<<e.what();
+ LOG(INFO)<<"connection is closed.";
}
iter++;
diff --git a/runtime/websocket/bin/websocket-server.h b/runtime/websocket/bin/websocket-server.h
index d511071..cacf12d 100644
--- a/runtime/websocket/bin/websocket-server.h
+++ b/runtime/websocket/bin/websocket-server.h
@@ -50,6 +50,7 @@
typedef struct {
std::string msg="";
std::string stamp="";
+ std::string stamp_sents;
std::string tpass_msg="";
float snippet_time=0;
} FUNASR_RECOG_RESULT;
--
Gitblit v1.9.1