From ae3e2567602546e66c0f358463617e560fc70e20 Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期四, 20 四月 2023 14:50:55 +0800
Subject: [PATCH] add offline vad for onnxruntime

---
 funasr/runtime/onnxruntime/include/Model.h         |    3 
 funasr/runtime/onnxruntime/src/paraformer_onnx.cpp |   78 --
 funasr/runtime/onnxruntime/tester/tester.cpp       |   10 
 funasr/runtime/onnxruntime/src/FsmnVad.h           |   57 ++
 funasr/runtime/onnxruntime/include/libfunasrapi.h  |   12 
 funasr/runtime/onnxruntime/src/FsmnVad.cc          |  268 +++++++++++
 funasr/runtime/onnxruntime/include/Audio.h         |    5 
 funasr/runtime/onnxruntime/src/Model.cpp           |    4 
 funasr/runtime/onnxruntime/tester/tester_rtf.cpp   |   14 
 funasr/runtime/onnxruntime/src/e2e_vad.h           |  782 ++++++++++++++++++++++++++++++++++
 funasr/runtime/onnxruntime/src/Audio.cpp           |   84 ---
 funasr/runtime/onnxruntime/src/libfunasrapi.cpp    |   33 
 funasr/runtime/onnxruntime/src/paraformer_onnx.h   |   10 
 funasr/runtime/onnxruntime/src/precomp.h           |    3 
 14 files changed, 1,192 insertions(+), 171 deletions(-)

diff --git a/funasr/runtime/onnxruntime/include/Audio.h b/funasr/runtime/onnxruntime/include/Audio.h
index ec49a9f..c38c31a 100644
--- a/funasr/runtime/onnxruntime/include/Audio.h
+++ b/funasr/runtime/onnxruntime/include/Audio.h
@@ -5,6 +5,7 @@
 #include <ComDefine.h>
 #include <queue>
 #include <stdint.h>
+#include "Model.h"
 
 #ifndef model_sample_rate
 #define model_sample_rate 16000
@@ -27,7 +28,7 @@
 
     ~AudioFrame();
     int set_start(int val);
-    int set_end(int val, int max_len);
+    int set_end(int val);
     int get_start();
     int get_len();
     int disp();
@@ -57,7 +58,7 @@
     int fetch_chunck(float *&dout, int len);
     int fetch(float *&dout, int &len, int &flag);
     void padding();
-    void split();
+    void split(Model* pRecogObj);
     float get_time_len();
 
     int get_queue_size() { return (int)frame_queue.size(); }
diff --git a/funasr/runtime/onnxruntime/include/Model.h b/funasr/runtime/onnxruntime/include/Model.h
index 6f45c38..cd3b0a3 100644
--- a/funasr/runtime/onnxruntime/include/Model.h
+++ b/funasr/runtime/onnxruntime/include/Model.h
@@ -11,7 +11,8 @@
     virtual std::string forward_chunk(float *din, int len, int flag) = 0;
     virtual std::string forward(float *din, int len, int flag) = 0;
     virtual std::string rescoring() = 0;
+    virtual std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data)=0;
 };
 
-Model *create_model(const char *path,int nThread=0,bool quantize=false);
+Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false);
 #endif
diff --git a/funasr/runtime/onnxruntime/include/libfunasrapi.h b/funasr/runtime/onnxruntime/include/libfunasrapi.h
index 9bc37e7..8d8ebd2 100644
--- a/funasr/runtime/onnxruntime/include/libfunasrapi.h
+++ b/funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -48,18 +48,18 @@
 
 typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
 	
-// APIs for qmasr
-_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThread, bool quantize);
+// APIs for funasr
+_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThread, bool quantize=false, bool use_vad=false);
 
 
 // if not give a fnCallback ,it should be NULL 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
 
-_FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
+_FUNASRAPI FUNASR_RESULT	FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
 
 _FUNASRAPI const char*	FunASRGetResult(FUNASR_RESULT Result,int nIndex);
 
diff --git a/funasr/runtime/onnxruntime/src/Audio.cpp b/funasr/runtime/onnxruntime/src/Audio.cpp
index 38b6de8..72e90a2 100644
--- a/funasr/runtime/onnxruntime/src/Audio.cpp
+++ b/funasr/runtime/onnxruntime/src/Audio.cpp
@@ -134,19 +134,10 @@
     return start;
 };
 
-int AudioFrame::set_end(int val, int max_len)
+int AudioFrame::set_end(int val)
 {
-
-    float num_samples = val - start;
-    float frame_length = 400;
-    float frame_shift = 160;
-    float num_new_samples =
-        ceil((num_samples - frame_length) / frame_shift) * frame_shift + frame_length;
-
-    end = start + num_new_samples;
-    len = (int)num_new_samples;
-    if (end > max_len)
-        printf("frame end > max_len!!!!!!!\n");
+    end = val;
+    len = end - start;
     return end;
 };
 
@@ -473,7 +464,6 @@
 
 void Audio::padding()
 {
-
     float num_samples = speech_len;
     float frame_length = 400;
     float frame_shift = 160;
@@ -509,71 +499,27 @@
     delete frame;
 }
 
