From c7fc6149b3c5c2de3107c4f1d4983309882d1a1a Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期三, 07 六月 2023 14:57:49 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR

---
 funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp |   89 ++++++++++++++++++++++++++++++++++----------
 1 files changed, 68 insertions(+), 21 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
similarity index 62%
copy from funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
copy to funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
index 0f606c6..68e32e5 100644
--- a/funasr/runtime/onnxruntime/src/funasr-onnx-offline-vad.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
@@ -18,6 +18,7 @@
 #include "funasrruntime.h"
 #include "tclap/CmdLine.h"
 #include "com-define.h"
+#include "audio.h"
 
 using namespace std;
 
@@ -38,10 +39,16 @@
     }
 }
 
-void print_segs(vector<vector<int>>* vec) {
-    string seg_out="[";
+void print_segs(vector<vector<int>>* vec, string &wav_id) {
+    if((*vec).size() == 0){
+        return;
+    }    
+    string seg_out=wav_id + ": [";
     for (int i = 0; i < vec->size(); i++) {
         vector<int> inner_vec = (*vec)[i];
+        if(inner_vec.size() == 0){
+            continue;
+        }
         seg_out += "[";
         for (int j = 0; j < inner_vec.size(); j++) {
             seg_out += to_string(inner_vec[j]);
@@ -97,9 +104,12 @@
 
     // read wav_path
     vector<string> wav_list;
+    vector<string> wav_ids;
+    string default_id = "wav_default_id";
     string wav_path_ = model_path.at(WAV_PATH);
     if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
         wav_list.emplace_back(wav_path_);
+        wav_ids.emplace_back(default_id);
     }
     else if(is_target_file(wav_path_, "scp")){
         ifstream in(wav_path_);
@@ -113,39 +123,76 @@
             istringstream iss(line);
             string column1, column2;
             iss >> column1 >> column2;
-            wav_list.emplace_back(column2); 
+            wav_list.emplace_back(column2);
+            wav_ids.emplace_back(column1);
         }
         in.close();
     }else{
         LOG(ERROR)<<"Please check the wav extension!";
         exit(-1);
     }
-    
+    // init online features
+    FUNASR_HANDLE online_hanlde=FsmnVadOnlineInit(vad_hanlde);
     float snippet_time = 0.0f;
     long taking_micros = 0;
-    for(auto& wav_file : wav_list){
-        gettimeofday(&start, NULL);
-        FUNASR_RESULT result=FsmnVadInfer(vad_hanlde, wav_file.c_str(), FSMN_VAD_OFFLINE, NULL, 16000);
-        gettimeofday(&end, NULL);
-        seconds = (end.tv_sec - start.tv_sec);
-        taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+    for (int i = 0; i < wav_list.size(); i++) {
+        auto& wav_file = wav_list[i];
+        auto& wav_id = wav_ids[i];
 
-        if (result)
-        {
-            vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
-            print_segs(vad_segments);
-            snippet_time += FsmnVadGetRetSnippetTime(result);
-            FsmnVadFreeResult(result);
-        }
-        else
-        {
-            LOG(ERROR) << ("No return data!\n");
+        int32_t sampling_rate_ = -1;
+        funasr::Audio audio(1);
+		if(is_target_file(wav_file.c_str(), "wav")){
+			int32_t sampling_rate_ = -1;
+			if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
+				LOG(ERROR)<<"Failed to load "<< wav_file;
+                exit(-1);
+            }
+		}else if(is_target_file(wav_file.c_str(), "pcm")){
+			if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
+				LOG(ERROR)<<"Failed to load "<< wav_file;
+                exit(-1);
+            }
+		}else{
+			LOG(ERROR)<<"Wrong wav extension";
+			exit(-1);
+		}
+        char* speech_buff = audio.GetSpeechChar();
+        int buff_len = audio.GetSpeechLen()*2;
+
+        int step = 3200;
+        bool is_final = false;
+
+        for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+            if (sample_offset + step >= buff_len - 1) {
+                    step = buff_len - sample_offset;
+                    is_final = true;
+                } else {
+                    is_final = false;
+            }
+            gettimeofday(&start, NULL);
+            FUNASR_RESULT result = FsmnVadInferBuffer(online_hanlde, speech_buff+sample_offset, step, NULL, is_final, 16000);
+            gettimeofday(&end, NULL);
+            seconds = (end.tv_sec - start.tv_sec);
+            taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+
+            if (result)
+            {
+                vector<std::vector<int>>* vad_segments = FsmnVadGetResult(result, 0);
+                print_segs(vad_segments, wav_id);
+                snippet_time += FsmnVadGetRetSnippetTime(result);
+                FsmnVadFreeResult(result);
+            }
+            else
+            {
+                LOG(ERROR) << ("No return data!\n");
+            }
         }
     }
- 
+
     LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
     LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
     LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
+    FsmnVadUninit(online_hanlde);
     FsmnVadUninit(vad_hanlde);
     return 0;
 }

--
Gitblit v1.9.1