From ec1c46280be9ccf295962ee1abef5b26f7464095 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 23 一月 2024 23:52:14 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 runtime/docs/SDK_tutorial_online_zh.md            |    2 
 runtime/websocket/bin/funasr-wss-client-2pass.cpp |    5 --
 runtime/onnxruntime/src/util.cpp                  |    8 ++--
 runtime/websocket/bin/funasr-wss-client.cpp       |    5 --
 runtime/websocket/CMakeLists.txt                  |    6 +-
 runtime/websocket/bin/websocket-server-2pass.cpp  |    2 
 runtime/onnxruntime/src/audio.cpp                 |   54 +++++++++++++++++++++++++++
 runtime/onnxruntime/src/funasrruntime.cpp         |    2 
 runtime/docs/SDK_tutorial_online.md               |    4 +-
 runtime/onnxruntime/include/audio.h               |    1 
 10 files changed, 69 insertions(+), 20 deletions(-)

diff --git a/runtime/docs/SDK_tutorial_online.md b/runtime/docs/SDK_tutorial_online.md
index 4683761..c8e78be 100644
--- a/runtime/docs/SDK_tutorial_online.md
+++ b/runtime/docs/SDK_tutorial_online.md
@@ -59,7 +59,7 @@
 If you want to run the client directly for testing, you can refer to the following simple instructions, using the Python version as an example:
 
 ```shell
-python3 funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode offline --audio_in "../audio/asr_example.wav"
+python3 funasr_wss_client.py --host "127.0.0.1" --port 10095 --mode offline --audio_in "../audio/asr_example.pcm"
 ```
 
 Command parameter instructions:
@@ -79,7 +79,7 @@
 
 After entering the samples/cpp directory, you can test it with CPP. The command is as follows:
 ```shell
-./funasr-wss-client-2pass --server-ip 127.0.0.1 --port 10095 --wav-path ../audio/asr_example.wav
+./funasr-wss-client-2pass --server-ip 127.0.0.1 --port 10095 --wav-path ../audio/asr_example.pcm
 ```
 
 Command parameter description:
diff --git a/runtime/docs/SDK_tutorial_online_zh.md b/runtime/docs/SDK_tutorial_online_zh.md
index e6705de..ab48ec7 100644
--- a/runtime/docs/SDK_tutorial_online_zh.md
+++ b/runtime/docs/SDK_tutorial_online_zh.md
@@ -84,7 +84,7 @@
 杩涘叆samples/cpp鐩綍鍚庯紝鍙互鐢╟pp杩涜娴嬭瘯锛屾寚浠ゅ涓嬶細
 ```shell
 ./funasr-wss-client-2pass --server-ip 127.0.0.1 --port 10095 --mode 2pass \
-   --wav-path ../audio/asr_example.wav
+   --wav-path ../audio/asr_example.pcm
 ```
 
 鍛戒护鍙傛暟璇存槑锛�
diff --git a/runtime/onnxruntime/include/audio.h b/runtime/onnxruntime/include/audio.h
index 5194aa2..98f2169 100644
--- a/runtime/onnxruntime/include/audio.h
+++ b/runtime/onnxruntime/include/audio.h
@@ -76,6 +76,7 @@
     int Fetch(float *&dout, int &len, int &flag, float &start_time);
     void Padding();
     void Split(OfflineStream* offline_streamj);
+    void CutSplit(OfflineStream* offline_streamj);
     void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
     void Split(VadModel* vad_obj, int chunk_len, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
     float GetTimeLen();
diff --git a/runtime/onnxruntime/src/audio.cpp b/runtime/onnxruntime/src/audio.cpp
index 6f829cc..9b93dc8 100644
--- a/runtime/onnxruntime/src/audio.cpp
+++ b/runtime/onnxruntime/src/audio.cpp
@@ -1085,6 +1085,60 @@
     }
 }
 