-#define UNTRIGGERED 0
-#define TRIGGERED   1
-
-#define SPEECH_LEN_5S  (16000 * 5)
-#define SPEECH_LEN_10S (16000 * 10)
-#define SPEECH_LEN_20S (16000 * 20)
-#define SPEECH_LEN_30S (16000 * 30)
-
-/*
-void Audio::split()
+void Audio::split(Model* pRecogObj)
 {
-    VadInst *handle = WebRtcVad_Create();
-    WebRtcVad_Init(handle);
-    WebRtcVad_set_mode(handle, 2);
-    int window_size = 10;
-    AudioWindow audiowindow(window_size);
-    int status = UNTRIGGERED;
-    int offset = 0;
-    int fs = 16000;
-    int step = 480;
-
     AudioFrame *frame;
 
     frame = frame_queue.front();
     frame_queue.pop();
+    int sp_len = frame->get_len();
     delete frame;
     frame = NULL;
 
-    while (offset < speech_len - step) {
-        int n = WebRtcVad_Process(handle, fs, speech_buff + offset, step);
-        if (status == UNTRIGGERED && audiowindow.put(n) >= window_size - 1) {
-            frame = new AudioFrame();
-            int start = offset - step * (window_size - 1);
-            frame->set_start(start);
-            status = TRIGGERED;
-        } else if (status == TRIGGERED) {
-            int win_weight = audiowindow.put(n);
-            int voice_len = (offset - frame->get_start());
-            int gap = 0;
-            if (voice_len < SPEECH_LEN_5S) {
-                offset += step;
-                continue;
-            } else if (voice_len < SPEECH_LEN_10S) {
-                gap = 1;
-            } else if (voice_len < SPEECH_LEN_20S) {
-                gap = window_size / 5;
-            } else {
-                gap = window_size / 2;
-            }
-
-            if (win_weight < gap) {
-                status = UNTRIGGERED;
-                offset = frame->set_end(offset, speech_align_len);
-                frame_queue.push(frame);
-                frame = NULL;
-            }
-        }
-        offset += step;
-    }
-
-    if (frame != NULL) {
-        frame->set_end(speech_len, speech_align_len);
+    std::vector<float> pcm_data(speech_data, speech_data+sp_len);
+    vector<std::vector<int>> vad_segments = pRecogObj->vad_seg(pcm_data);
+    int seg_sample = model_sample_rate/1000;
+    for(vector<int> segment:vad_segments)
+    {
+        frame = new AudioFrame();
+        int start = segment[0]*seg_sample;
+        int end = segment[1]*seg_sample;
+        frame->set_start(start);
+        frame->set_end(end);
         frame_queue.push(frame);
         frame = NULL;
     }
-    WebRtcVad_Free(handle);
 }
-*/
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/FsmnVad.cc b/funasr/runtime/onnxruntime/src/FsmnVad.cc
new file mode 100644
index 0000000..69c69ca
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/FsmnVad.cc
@@ -0,0 +1,268 @@
+//
+// Created by root on 4/9/23.
+//
+#include <fstream>
+#include "FsmnVad.h"
+#include "precomp.h"
+//#include "glog/logging.h"
+
+
+void FsmnVad::init_vad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
+                       float vad_speech_noise_thres) {
+
+
+    session_options_.SetIntraOpNumThreads(1);
+    session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
+    session_options_.DisableCpuMemArena();
+    this->vad_sample_rate_ = vad_sample_rate;
+    this->vad_silence_duration_=vad_silence_duration;
+    this->vad_max_len_=vad_max_len;
+    this->vad_speech_noise_thres_=vad_speech_noise_thres;
+
+    read_model(vad_model);
+    load_cmvn(vad_cmvn.c_str());
+
+    fbank_opts.frame_opts.dither = 0;
+    fbank_opts.mel_opts.num_bins = 80;
+    fbank_opts.frame_opts.samp_freq = vad_sample_rate;
+    fbank_opts.frame_opts.window_type = "hamming";
+    fbank_opts.frame_opts.frame_shift_ms = 10;
+    fbank_opts.frame_opts.frame_length_ms = 25;
+    fbank_opts.energy_floor = 0;
+    fbank_opts.mel_opts.debug_mel = false;
+
+}
+
+
+void FsmnVad::read_model(const std::string &vad_model) {
+    try {
+        vad_session_ = std::make_shared<Ort::Session>(
+                env_, vad_model.c_str(), session_options_);
+    } catch (std::exception const &e) {
+        //LOG(ERROR) << "Error when load onnx model: " << e.what();
+        exit(0);
+    }
+    //LOG(INFO) << "vad onnx:";
+    GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_);
+}
+
+
+
+void FsmnVad::GetInputOutputInfo(
+        const std::shared_ptr<Ort::Session> &session,
+        std::vector<const char *> *in_names, std::vector<const char *> *out_names) {
+    Ort::AllocatorWithDefaultOptions allocator;
+    // Input info
+    int num_nodes = session->GetInputCount();
+    in_names->resize(num_nodes);
+    for (int i = 0; i < num_nodes; ++i) {
+        std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetInputNameAllocated(i, allocator);
+        Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
+        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
+        ONNXTensorElementDataType type = tensor_info.GetElementType();
+        std::vector<int64_t> node_dims = tensor_info.GetShape();
+        std::stringstream shape;
+        for (auto j: node_dims) {
+            shape << j;
+            shape << " ";
+        }
+        // LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type
+        //           << " dims=" << shape.str();
+        (*in_names)[i] = name.get();
+        name.release();
+    }
+    // Output info
+    num_nodes = session->GetOutputCount();
+    out_names->resize(num_nodes);
+    for (int i = 0; i < num_nodes; ++i) {
+        std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetOutputNameAllocated(i, allocator);
+        Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
+        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
+        ONNXTensorElementDataType type = tensor_info.GetElementType();
+        std::vector<int64_t> node_dims = tensor_info.GetShape();
+        std::stringstream shape;
+        for (auto j: node_dims) {
+            shape << j;
+            shape << " ";
+        }
+        // LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type
+        //           << " dims=" << shape.str();
+        (*out_names)[i] = name.get();
+        name.release();
+    }
+}
+
+
+void FsmnVad::Forward(
+        const std::vector<std::vector<float>> &chunk_feats,
+        std::vector<std::vector<float>> *out_prob) {
+    Ort::MemoryInfo memory_info =
+            Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+
+    int num_frames = chunk_feats.size();
+    const int feature_dim = chunk_feats[0].size();
+
+    //  2. Generate input nodes tensor
+    // vad node { batch,frame number,feature dim }
+    const int64_t vad_feats_shape[3] = {1, num_frames, feature_dim};
+    std::vector<float> vad_feats;
+    for (const auto &chunk_feat: chunk_feats) {
+        vad_feats.insert(vad_feats.end(), chunk_feat.begin(), chunk_feat.end());
+    }
+    Ort::Value vad_feats_ort = Ort::Value::CreateTensor<float>(
+            memory_info, vad_feats.data(), vad_feats.size(), vad_feats_shape, 3);
+    // cache node {batch,128,19,1}
+    const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
+    std::vector<float> cache_feats(128 * 19 * 1, 0);
+    Ort::Value cache_feats_ort = Ort::Value::CreateTensor<float>(
+            memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4);
+
+    // 3. Put nodes into onnx input vector
+    std::vector<Ort::Value> vad_inputs;
+    vad_inputs.emplace_back(std::move(vad_feats_ort));
+    // 4 caches
+    for (int i = 0; i < 4; i++) {
+        vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
+                memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4)));
+    }
+    // 4. Onnx infer
+    std::vector<Ort::Value> vad_ort_outputs;
+    try {
+        // VLOG(3) << "Start infer";
+        vad_ort_outputs = vad_session_->Run(
+                Ort::RunOptions{nullptr}, vad_in_names_.data(), vad_inputs.data(),
+                vad_inputs.size(), vad_out_names_.data(), vad_out_names_.size());
+    } catch (std::exception const &e) {
+        // LOG(ERROR) << e.what();
+        return;
+    }
+
+    // 5. Change infer result to output shapes
+    float *logp_data = vad_ort_outputs[0].GetTensorMutableData<float>();
+    auto type_info = vad_ort_outputs[0].GetTensorTypeAndShapeInfo();
+
+    int num_outputs = type_info.GetShape()[1];
+    int output_dim = type_info.GetShape()[2];
+    out_prob->resize(num_outputs);
+    for (int i = 0; i < num_outputs; i++) {
+        (*out_prob)[i].resize(output_dim);
+        memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
+               sizeof(float) * output_dim);
+    }
+}
+
+
+void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
+                         const std::vector<float> &waves) {
+    knf::OnlineFbank fbank(fbank_opts);
+
+    fbank.AcceptWaveform(sample_rate, &waves[0], waves.size());
+    int32_t frames = fbank.NumFramesReady();
+    for (int32_t i = 0; i != frames; ++i) {
+        const float *frame = fbank.GetFrame(i);
+        std::vector<float> frame_vector(frame, frame + fbank_opts.mel_opts.num_bins);
+        vad_feats.emplace_back(frame_vector);
+    }
+}
+
+void FsmnVad::load_cmvn(const char *filename)
+{
+    using namespace std;
+    ifstream cmvn_stream(filename);
+    string line;
+
+    while (getline(cmvn_stream, line)) {
+        istringstream iss(line);
+        vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
+        if (line_item[0] == "<AddShift>") {
+            getline(cmvn_stream, line);
+            istringstream means_lines_stream(line);
+            vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
+            if (means_lines[0] == "<LearnRateCoef>") {
+                for (int j = 3; j < means_lines.size() - 1; j++) {
+                    means_list.push_back(stof(means_lines[j]));
+                }
+                continue;
+            }
+        }
+        else if (line_item[0] == "<Rescale>") {
+            getline(cmvn_stream, line);
+            istringstream vars_lines_stream(line);
+            vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
+            if (vars_lines[0] == "<LearnRateCoef>") {
+                for (int j = 3; j < vars_lines.size() - 1; j++) {
+                    // vars_list.push_back(stof(vars_lines[j])*scale);
+                    vars_list.push_back(stof(vars_lines[j]));
+                }
+                continue;
+            }
+        }
+    }
+}
+
+std::vector<std::vector<float>> &FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats, int lfr_m, int lfr_n) {
+
+    std::vector<std::vector<float>> out_feats;
+    int T = vad_feats.size();
+    int T_lrf = ceil(1.0 * T / lfr_n);
+
+    // Pad frames at start(copy first frame)
+    for (int i = 0; i < (lfr_m - 1) / 2; i++) {
+        vad_feats.insert(vad_feats.begin(), vad_feats[0]);
+    }
+    // Merge lfr_m frames as one,lfr_n frames per window
+    T = T + (lfr_m - 1) / 2;
+    std::vector<float> p;
+    for (int i = 0; i < T_lrf; i++) {
+        if (lfr_m <= T - i * lfr_n) {
+            for (int j = 0; j < lfr_m; j++) {
+                p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
+            }
+            out_feats.emplace_back(p);
+            p.clear();
+        } else {
+            // Fill to lfr_m frames at last window if less than lfr_m frames  (copy last frame)
+            int num_padding = lfr_m - (T - i * lfr_n);
+            for (int j = 0; j < (vad_feats.size() - i * lfr_n); j++) {
+                p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
+            }
+            for (int j = 0; j < num_padding; j++) {
+                p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
+            }
+            out_feats.emplace_back(p);
+        }
+    }
+    // Apply cmvn
+    for (auto &out_feat: out_feats) {
+        for (int j = 0; j < means_list.size(); j++) {
+            out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j];
+        }
+    }
+    vad_feats = out_feats;
+    return vad_feats;
+}
+
+std::vector<std::vector<int>>
+FsmnVad::infer(const std::vector<float> &waves) {
+    std::vector<std::vector<float>> vad_feats;
+    std::vector<std::vector<float>> vad_probs;
+    FbankKaldi(vad_sample_rate_, vad_feats, waves);
+    vad_feats = LfrCmvn(vad_feats, 5, 1);
+    Forward(vad_feats, &vad_probs);
+
+    E2EVadModel vad_scorer = E2EVadModel();
+    std::vector<std::vector<int>> vad_segments;
+    vad_segments = vad_scorer(vad_probs, waves, true, vad_silence_duration_, vad_max_len_,
+                              vad_speech_noise_thres_, vad_sample_rate_);
+    return vad_segments;
+
+}
+
+
+void FsmnVad::test() {
+
+}
+
+FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} {
+
+}
diff --git a/funasr/runtime/onnxruntime/src/FsmnVad.h b/funasr/runtime/onnxruntime/src/FsmnVad.h
new file mode 100644
index 0000000..839f9ec
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/FsmnVad.h
@@ -0,0 +1,57 @@
+//
+// Created by zyf on 4/9/23.
+//
+
+#ifndef VAD_SERVER_FSMNVAD_H
+#define VAD_SERVER_FSMNVAD_H
+
+#include "e2e_vad.h"
+#include "onnxruntime_cxx_api.h"
+#include "kaldi-native-fbank/csrc/feature-fbank.h"
+#include "kaldi-native-fbank/csrc/online-feature.h"
+
+
+class FsmnVad {
+public:
+    FsmnVad();
+    void test();
+    void init_vad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
+                  float vad_speech_noise_thres);
+
+    std::vector<std::vector<int>> infer(const std::vector<float> &waves);
+
+private:
+
+    void read_model(const std::string &vad_model);
+
+    static void GetInputOutputInfo(
+            const std::shared_ptr<Ort::Session> &session,
+            std::vector<const char *> *in_names, std::vector<const char *> *out_names);
+
+    void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
+                    const std::vector<float> &waves);
+
+    std::vector<std::vector<float>> &LfrCmvn(std::vector<std::vector<float>> &vad_feats, int lfr_m, int lfr_n);
+
+    void Forward(
+            const std::vector<std::vector<float>> &chunk_feats,
+            std::vector<std::vector<float>> *out_prob);
+
+    void load_cmvn(const char *filename);
+
+    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
+    Ort::Env env_;
+    Ort::SessionOptions session_options_;
+    std::vector<const char *> vad_in_names_;
+    std::vector<const char *> vad_out_names_;
+    knf::FbankOptions fbank_opts;
+    std::vector<float> means_list;
+    std::vector<float> vars_list;
+    int vad_sample_rate_ = 16000;
+    int vad_silence_duration_ = 800;
+    int vad_max_len_ = 15000;
+    double vad_speech_noise_thres_ = 0.9;
+};
+
+
+#endif //VAD_SERVER_FSMNVAD_H
diff --git a/funasr/runtime/onnxruntime/src/Model.cpp b/funasr/runtime/onnxruntime/src/Model.cpp
index 7ddb635..2f864a9 100644
--- a/funasr/runtime/onnxruntime/src/Model.cpp
+++ b/funasr/runtime/onnxruntime/src/Model.cpp
@@ -1,10 +1,10 @@
 #include "precomp.h"
 
