From d20c030e5b75306dd67e8fe9924d5d94eac1bf30 Mon Sep 17 00:00:00 2001
From: wusong <63332221+wusong1128@users.noreply.github.com>
Date: 星期三, 25 九月 2024 15:11:50 +0800
Subject: [PATCH] 解决python ws服务针对尾部非人声录音无结束标识返回的问题 (#2102)

---
 runtime/onnxruntime/src/util.cpp |   37 +++++++++++++++++++++++++------------
 1 files changed, 25 insertions(+), 12 deletions(-)

diff --git a/runtime/onnxruntime/src/util.cpp b/runtime/onnxruntime/src/util.cpp
index 7723e5f..483795e 100644
--- a/runtime/onnxruntime/src/util.cpp
+++ b/runtime/onnxruntime/src/util.cpp
@@ -365,9 +365,13 @@
   }
 }
 
-std::string VectorToString(const std::vector<std::vector<int>>& vec) {
+std::string VectorToString(const std::vector<std::vector<int>>& vec, bool out_empty) {
     if(vec.size() == 0){
-        return "";
+        if(out_empty){
+            return "";
+        }else{
+            return "[]";
+        }
     }
     std::ostringstream out;
     out << "[";
@@ -584,11 +588,11 @@
                 }
             }
             // format
-            ts_sent += "{'text_seg':\"" + text_seg + "\",";
-            ts_sent += "'punc':'" + characters[idx_str] + "',";
-            ts_sent += "'start':'" + to_string(start) + "',";
-            ts_sent += "'end':'" + to_string(end) + "',";
-            ts_sent += "'ts_list':" + VectorToString(ts_seg) + "}";
+            ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
+            ts_sent += "\"punc\":\"" + characters[idx_str] + "\",";
+            ts_sent += "\"start\":" + to_string(start) + ",";
+            ts_sent += "\"end\":" + to_string(end) + ",";
+            ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
             
             if (idx_str == characters.size()-1){
                 ts_sentences += ts_sent;
@@ -621,11 +625,11 @@
             end = ts_seg[ts_seg.size()-1][1];
         }
         // format
-        ts_sent += "{'text_seg':\"" + text_seg + "\",";
-        ts_sent += "'punc':'',";
-        ts_sent += "'start':'" + to_string(start) + "',";
-        ts_sent += "'end':'" + to_string(end) + "',";
-        ts_sent += "'ts_list':" + VectorToString(ts_seg) + "}";
+        ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
+        ts_sent += "\"punc\":\"\",";
+        ts_sent += "\"start\":" + to_string(start) + ",";
+        ts_sent += "\"end\":" + to_string(end) + ",";
+        ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
         ts_sentences += ts_sent;
     }
 
@@ -866,6 +870,15 @@
                 sum -=(1.0 - 1e-4);
             }            
         }
+        // fix case: sum > 1
+        int cif_idx = cif_peak.size()-1;
+        while(sum>=1.0 - 1e-4 && cif_idx >= 0 ){
+            if(cif_peak[cif_idx] < 1.0 - 1e-4){
+                cif_peak[cif_idx] = sum;
+                sum -=(1.0 - 1e-4);
+            }
+            cif_idx--;
+        }
 
         fire_place.clear();
         for (int i = 0; i < num_frames; i++) {

--
Gitblit v1.9.1