+void Audio::CutSplit(OfflineStream* offline_stream)
+{
+    std::unique_ptr<VadModel> vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(offline_stream->vad_handle).get());
+    AudioFrame *frame;
+
+    frame = frame_queue.front();
+    frame_queue.pop();
+    int sp_len = frame->GetLen();
+    delete frame;
+    frame = nullptr;
+
+    int step = dest_sample_rate*10;
+    bool is_final=false;
+    vector<std::vector<int>> vad_segments;
+    for (int sample_offset = 0; sample_offset < speech_len; sample_offset += std::min(step, speech_len - sample_offset)) {
+        if (sample_offset + step >= speech_len - 1) {
+                step = speech_len - sample_offset;
+                is_final = true;
+            } else {
+                is_final = false;
+        }
+        std::vector<float> pcm_data(speech_data+sample_offset, speech_data+sample_offset+step);
+        vector<std::vector<int>> cut_segments = vad_online_handle->Infer(pcm_data, is_final);
+        vad_segments.insert(vad_segments.end(), cut_segments.begin(), cut_segments.end());
+    }    
+
+    int speech_start_i = -1, speech_end_i =-1;
+    for(vector<int> vad_segment:vad_segments)
+    {
+        if(vad_segment.size() != 2){
+            LOG(ERROR) << "Size of vad_segment is not 2.";
+            break;
+        }
+        if(vad_segment[0] != -1){
+            speech_start_i = vad_segment[0];
+        }
+        if(vad_segment[1] != -1){
+            speech_end_i = vad_segment[1];
+        }
+
+        if(speech_start_i!=-1 && speech_end_i!=-1){
+            frame = new AudioFrame();
+            int start = speech_start_i*seg_sample;
+            int end = speech_end_i*seg_sample;
+            frame->SetStart(start);
+            frame->SetEnd(end);
+            frame_queue.push(frame);
+            frame = nullptr;
+            speech_start_i=-1;
+            speech_end_i=-1;
+        }
+    }
+}
+
 void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished)
 {
     AudioFrame *frame;
diff --git a/runtime/onnxruntime/src/funasrruntime.cpp b/runtime/onnxruntime/src/funasrruntime.cpp
index 0ca4ded..68a9f09 100644
--- a/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/runtime/onnxruntime/src/funasrruntime.cpp
@@ -245,7 +245,7 @@
             return p_result;
         }
 		if(offline_stream->UseVad()){
-			audio.Split(offline_stream);
+			audio.CutSplit(offline_stream);
 		}
 
 		float* buff;
diff --git a/runtime/onnxruntime/src/util.cpp b/runtime/onnxruntime/src/util.cpp
index 039fa90..a12570b 100644
--- a/runtime/onnxruntime/src/util.cpp
+++ b/runtime/onnxruntime/src/util.cpp
@@ -590,8 +590,8 @@
             // format
             ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
             ts_sent += "\"punc\":\"" + characters[idx_str] + "\",";
-            ts_sent += "\"start\":\"" + to_string(start) + "\",";
-            ts_sent += "\"end\":\"" + to_string(end) + "\",";
+            ts_sent += "\"start\":" + to_string(start) + ",";
+            ts_sent += "\"end\":" + to_string(end) + ",";
             ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
             
             if (idx_str == characters.size()-1){
@@ -627,8 +627,8 @@
         // format
         ts_sent += "{\"text_seg\":\"" + text_seg + "\",";
         ts_sent += "\"punc\":\"\",";
-        ts_sent += "\"start\":\"" + to_string(start) + "\",";
-        ts_sent += "\"end\":\"" + to_string(end) + "\",";
+        ts_sent += "\"start\":" + to_string(start) + ",";
+        ts_sent += "\"end\":" + to_string(end) + ",";
         ts_sent += "\"ts_list\":" + VectorToString(ts_seg, false) + "}";
         ts_sentences += ts_sent;
     }
diff --git a/runtime/websocket/CMakeLists.txt b/runtime/websocket/CMakeLists.txt
index 13472da..ba6497a 100644
--- a/runtime/websocket/CMakeLists.txt
+++ b/runtime/websocket/CMakeLists.txt
@@ -31,7 +31,7 @@
   # cmake_policy(SET CMP0135 NEW)
   include(FetchContent)
 