-Model *create_model(const char *path, int nThread, bool quantize)
+Model *create_model(const char *path, int nThread, bool quantize, bool use_vad)
 {
     Model *mm;
 
-    mm = new paraformer::ModelImp(path, nThread, quantize);
+    mm = new paraformer::ModelImp(path, nThread, quantize, use_vad);
 
     return mm;
 }
diff --git a/funasr/runtime/onnxruntime/src/e2e_vad.h b/funasr/runtime/onnxruntime/src/e2e_vad.h
new file mode 100644
index 0000000..b1780a7
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/e2e_vad.h
@@ -0,0 +1,782 @@
+//
+// Created by root on 3/31/23.
+//
+
+#include <utility>
+#include <vector>
+#include <string>
+#include <map>
+#include <cmath>
+#include <algorithm>
+#include <iostream>
+#include <numeric>
+#include <cassert>
+
+
+enum class VadStateMachine {
+    kVadInStateStartPointNotDetected = 1,
+    kVadInStateInSpeechSegment = 2,
+    kVadInStateEndPointDetected = 3
+};
+
+enum class FrameState {
+    kFrameStateInvalid = -1,
+    kFrameStateSpeech = 1,
+    kFrameStateSil = 0
+};
+
+// final voice/unvoice state per frame
+enum class AudioChangeState {
+    kChangeStateSpeech2Speech = 0,
+    kChangeStateSpeech2Sil = 1,
+    kChangeStateSil2Sil = 2,
+    kChangeStateSil2Speech = 3,
+    kChangeStateNoBegin = 4,
+    kChangeStateInvalid = 5
+};
+
+enum class VadDetectMode {
+    kVadSingleUtteranceDetectMode = 0,
+    kVadMutipleUtteranceDetectMode = 1
+};
+
+class VADXOptions {
+public:
+    int sample_rate;
+    int detect_mode;
+    int snr_mode;
+    int max_end_silence_time;
+    int max_start_silence_time;
+    bool do_start_point_detection;
+    bool do_end_point_detection;
+    int window_size_ms;
+    int sil_to_speech_time_thres;
+    int speech_to_sil_time_thres;
+    float speech_2_noise_ratio;
+    int do_extend;
+    int lookback_time_start_point;
+    int lookahead_time_end_point;
+    int max_single_segment_time;
+    int nn_eval_block_size;
+    int dcd_block_size;
+    float snr_thres;
+    int noise_frame_num_used_for_snr;
+    float decibel_thres;
+    float speech_noise_thres;
+    float fe_prior_thres;
+    int silence_pdf_num;
+    std::vector<int> sil_pdf_ids;
+    float speech_noise_thresh_low;
+    float speech_noise_thresh_high;
+    bool output_frame_probs;
+    int frame_in_ms;
+    int frame_length_ms;
+
+    explicit VADXOptions(
+            int sr = 16000,
+            int dm = static_cast<int>(VadDetectMode::kVadMutipleUtteranceDetectMode),
+            int sm = 0,
+            int mset = 800,
+            int msst = 3000,
+            bool dspd = true,
+            bool depd = true,
+            int wsm = 200,
+            int ststh = 150,
+            int sttsh = 150,
+            float s2nr = 1.0,
+            int de = 1,
+            int lbtps = 200,
+            int latsp = 100,
+            int mss = 15000,
+            int nebs = 8,
+            int dbs = 4,
+            float st = -100.0,
+            int nfnus = 100,
+            float dt = -100.0,
+            float snt = 0.9,
+            float fept = 1e-4,
+            int spn = 1,
+            std::vector<int> spids = {0},
+            float sntl = -0.1,
+            float snth = 0.3,
+            bool ofp = false,
+            int fim = 10,
+            int flm = 25
+    ) :
+            sample_rate(sr),
+            detect_mode(dm),
+            snr_mode(sm),
+            max_end_silence_time(mset),
+            max_start_silence_time(msst),
+            do_start_point_detection(dspd),
+            do_end_point_detection(depd),
+            window_size_ms(wsm),
+            sil_to_speech_time_thres(ststh),
+            speech_to_sil_time_thres(sttsh),
+            speech_2_noise_ratio(s2nr),
+            do_extend(de),
+            lookback_time_start_point(lbtps),
+            lookahead_time_end_point(latsp),
+            max_single_segment_time(mss),
+            nn_eval_block_size(nebs),
+            dcd_block_size(dbs),
+            snr_thres(st),
+            noise_frame_num_used_for_snr(nfnus),
+            decibel_thres(dt),
+            speech_noise_thres(snt),
+            fe_prior_thres(fept),
+            silence_pdf_num(spn),
+            sil_pdf_ids(std::move(spids)),
+            speech_noise_thresh_low(sntl),
+            speech_noise_thresh_high(snth),
+            output_frame_probs(ofp),
+            frame_in_ms(fim),
+            frame_length_ms(flm) {}
+};
+
+class E2EVadSpeechBufWithDoa {
+public:
+    int start_ms;
+    int end_ms;
+    std::vector<float> buffer;
+    bool contain_seg_start_point;
+    bool contain_seg_end_point;
+    int doa;
+
+    E2EVadSpeechBufWithDoa() :
+            start_ms(0),
+            end_ms(0),
+            buffer(),
+            contain_seg_start_point(false),
+            contain_seg_end_point(false),
+            doa(0) {}
+
+    void Reset() {
+        start_ms = 0;
+        end_ms = 0;
+        buffer.clear();
+        contain_seg_start_point = false;
+        contain_seg_end_point = false;
+        doa = 0;
+    }
+};
+
+class E2EVadFrameProb {
+public:
+    double noise_prob;
+    double speech_prob;
+    double score;
+    int frame_id;
+    int frm_state;
+
+    E2EVadFrameProb() :
+            noise_prob(0.0),
+            speech_prob(0.0),
+            score(0.0),
+            frame_id(0),
+            frm_state(0) {}
+};
+
+class WindowDetector {
+public:
+    int window_size_ms;
+    int sil_to_speech_time;
+    int speech_to_sil_time;
+    int frame_size_ms;
+    int win_size_frame;
+    int win_sum;
+    std::vector<int> win_state;
+    int cur_win_pos;
+    FrameState pre_frame_state;
+    FrameState cur_frame_state;
+    int sil_to_speech_frmcnt_thres;
+    int speech_to_sil_frmcnt_thres;
+    int voice_last_frame_count;
+    int noise_last_frame_count;
+    int hydre_frame_count;
+
+    WindowDetector(int window_size_ms, int sil_to_speech_time, int speech_to_sil_time, int frame_size_ms) :
+            window_size_ms(window_size_ms),
+            sil_to_speech_time(sil_to_speech_time),
+            speech_to_sil_time(speech_to_sil_time),
+            frame_size_ms(frame_size_ms),
+            win_size_frame(window_size_ms / frame_size_ms),
+            win_sum(0),
+            win_state(std::vector<int>(win_size_frame, 0)),
+            cur_win_pos(0),
+            pre_frame_state(FrameState::kFrameStateSil),
+            cur_frame_state(FrameState::kFrameStateSil),
+            sil_to_speech_frmcnt_thres(sil_to_speech_time / frame_size_ms),
+            speech_to_sil_frmcnt_thres(speech_to_sil_time / frame_size_ms),
+            voice_last_frame_count(0),
+            noise_last_frame_count(0),
+            hydre_frame_count(0) {}
+
+    void Reset() {
+        cur_win_pos = 0;
+        win_sum = 0;
+        win_state = std::vector<int>(win_size_frame, 0);
+        pre_frame_state = FrameState::kFrameStateSil;
+        cur_frame_state = FrameState::kFrameStateSil;
+        voice_last_frame_count = 0;
+        noise_last_frame_count = 0;
+        hydre_frame_count = 0;
+    }
+
+    int GetWinSize() {
+        return win_size_frame;
+    }
+
+    AudioChangeState DetectOneFrame(FrameState frameState, int frame_count) {
+        int cur_frame_state = 0;
+        if (frameState == FrameState::kFrameStateSpeech) {
+            cur_frame_state = 1;
+        } else if (frameState == FrameState::kFrameStateSil) {
+            cur_frame_state = 0;
+        } else {
+            return AudioChangeState::kChangeStateInvalid;
+        }
+        win_sum -= win_state[cur_win_pos];
+        win_sum += cur_frame_state;
+        win_state[cur_win_pos] = cur_frame_state;
+        cur_win_pos = (cur_win_pos + 1) % win_size_frame;
+        if (pre_frame_state == FrameState::kFrameStateSil && win_sum >= sil_to_speech_frmcnt_thres) {
+            pre_frame_state = FrameState::kFrameStateSpeech;
+            return AudioChangeState::kChangeStateSil2Speech;
+        }
+        if (pre_frame_state == FrameState::kFrameStateSpeech && win_sum <= speech_to_sil_frmcnt_thres) {
+            pre_frame_state = FrameState::kFrameStateSil;
+            return AudioChangeState::kChangeStateSpeech2Sil;
+        }
+        if (pre_frame_state == FrameState::kFrameStateSil) {
+            return AudioChangeState::kChangeStateSil2Sil;
+        }
+        if (pre_frame_state == FrameState::kFrameStateSpeech) {
+            return AudioChangeState::kChangeStateSpeech2Speech;
+        }
+        return AudioChangeState::kChangeStateInvalid;
+    }
+
+    int FrameSizeMs() {
+        return frame_size_ms;
+    }
+};
+
+class E2EVadModel {
+public:
+    E2EVadModel() {
+        this->vad_opts = VADXOptions();
+//    this->windows_detector = WindowDetector(200,150,150,10);
+        // this->encoder = encoder;
+        // init variables
+        this->is_final = false;
+        this->data_buf_start_frame = 0;
+        this->frm_cnt = 0;
+        this->latest_confirmed_speech_frame = 0;
+        this->lastest_confirmed_silence_frame = -1;
+        this->continous_silence_frame_count = 0;
+        this->vad_state_machine = VadStateMachine::kVadInStateStartPointNotDetected;
+        this->confirmed_start_frame = -1;
+        this->confirmed_end_frame = -1;
+        this->number_end_time_detected = 0;
+        this->sil_frame = 0;
+        this->sil_pdf_ids = this->vad_opts.sil_pdf_ids;
+        this->noise_average_decibel = -100.0;
+        this->pre_end_silence_detected = false;
+        this->next_seg = true;
+//    this->output_data_buf = [];
+        this->output_data_buf_offset = 0;
+//    this->frame_probs = [];
+        this->max_end_sil_frame_cnt_thresh =
+                this->vad_opts.max_end_silence_time - this->vad_opts.speech_to_sil_time_thres;
+        this->speech_noise_thres = this->vad_opts.speech_noise_thres;
+        this->max_time_out = false;
+//    this->decibel = [];
+        this->ResetDetection();
+    }
+
+    std::vector<std::vector<int>>
+    operator()(const std::vector<std::vector<float>> &score, const std::vector<float> &waveform, bool is_final = false,
+               int max_end_sil = 800, int max_single_segment_time = 15000, float speech_noise_thres = 0.9,
+               int sample_rate = 16000) {
+        max_end_sil_frame_cnt_thresh = max_end_sil - vad_opts.speech_to_sil_time_thres;
+        this->waveform = waveform;
+        this->vad_opts.max_single_segment_time = max_single_segment_time;
+        this->vad_opts.speech_noise_thres = speech_noise_thres;
+        this->vad_opts.sample_rate = sample_rate;
+
+        ComputeDecibel();
+        ComputeScores(score);
+        if (!is_final) {
+            DetectCommonFrames();
+        } else {
+            DetectLastFrames();
+        }
+        //    std::vector<std::vector<int>> segments;
+        //    for (size_t batch_num = 0; batch_num < score.size(); batch_num++) {
+        std::vector<std::vector<int>> segment_batch;
+        if (output_data_buf.size() > 0) {
+            for (size_t i = output_data_buf_offset; i < output_data_buf.size(); i++) {
+                if (!output_data_buf[i].contain_seg_start_point) {
+                    continue;
+                }
+                if (!next_seg && !output_data_buf[i].contain_seg_end_point) {
+                    continue;
+                }
+                int start_ms = next_seg ? output_data_buf[i].start_ms : -1;
+                int end_ms;
+                if (output_data_buf[i].contain_seg_end_point) {
+                    end_ms = output_data_buf[i].end_ms;
+                    next_seg = true;
+                    output_data_buf_offset += 1;
+                } else {
+                    end_ms = -1;
+                    next_seg = false;
+                }
+                std::vector<int> segment = {start_ms, end_ms};
+                segment_batch.push_back(segment);
+            }
+        }
+
+        //    }
+        if (is_final) {
+            AllResetDetection();
+        }
+        return segment_batch;
+    }
+
+private:
+    VADXOptions vad_opts;
+    WindowDetector windows_detector = WindowDetector(200, 150, 150, 10);
+    bool is_final;
+    int data_buf_start_frame;
+    int frm_cnt;
+    int latest_confirmed_speech_frame;
+    int lastest_confirmed_silence_frame;
+    int continous_silence_frame_count;
+    VadStateMachine vad_state_machine;
+    int confirmed_start_frame;
+    int confirmed_end_frame;
+    int number_end_time_detected;
+    int sil_frame;
+    std::vector<int> sil_pdf_ids;
+    float noise_average_decibel;
+    bool pre_end_silence_detected;
+    bool next_seg;
+    std::vector<E2EVadSpeechBufWithDoa> output_data_buf;
+    int output_data_buf_offset;
+    std::vector<E2EVadFrameProb> frame_probs;
+    int max_end_sil_frame_cnt_thresh;
+    float speech_noise_thres;
+    std::vector<std::vector<float>> scores;
+    bool max_time_out;
+    std::vector<float> decibel;
+    std::vector<float> data_buf;
+    std::vector<float> data_buf_all;
+    std::vector<float> waveform;
+
+    void AllResetDetection() {
+        is_final = false;
+        data_buf_start_frame = 0;
+        frm_cnt = 0;
+        latest_confirmed_speech_frame = 0;
+        lastest_confirmed_silence_frame = -1;
+        continous_silence_frame_count = 0;
+        vad_state_machine = VadStateMachine::kVadInStateStartPointNotDetected;
+        confirmed_start_frame = -1;
+        confirmed_end_frame = -1;
+        number_end_time_detected = 0;
+        sil_frame = 0;
+        sil_pdf_ids = vad_opts.sil_pdf_ids;
+        noise_average_decibel = -100.0;
+        pre_end_silence_detected = false;
+        next_seg = true;
+        output_data_buf.clear();
+        output_data_buf_offset = 0;
+        frame_probs.clear();
+        max_end_sil_frame_cnt_thresh = vad_opts.max_end_silence_time - vad_opts.speech_to_sil_time_thres;
+        speech_noise_thres = vad_opts.speech_noise_thres;
+        scores.clear();
+        max_time_out = false;
+        decibel.clear();
+        data_buf.clear();
+        data_buf_all.clear();
+        waveform.clear();
+        ResetDetection();
+    }
+
+    void ResetDetection() {
+        continous_silence_frame_count = 0;
+        latest_confirmed_speech_frame = 0;
+        lastest_confirmed_silence_frame = -1;
+        confirmed_start_frame = -1;
+        confirmed_end_frame = -1;
+        vad_state_machine = VadStateMachine::kVadInStateStartPointNotDetected;
+        windows_detector.Reset();
+        sil_frame = 0;
+        frame_probs.clear();
+    }
+
+    void ComputeDecibel() {
+        int frame_sample_length = int(vad_opts.frame_length_ms * vad_opts.sample_rate / 1000);
+        int frame_shift_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
+        if (data_buf_all.empty()) {
+            data_buf_all = waveform;
+            data_buf = data_buf_all;
+        } else {
+            data_buf_all.insert(data_buf_all.end(), waveform.begin(), waveform.end());
+        }
+        for (int offset = 0; offset < waveform.size() - frame_sample_length + 1; offset += frame_shift_length) {
+            float sum = 0.0;
+            for (int i = 0; i < frame_sample_length; i++) {
+                sum += waveform[offset + i] * waveform[offset + i];
+            }
+//      float decibel = 10 * log10(sum + 0.000001);
+            this->decibel.push_back(10 * log10(sum + 0.000001));
+        }
+    }
+
+    void ComputeScores(const std::vector<std::vector<float>> &scores) {
+        vad_opts.nn_eval_block_size = scores.size();
+        frm_cnt += scores.size();
+        if (this->scores.empty()) {
+            this->scores = scores;  // the first calculation
+        } else {
+            this->scores.insert(this->scores.end(), scores.begin(), scores.end());
+        }
+    }
+
+    void PopDataBufTillFrame(int frame_idx) {
+        while (data_buf_start_frame < frame_idx) {
+            int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
+            if (data_buf.size() >= frame_sample_length) {
+                data_buf_start_frame += 1;
+                data_buf.erase(data_buf.begin(), data_buf.begin() + frame_sample_length);
+            } else {
+                break;
+            }
+        }
+    }
+
+    void PopDataToOutputBuf(int start_frm, int frm_cnt, bool first_frm_is_start_point, bool last_frm_is_end_point,
+                            bool end_point_is_sent_end) {
+        PopDataBufTillFrame(start_frm);
+        int expected_sample_number = int(frm_cnt * vad_opts.sample_rate * vad_opts.frame_in_ms / 1000);
+        if (last_frm_is_end_point) {
+            int extra_sample = std::max(0, int(vad_opts.frame_length_ms * vad_opts.sample_rate / 1000 -
+                                               vad_opts.sample_rate * vad_opts.frame_in_ms / 1000));
+            expected_sample_number += int(extra_sample);
+        }
+        if (end_point_is_sent_end) {
+            expected_sample_number = std::max(expected_sample_number, int(data_buf.size()));
+        }
+        if (data_buf.size() < expected_sample_number) {
+            std::cout << "error in calling pop data_buf\n";
+        }
+        if (output_data_buf.size() == 0 || first_frm_is_start_point) {
+            output_data_buf.push_back(E2EVadSpeechBufWithDoa());
+            output_data_buf[output_data_buf.size() - 1].Reset();
+            output_data_buf[output_data_buf.size() - 1].start_ms = start_frm * vad_opts.frame_in_ms;
+            output_data_buf[output_data_buf.size() - 1].end_ms = output_data_buf[output_data_buf.size() - 1].start_ms;
+            output_data_buf[output_data_buf.size() - 1].doa = 0;
+        }
+        E2EVadSpeechBufWithDoa &cur_seg = output_data_buf.back();
+        if (cur_seg.end_ms != start_frm * vad_opts.frame_in_ms) {
+            std::cout << "warning\n";
+        }
+        int out_pos = (int) cur_seg.buffer.size();
+        int data_to_pop;
+        if (end_point_is_sent_end) {
+            data_to_pop = expected_sample_number;
+        } else {
+            data_to_pop = int(frm_cnt * vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
+        }
+        if (data_to_pop > int(data_buf.size())) {
+            std::cout << "VAD data_to_pop is bigger than data_buf.size()!!!\n";
+            data_to_pop = (int) data_buf.size();
+            expected_sample_number = (int) data_buf.size();
+        }
+        cur_seg.doa = 0;
+        for (int sample_cpy_out = 0; sample_cpy_out < data_to_pop; sample_cpy_out++) {
+            cur_seg.buffer.push_back(data_buf.back());
+            out_pos++;
+        }
+        for (int sample_cpy_out = data_to_pop; sample_cpy_out < expected_sample_number; sample_cpy_out++) {
+            cur_seg.buffer.push_back(data_buf.back());
+            out_pos++;
+        }
+        if (cur_seg.end_ms != start_frm * vad_opts.frame_in_ms) {
+            std::cout << "Something wrong with the VAD algorithm\n";
+        }
+        data_buf_start_frame += frm_cnt;
+        cur_seg.end_ms = (start_frm + frm_cnt) * vad_opts.frame_in_ms;
+        if (first_frm_is_start_point) {
+            cur_seg.contain_seg_start_point = true;
+        }
+        if (last_frm_is_end_point) {
+            cur_seg.contain_seg_end_point = true;
+        }
+    }
+
+    void OnSilenceDetected(int valid_frame) {
+        lastest_confirmed_silence_frame = valid_frame;
+        if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) {
+            PopDataBufTillFrame(valid_frame);
+        }
+        // silence_detected_callback_
+        // pass
+    }
+
+    void OnVoiceDetected(int valid_frame) {
+        latest_confirmed_speech_frame = valid_frame;
+        PopDataToOutputBuf(valid_frame, 1, false, false, false);
+    }
+
+    void OnVoiceStart(int start_frame, bool fake_result = false) {
+        if (vad_opts.do_start_point_detection) {
+            // pass
+        }
+        if (confirmed_start_frame != -1) {
+            std::cout << "not reset vad properly\n";
+        } else {
+            confirmed_start_frame = start_frame;
+        }
+        if (!fake_result && vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) {
+            PopDataToOutputBuf(confirmed_start_frame, 1, true, false, false);
+        }
+    }
+
+
+    void OnVoiceEnd(int end_frame, bool fake_result, bool is_last_frame) {
+        for (int t = latest_confirmed_speech_frame + 1; t < end_frame; t++) {
+            OnVoiceDetected(t);
+        }
+        if (vad_opts.do_end_point_detection) {
+            // pass
+        }
+        if (confirmed_end_frame != -1) {
+            std::cout << "not reset vad properly\n";
+        } else {
+            confirmed_end_frame = end_frame;
+        }
+        if (!fake_result) {
+            sil_frame = 0;
+            PopDataToOutputBuf(confirmed_end_frame, 1, false, true, is_last_frame);
+        }
+        number_end_time_detected++;
+    }
+
+    void MaybeOnVoiceEndIfLastFrame(bool is_final_frame, int cur_frm_idx) {
+        if (is_final_frame) {
+            OnVoiceEnd(cur_frm_idx, false, true);
+            vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
+        }
+    }
+
+    int GetLatency() {
+        return int(LatencyFrmNumAtStartPoint() * vad_opts.frame_in_ms);
+    }
+
+    int LatencyFrmNumAtStartPoint() {
+        int vad_latency = windows_detector.GetWinSize();
+        if (vad_opts.do_extend) {
+            vad_latency += int(vad_opts.lookback_time_start_point / vad_opts.frame_in_ms);
+        }
+        return vad_latency;
+    }
+
+    FrameState GetFrameState(int t) {
+        FrameState frame_state = FrameState::kFrameStateInvalid;
+        float cur_decibel = decibel[t];
+        float cur_snr = cur_decibel - noise_average_decibel;
+        if (cur_decibel < vad_opts.decibel_thres) {
+            frame_state = FrameState::kFrameStateSil;
+            DetectOneFrame(frame_state, t, false);
+            return frame_state;
+        }
+        float sum_score = 0.0;
+        float noise_prob = 0.0;
+        assert(sil_pdf_ids.size() == vad_opts.silence_pdf_num);
+        if (sil_pdf_ids.size() > 0) {
+            std::vector<float> sil_pdf_scores;
+            for (auto sil_pdf_id: sil_pdf_ids) {
+                sil_pdf_scores.push_back(scores[t][sil_pdf_id]);
+            }
+            sum_score = accumulate(sil_pdf_scores.begin(), sil_pdf_scores.end(), 0.0);
+            noise_prob = log(sum_score) * vad_opts.speech_2_noise_ratio;
+            float total_score = 1.0;
+            sum_score = total_score - sum_score;
+        }
+        float speech_prob = log(sum_score);
+        if (vad_opts.output_frame_probs) {
+            E2EVadFrameProb frame_prob;
+            frame_prob.noise_prob = noise_prob;
+            frame_prob.speech_prob = speech_prob;
+            frame_prob.score = sum_score;
+            frame_prob.frame_id = t;
+            frame_probs.push_back(frame_prob);
+        }
+        if (exp(speech_prob) >= exp(noise_prob) + speech_noise_thres) {
+            if (cur_snr >= vad_opts.snr_thres && cur_decibel >= vad_opts.decibel_thres) {
+                frame_state = FrameState::kFrameStateSpeech;
+            } else {
+                frame_state = FrameState::kFrameStateSil;
+            }
+        } else {
+            frame_state = FrameState::kFrameStateSil;
+            if (noise_average_decibel < -99.9) {
+                noise_average_decibel = cur_decibel;
+            } else {
+                noise_average_decibel =
+                        (cur_decibel + noise_average_decibel * (vad_opts.noise_frame_num_used_for_snr - 1)) /
+                        vad_opts.noise_frame_num_used_for_snr;
+            }
+        }
+        return frame_state;
+    }
+
+    int DetectCommonFrames() {
+        if (vad_state_machine == VadStateMachine::kVadInStateEndPointDetected) {
+            return 0;
+        }
+        for (int i = vad_opts.nn_eval_block_size - 1; i >= 0; i--) {
+            FrameState frame_state = FrameState::kFrameStateInvalid;
+            frame_state = GetFrameState(frm_cnt - 1 - i);
+            DetectOneFrame(frame_state, frm_cnt - 1 - i, false);
+        }
+        return 0;
+    }
+
+    int DetectLastFrames() {
+        if (vad_state_machine == VadStateMachine::kVadInStateEndPointDetected) {
+            return 0;
+        }
+        for (int i = vad_opts.nn_eval_block_size - 1; i >= 0; i--) {
+            FrameState frame_state = FrameState::kFrameStateInvalid;
+            frame_state = GetFrameState(frm_cnt - 1 - i);
+            if (i != 0) {
+                DetectOneFrame(frame_state, frm_cnt - 1 - i, false);
+            } else {
+                DetectOneFrame(frame_state, frm_cnt - 1, true);
+            }
+        }
+        return 0;
+    }
+
+    void DetectOneFrame(FrameState cur_frm_state, int cur_frm_idx, bool is_final_frame) {
+        FrameState tmp_cur_frm_state = FrameState::kFrameStateInvalid;
+        if (cur_frm_state == FrameState::kFrameStateSpeech) {
+            if (std::fabs(1.0) > vad_opts.fe_prior_thres) {
+                tmp_cur_frm_state = FrameState::kFrameStateSpeech;
+            } else {
+                tmp_cur_frm_state = FrameState::kFrameStateSil;
+            }
+        } else if (cur_frm_state == FrameState::kFrameStateSil) {
+            tmp_cur_frm_state = FrameState::kFrameStateSil;
+        }
+        AudioChangeState state_change = windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx);
+        int frm_shift_in_ms = vad_opts.frame_in_ms;
+        if (AudioChangeState::kChangeStateSil2Speech == state_change) {
+            int silence_frame_count = continous_silence_frame_count;
+            continous_silence_frame_count = 0;
+            pre_end_silence_detected = false;
+            int start_frame = 0;
+            if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) {
+                start_frame = std::max(data_buf_start_frame, cur_frm_idx - LatencyFrmNumAtStartPoint());
+                OnVoiceStart(start_frame);
+                vad_state_machine = VadStateMachine::kVadInStateInSpeechSegment;
+                for (int t = start_frame + 1; t <= cur_frm_idx; t++) {
+                    OnVoiceDetected(t);
+                }
+            } else if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) {
+                for (int t = latest_confirmed_speech_frame + 1; t < cur_frm_idx; t++) {
+                    OnVoiceDetected(t);
+                }
+                if (cur_frm_idx - confirmed_start_frame + 1 > vad_opts.max_single_segment_time / frm_shift_in_ms) {
+                    OnVoiceEnd(cur_frm_idx, false, false);
+                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
+                } else if (!is_final_frame) {
+                    OnVoiceDetected(cur_frm_idx);
+                } else {
+                    MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
+                }
+            }
+        } else if (AudioChangeState::kChangeStateSpeech2Sil == state_change) {
+            continous_silence_frame_count = 0;
+            if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) {
+                // do nothing
+            } else if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) {
+                if (cur_frm_idx - confirmed_start_frame + 1 >
+                    vad_opts.max_single_segment_time / frm_shift_in_ms) {
+                    OnVoiceEnd(cur_frm_idx, false, false);
+                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
+                } else if (!is_final_frame) {
+                    OnVoiceDetected(cur_frm_idx);
+                } else {
+                    MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
+                }
+            }
+        } else if (AudioChangeState::kChangeStateSpeech2Speech == state_change) {
+            continous_silence_frame_count = 0;
+            if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) {
+                if (cur_frm_idx - confirmed_start_frame + 1 >
+                    vad_opts.max_single_segment_time / frm_shift_in_ms) {
+                    max_time_out = true;
+                    OnVoiceEnd(cur_frm_idx, false, false);
+                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
+                } else if (!is_final_frame) {
+                    OnVoiceDetected(cur_frm_idx);
+                } else {
+                    MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
+                }
+            }
+        } else if (AudioChangeState::kChangeStateSil2Sil == state_change) {
+            continous_silence_frame_count += 1;
+            if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) {
+                if ((vad_opts.detect_mode == static_cast<int>(VadDetectMode::kVadSingleUtteranceDetectMode) &&
+                     (continous_silence_frame_count * frm_shift_in_ms > vad_opts.max_start_silence_time)) ||
+                    (is_final_frame && number_end_time_detected == 0)) {
+                    for (int t = lastest_confirmed_silence_frame + 1; t < cur_frm_idx; t++) {
+                        OnSilenceDetected(t);
+                    }
+                    OnVoiceStart(0, true);
+                    OnVoiceEnd(0, true, false);
+                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
+                } else {
+                    if (cur_frm_idx >= LatencyFrmNumAtStartPoint()) {
+                        OnSilenceDetected(cur_frm_idx - LatencyFrmNumAtStartPoint());
+                    }
+                }
+            } else if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) {
+                if (continous_silence_frame_count * frm_shift_in_ms >= max_end_sil_frame_cnt_thresh) {
+                    int lookback_frame = max_end_sil_frame_cnt_thresh / frm_shift_in_ms;
+                    if (vad_opts.do_extend) {
+                        lookback_frame -= vad_opts.lookahead_time_end_point / frm_shift_in_ms;
+                        lookback_frame -= 1;
+                        lookback_frame = std::max(0, lookback_frame);
+                    }
+                    OnVoiceEnd(cur_frm_idx - lookback_frame, false, false);
+                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
+                } else if (cur_frm_idx - confirmed_start_frame + 1 >
+                           vad_opts.max_single_segment_time / frm_shift_in_ms) {
+                    OnVoiceEnd(cur_frm_idx, false, false);
+                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
+                } else if (vad_opts.do_extend && !is_final_frame) {
+                    if (continous_silence_frame_count <= vad_opts.lookahead_time_end_point / frm_shift_in_ms) {
+                        OnVoiceDetected(cur_frm_idx);
+                    }
+                } else {
+                    MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
+                }
+            }
+        }
+        if (vad_state_machine == VadStateMachine::kVadInStateEndPointDetected &&
+            vad_opts.detect_mode == static_cast<int>(VadDetectMode::kVadMutipleUtteranceDetectMode)) {
+            ResetDetection();
+        }
+    }
+
+};
+
+
+
diff --git a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
index a4780b2..f15e86f 100644
--- a/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
+++ b/funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -4,14 +4,14 @@
 extern "C" {
 #endif
 
-	// APIs for qmasr
-	_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThreadNum, bool quantize)
+	// APIs for funasr
+	_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThreadNum, bool quantize, bool use_vad)
 	{
-		Model* mm = create_model(szModelDir, nThreadNum, quantize);
+		Model* mm = create_model(szModelDir, nThreadNum, quantize, use_vad);
 		return mm;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -21,7 +21,9 @@
 		Audio audio(1);
 		if (!audio.loadwav(szBuf, nLen, &sampling_rate))
 			return nullptr;
-		//audio.split();
+		if(use_vad){
+			audio.split(pRecogObj);
+		}
 
 		float* buff;
 		int len;
@@ -31,7 +33,6 @@
 		int nStep = 0;
 		int nTotal = audio.get_queue_size();
 		while (audio.fetch(buff, len, flag) > 0) {
-			//pRecogObj->reset();
 			string msg = pRecogObj->forward(buff, len, flag);
 			pResult->msg += msg;
 			nStep++;
@@ -42,7 +43,7 @@
 		return pResult;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -51,7 +52,9 @@
 		Audio audio(1);
 		if (!audio.loadpcmwav(szBuf, nLen, &sampling_rate))
 			return nullptr;
-		//audio.split();
+		if(use_vad){
+			audio.split(pRecogObj);
+		}
 
 		float* buff;
 		int len;
@@ -61,7 +64,6 @@
 		int nStep = 0;
 		int nTotal = audio.get_queue_size();
 		while (audio.fetch(buff, len, flag) > 0) {
-			//pRecogObj->reset();
 			string msg = pRecogObj->forward(buff, len, flag);
 			pResult->msg += msg;
 			nStep++;
@@ -72,7 +74,7 @@
 		return pResult;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -81,7 +83,9 @@
 		Audio audio(1);
 		if (!audio.loadpcmwav(szFileName, &sampling_rate))
 			return nullptr;
-		//audio.split();
+		if(use_vad){
+			audio.split(pRecogObj);
+		}
 
 		float* buff;
 		int len;
@@ -91,7 +95,6 @@
 		int nStep = 0;
 		int nTotal = audio.get_queue_size();
 		while (audio.fetch(buff, len, flag) > 0) {
-			//pRecogObj->reset();
 			string msg = pRecogObj->forward(buff, len, flag);
 			pResult->msg += msg;
 			nStep++;
@@ -102,7 +105,7 @@
 		return pResult;
 	}
 
-	_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
+	_FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
 	{
 		Model* pRecogObj = (Model*)handle;
 		if (!pRecogObj)
@@ -112,7 +115,9 @@
 		Audio audio(1);
 		if(!audio.loadwav(szWavfile, &sampling_rate))
 			return nullptr;
-		//audio.split();
+		if(use_vad){
+			audio.split(pRecogObj);
+		}
 
 		float* buff;
 		int len;
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
index a0dd6d4..16cf57a 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -3,11 +3,19 @@
 using namespace std;
 using namespace paraformer;
 
-ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
+ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad)
 :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
     string model_path;
     string cmvn_path;
     string config_path;
+
+    // VAD model
+    if(use_vad){
+        string vad_path = pathAppend(path, "vad_model.onnx");
+        string mvn_path = pathAppend(path, "vad.mvn");
+        vadHandle = make_unique<FsmnVad>();
+        vadHandle->init_vad(vad_path, mvn_path, model_sample_rate, 800, 15000, 0.9);
+    }
 
     if(quantize)
     {
@@ -30,8 +38,10 @@
     //fbank_ = std::make_unique<knf::OnlineFbank>(fbank_opts);
 
     //sessionOptions.SetInterOpNumThreads(1);
+    //sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
     sessionOptions.SetIntraOpNumThreads(nNumThread);
-    sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
+    sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
+    sessionOptions.DisableCpuMemArena();
 
 #ifdef _WIN32
     wstring wstrPath = strToWstr(model_path);
@@ -67,6 +77,10 @@
 
 void ModelImp::reset()
 {
+}
+
+vector<std::vector<int>> ModelImp::vad_seg(std::vector<float>& pcm_data){
+    return vadHandle->infer(pcm_data);
 }
 
 vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) {
@@ -172,66 +186,6 @@
       p += dim;
     }
   }
-
-//   void ParaformerOnnxAsrModel::ForwardFunc(
-//     const std::vector<std::vector<float>>& chunk_feats,
-//     std::vector<std::vector<float>>* out_prob) {
-//   Ort::MemoryInfo memory_info =
-//       Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
-//   // 1. Prepare onnx required data, splice cached_feature_ and chunk_feats
-//   // chunk
-// //  int num_frames = cached_feature_.size() + chunk_feats.size();
-//   int num_frames = chunk_feats.size();
-//   const int feature_dim = chunk_feats[0].size();
-
-//   //  2. Generate 2 input nodes tensor
-//   // speech node { batch,frame number,feature dim }
-//   const int64_t paraformer_feats_shape[3] = {1, num_frames, feature_dim};
-//   std::vector<float> paraformer_feats;
-//   for (const auto & chunk_feat : chunk_feats) {
-//     paraformer_feats.insert(paraformer_feats.end(), chunk_feat.begin(), chunk_feat.end());
-//   }
-//   Ort::Value paraformer_feats_ort = Ort::Value::CreateTensor<float>(
-//           memory_info, paraformer_feats.data(), paraformer_feats.size(), paraformer_feats_shape, 3);
-//   // speech_lengths node {speech length,}
-//   const int64_t paraformer_length_shape[1] = {1};
-//   std::vector<int32_t> paraformer_length;
-//   paraformer_length.emplace_back(num_frames);
-//   Ort::Value paraformer_length_ort = Ort::Value::CreateTensor<int32_t>(
-//           memory_info, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1);
-
-//   // 3. Put nodes into onnx input vector
-//   std::vector<Ort::Value> paraformer_inputs;
-//   paraformer_inputs.emplace_back(std::move(paraformer_feats_ort));
-//   paraformer_inputs.emplace_back(std::move(paraformer_length_ort));
-
-//   // 4. Onnx infer
-//   std::vector<Ort::Value> paraformer_ort_outputs;
-//   try{
-//     VLOG(3) << "Start infer";
-//     paraformer_ort_outputs = paraformer_session_->Run(
-//             Ort::RunOptions{nullptr}, paraformer_in_names_.data(), paraformer_inputs.data(),
-//             paraformer_inputs.size(), paraformer_out_names_.data(), paraformer_out_names_.size());
-//   }catch (std::exception const& e) {
-//     //  Catch "Non-zero status code returned error",usually because there is no asr result.
-//     // Need funasr to resolve.
-//     LOG(ERROR) << e.what();
-//     return;
-//   }
-
-//   // 5. Change infer result to output shapes
-//   float* logp_data = paraformer_ort_outputs[0].GetTensorMutableData<float>();
-//   auto type_info = paraformer_ort_outputs[0].GetTensorTypeAndShapeInfo();
-
-//   int num_outputs = type_info.GetShape()[1];
-//   int output_dim = type_info.GetShape()[2];
-//   out_prob->resize(num_outputs);
-//   for (int i = 0; i < num_outputs; i++) {
-//     (*out_prob)[i].resize(output_dim);
-//     memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
-//            sizeof(float) * output_dim);
-//   }
-// }
 
 string ModelImp::forward(float* din, int len, int flag)
 {
diff --git a/funasr/runtime/onnxruntime/src/paraformer_onnx.h b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
index 6442af3..b0712b4 100644
--- a/funasr/runtime/onnxruntime/src/paraformer_onnx.h
+++ b/funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -4,8 +4,7 @@
 #ifndef PARAFORMER_MODELIMP_H
 #define PARAFORMER_MODELIMP_H
 
-#include "kaldi-native-fbank/csrc/feature-fbank.h"
-#include "kaldi-native-fbank/csrc/online-feature.h"
+#include "precomp.h"
 
 namespace paraformer {
 
@@ -13,6 +12,8 @@
     private:
         //std::unique_ptr<knf::OnlineFbank> fbank_;
         knf::FbankOptions fbank_opts;
+
+        std::unique_ptr<FsmnVad> vadHandle;
 
         Vocab* vocab;
         vector<float> means_list;
@@ -27,7 +28,7 @@
 
         string greedy_search( float* in, int nLen);
 
-        std::unique_ptr<Ort::Session> m_session;
+        std::shared_ptr<Ort::Session> m_session;
         Ort::Env env_;
         Ort::SessionOptions sessionOptions;
 
@@ -36,13 +37,14 @@
         vector<const char*> m_szOutputNames;
 
     public:
-        ModelImp(const char* path, int nNumThread=0, bool quantize=false);
+        ModelImp(const char* path, int nNumThread=0, bool quantize=false, bool use_vad=false);
         ~ModelImp();
         void reset();
         vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
         string forward_chunk(float* din, int len, int flag);
         string forward(float* din, int len, int flag);
         string rescoring();
+        std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data);
 
     };
 
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index 8155b67..c83efb8 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -26,6 +26,8 @@
 #include <fftw3.h>
 #include "onnxruntime_run_options_config_keys.h"
 #include "onnxruntime_cxx_api.h"
+#include "kaldi-native-fbank/csrc/feature-fbank.h"
+#include "kaldi-native-fbank/csrc/online-feature.h"
 
 
 // mine
@@ -33,6 +35,7 @@
 #include "commonfunc.h"
 #include <ComDefine.h>
 #include "predefine_coe.h"
+#include "FsmnVad.h"
 
 #include <ComDefine.h>
 //#include "alignedmem.h"
diff --git a/funasr/runtime/onnxruntime/tester/tester.cpp b/funasr/runtime/onnxruntime/tester/tester.cpp
index f4a19de..4cb38df 100644
--- a/funasr/runtime/onnxruntime/tester/tester.cpp
+++ b/funasr/runtime/onnxruntime/tester/tester.cpp
@@ -14,9 +14,9 @@
 
 int main(int argc, char *argv[])
 {
-    if (argc < 4)
+    if (argc < 5)
     {
-        printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) \n", argv[0]);
+        printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) use_vad(true or false) \n", argv[0]);
         exit(-1);
     }
     struct timeval start, end;
@@ -24,8 +24,10 @@
     int nThreadNum = 1;
     // is quantize
     bool quantize = false;
+    bool use_vad = false;
     istringstream(argv[3]) >> boolalpha >> quantize;
-    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
+    istringstream(argv[4]) >> boolalpha >> use_vad;
+    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize, use_vad);
 
     if (!AsrHanlde)
     {
@@ -41,7 +43,7 @@
     gettimeofday(&start, NULL);
     float snippet_time = 0.0f;
 
-    FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
+    FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL, use_vad);
 
     gettimeofday(&end, NULL);
    
diff --git a/funasr/runtime/onnxruntime/tester/tester_rtf.cpp b/funasr/runtime/onnxruntime/tester/tester_rtf.cpp
index 067d5c0..7f2368e 100644
--- a/funasr/runtime/onnxruntime/tester/tester_rtf.cpp
+++ b/funasr/runtime/onnxruntime/tester/tester_rtf.cpp
@@ -19,7 +19,7 @@
 std::atomic<int> index(0);
 std::mutex mtx;
 
-void runReg(FUNASR_HANDLE AsrHanlde, vector<string> wav_list, 
+void runReg(FUNASR_HANDLE AsrHandle, vector<string> wav_list, 
             float* total_length, long* total_time, int core_id) {
 
     // cpu_set_t cpuset;
@@ -37,7 +37,7 @@
     // warm up
     for (size_t i = 0; i < 1; i++)
     {
-        FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
+        FUNASR_RESULT Result=FunASRRecogFile(AsrHandle, wav_list[0].c_str(), RASR_NONE, NULL);
     }
 
     while (true) {
@@ -48,7 +48,7 @@
         }
 
         gettimeofday(&start, NULL);
-        FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
+        FUNASR_RESULT Result=FunASRRecogFile(AsrHandle, wav_list[i].c_str(), RASR_NONE, NULL);
 
         gettimeofday(&end, NULL);
         seconds = (end.tv_sec - start.tv_sec);
@@ -112,8 +112,8 @@
     int nThreadNum = 1;
     nThreadNum = atoi(argv[4]);
 
-    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], 1, quantize);
-    if (!AsrHanlde)
+    FUNASR_HANDLE AsrHandle=FunASRInit(argv[1], 1, quantize);
+    if (!AsrHandle)
     {
         printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
         exit(-1);
@@ -130,7 +130,7 @@
 
     for (int i = 0; i < nThreadNum; i++)
     {
-        threads.emplace_back(thread(runReg, AsrHanlde, wav_list, &total_length, &total_time, i));
+        threads.emplace_back(thread(runReg, AsrHandle, wav_list, &total_length, &total_time, i));
     }
 
     for (auto& thread : threads)
@@ -142,6 +142,6 @@
     printf("total_time_comput %ld ms.\n", total_time / 1000);
     printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
 
-    FunASRUninit(AsrHanlde);
+    FunASRUninit(AsrHandle);
     return 0;
 }

--
Gitblit v1.9.1