-  if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/websocket )
+  if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/websocket/websocketpp )
     FetchContent_Declare(websocketpp
     GIT_REPOSITORY https://github.com/zaphoyd/websocketpp.git
       GIT_TAG 0.8.2
@@ -42,7 +42,7 @@
   endif()
   include_directories(${PROJECT_SOURCE_DIR}/third_party/websocket)
    
-  if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/asio )
+  if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/asio/asio )
     FetchContent_Declare(asio
       URL   https://github.com/chriskohlhoff/asio/archive/refs/tags/asio-1-24-0.tar.gz
     SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/asio
@@ -52,7 +52,7 @@
   endif()
   include_directories(${PROJECT_SOURCE_DIR}/third_party/asio/asio/include)
  
-  if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/json )
+  if(NOT EXISTS ${PROJECT_SOURCE_DIR}/third_party/json/ChangeLog.md )
     FetchContent_Declare(json
       URL   https://github.com/nlohmann/json/archive/refs/tags/v3.11.2.tar.gz
     SOURCE_DIR ${PROJECT_SOURCE_DIR}/third_party/json
diff --git a/runtime/websocket/bin/funasr-wss-client-2pass.cpp b/runtime/websocket/bin/funasr-wss-client-2pass.cpp
index 0cbd10e..6533dd5 100644
--- a/runtime/websocket/bin/funasr-wss-client-2pass.cpp
+++ b/runtime/websocket/bin/funasr-wss-client-2pass.cpp
@@ -192,10 +192,7 @@
     funasr::Audio audio(1);
     int32_t sampling_rate = audio_fs;
     std::string wav_format = "pcm";
-    if (funasr::IsTargetFile(wav_path.c_str(), "wav")) {
-      if (!audio.LoadWav(wav_path.c_str(), &sampling_rate, false)) 
-        return;
-    } else if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) {
+    if (funasr::IsTargetFile(wav_path.c_str(), "pcm")) {
       if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate, false)) return;
     } else {
       wav_format = "others";
diff --git a/runtime/websocket/bin/funasr-wss-client.cpp b/runtime/websocket/bin/funasr-wss-client.cpp
index 1dc9e3e..7af3fbb 100644
--- a/runtime/websocket/bin/funasr-wss-client.cpp
+++ b/runtime/websocket/bin/funasr-wss-client.cpp
@@ -193,10 +193,7 @@
 		funasr::Audio audio(1);
         int32_t sampling_rate = audio_fs;
         std::string wav_format = "pcm";
-        if (funasr::IsTargetFile(wav_path.c_str(), "wav")) {
-            if (!audio.LoadWav(wav_path.c_str(), &sampling_rate, false)) 
-                return;
-        } else if(funasr::IsTargetFile(wav_path.c_str(), "pcm")){
+        if(funasr::IsTargetFile(wav_path.c_str(), "pcm")){
 			if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate, false))
 				return ;
 		}else{
diff --git a/runtime/websocket/bin/websocket-server-2pass.cpp b/runtime/websocket/bin/websocket-server-2pass.cpp
index 954ffae..8c8cab4 100644
--- a/runtime/websocket/bin/websocket-server-2pass.cpp
+++ b/runtime/websocket/bin/websocket-server-2pass.cpp
@@ -211,7 +211,7 @@
         if(wav_format != "pcm" && wav_format != "PCM"){
           websocketpp::lib::error_code ec;
           nlohmann::json jsonresult;
-          jsonresult["text"] = "ERROR. Real-time transcription service ONLY SUPPORT wav_format pcm.";
+          jsonresult["text"] = "ERROR. Real-time transcription service ONLY SUPPORT PCM stream.";
           jsonresult["wav_name"] = wav_name;
           jsonresult["is_final"] = true;
           if (is_ssl) {

--
Gitblit v1.9.1