lyblsgo
2023-04-20 ae3e2567602546e66c0f358463617e560fc70e20
add offline vad for onnxruntime
11个文件已修改
3个文件已添加
1363 ■■■■ 已修改文件
funasr/runtime/onnxruntime/include/Audio.h 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/Model.h 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/libfunasrapi.h 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/Audio.cpp 84 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/FsmnVad.cc 268 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/FsmnVad.h 57 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/Model.cpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/e2e_vad.h 782 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/libfunasrapi.cpp 33 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp 78 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer_onnx.h 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/precomp.h 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/tester/tester.cpp 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/tester/tester_rtf.cpp 14 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/Audio.h
@@ -5,6 +5,7 @@
#include <ComDefine.h>
#include <queue>
#include <stdint.h>
#include "Model.h"
#ifndef model_sample_rate
#define model_sample_rate 16000
@@ -27,7 +28,7 @@
    ~AudioFrame();
    int set_start(int val);
    int set_end(int val, int max_len);
    int set_end(int val);
    int get_start();
    int get_len();
    int disp();
@@ -57,7 +58,7 @@
    int fetch_chunck(float *&dout, int len);
    int fetch(float *&dout, int &len, int &flag);
    void padding();
    void split();
    void split(Model* pRecogObj);
    float get_time_len();
    int get_queue_size() { return (int)frame_queue.size(); }
funasr/runtime/onnxruntime/include/Model.h
@@ -11,7 +11,8 @@
    virtual std::string forward_chunk(float *din, int len, int flag) = 0;
    virtual std::string forward(float *din, int len, int flag) = 0;
    virtual std::string rescoring() = 0;
    virtual std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data)=0;
};
Model *create_model(const char *path,int nThread=0,bool quantize=false);
Model *create_model(const char *path,int nThread=0,bool quantize=false, bool use_vad=false);
#endif
funasr/runtime/onnxruntime/include/libfunasrapi.h
@@ -48,18 +48,18 @@
typedef void (* QM_CALLBACK)(int nCurStep, int nTotal); // nTotal: total steps; nCurStep: Current Step.
    
// APIs for qmasr
_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThread, bool quantize);
// APIs for funasr
_FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThread, bool quantize=false, bool use_vad=false);
// if not give a fnCallback ,it should be NULL 
_FUNASRAPI FUNASR_RESULT    FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI FUNASR_RESULT    FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
_FUNASRAPI FUNASR_RESULT    FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI FUNASR_RESULT    FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
_FUNASRAPI FUNASR_RESULT    FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI FUNASR_RESULT    FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
_FUNASRAPI FUNASR_RESULT    FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback);
_FUNASRAPI FUNASR_RESULT    FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad=false);
_FUNASRAPI const char*    FunASRGetResult(FUNASR_RESULT Result,int nIndex);
funasr/runtime/onnxruntime/src/Audio.cpp
@@ -134,19 +134,10 @@
    return start;
};
int AudioFrame::set_end(int val, int max_len)
int AudioFrame::set_end(int val)
{
    float num_samples = val - start;
    float frame_length = 400;
    float frame_shift = 160;
    float num_new_samples =
        ceil((num_samples - frame_length) / frame_shift) * frame_shift + frame_length;
    end = start + num_new_samples;
    len = (int)num_new_samples;
    if (end > max_len)
        printf("frame end > max_len!!!!!!!\n");
    end = val;
    len = end - start;
    return end;
};
@@ -473,7 +464,6 @@
void Audio::padding()
{
    float num_samples = speech_len;
    float frame_length = 400;
    float frame_shift = 160;
@@ -509,71 +499,27 @@
    delete frame;
}
#define UNTRIGGERED 0
#define TRIGGERED   1
#define SPEECH_LEN_5S  (16000 * 5)
#define SPEECH_LEN_10S (16000 * 10)
#define SPEECH_LEN_20S (16000 * 20)
#define SPEECH_LEN_30S (16000 * 30)
/*
void Audio::split()
void Audio::split(Model* pRecogObj)
{
    VadInst *handle = WebRtcVad_Create();
    WebRtcVad_Init(handle);
    WebRtcVad_set_mode(handle, 2);
    int window_size = 10;
    AudioWindow audiowindow(window_size);
    int status = UNTRIGGERED;
    int offset = 0;
    int fs = 16000;
    int step = 480;
    AudioFrame *frame;
    frame = frame_queue.front();
    frame_queue.pop();
    int sp_len = frame->get_len();
    delete frame;
    frame = NULL;
    while (offset < speech_len - step) {
        int n = WebRtcVad_Process(handle, fs, speech_buff + offset, step);
        if (status == UNTRIGGERED && audiowindow.put(n) >= window_size - 1) {
            frame = new AudioFrame();
            int start = offset - step * (window_size - 1);
            frame->set_start(start);
            status = TRIGGERED;
        } else if (status == TRIGGERED) {
            int win_weight = audiowindow.put(n);
            int voice_len = (offset - frame->get_start());
            int gap = 0;
            if (voice_len < SPEECH_LEN_5S) {
                offset += step;
                continue;
            } else if (voice_len < SPEECH_LEN_10S) {
                gap = 1;
            } else if (voice_len < SPEECH_LEN_20S) {
                gap = window_size / 5;
            } else {
                gap = window_size / 2;
            }
            if (win_weight < gap) {
                status = UNTRIGGERED;
                offset = frame->set_end(offset, speech_align_len);
                frame_queue.push(frame);
                frame = NULL;
            }
        }
        offset += step;
    }
    if (frame != NULL) {
        frame->set_end(speech_len, speech_align_len);
    std::vector<float> pcm_data(speech_data, speech_data+sp_len);
    vector<std::vector<int>> vad_segments = pRecogObj->vad_seg(pcm_data);
    int seg_sample = model_sample_rate/1000;
    for(vector<int> segment:vad_segments)
    {
        frame = new AudioFrame();
        int start = segment[0]*seg_sample;
        int end = segment[1]*seg_sample;
        frame->set_start(start);
        frame->set_end(end);
        frame_queue.push(frame);
        frame = NULL;
    }
    WebRtcVad_Free(handle);
}
*/
funasr/runtime/onnxruntime/src/FsmnVad.cc
New file
@@ -0,0 +1,268 @@
//
// Created by root on 4/9/23.
//
#include <fstream>
#include "FsmnVad.h"
#include "precomp.h"
//#include "glog/logging.h"
void FsmnVad::init_vad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
                       float vad_speech_noise_thres) {
    session_options_.SetIntraOpNumThreads(1);
    session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
    session_options_.DisableCpuMemArena();
    this->vad_sample_rate_ = vad_sample_rate;
    this->vad_silence_duration_=vad_silence_duration;
    this->vad_max_len_=vad_max_len;
    this->vad_speech_noise_thres_=vad_speech_noise_thres;
    read_model(vad_model);
    load_cmvn(vad_cmvn.c_str());
    fbank_opts.frame_opts.dither = 0;
    fbank_opts.mel_opts.num_bins = 80;
    fbank_opts.frame_opts.samp_freq = vad_sample_rate;
    fbank_opts.frame_opts.window_type = "hamming";
    fbank_opts.frame_opts.frame_shift_ms = 10;
    fbank_opts.frame_opts.frame_length_ms = 25;
    fbank_opts.energy_floor = 0;
    fbank_opts.mel_opts.debug_mel = false;
}
void FsmnVad::read_model(const std::string &vad_model) {
    try {
        vad_session_ = std::make_shared<Ort::Session>(
                env_, vad_model.c_str(), session_options_);
    } catch (std::exception const &e) {
        //LOG(ERROR) << "Error when load onnx model: " << e.what();
        exit(0);
    }
    //LOG(INFO) << "vad onnx:";
    GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_);
}
void FsmnVad::GetInputOutputInfo(
        const std::shared_ptr<Ort::Session> &session,
        std::vector<const char *> *in_names, std::vector<const char *> *out_names) {
    Ort::AllocatorWithDefaultOptions allocator;
    // Input info
    int num_nodes = session->GetInputCount();
    in_names->resize(num_nodes);
    for (int i = 0; i < num_nodes; ++i) {
        std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetInputNameAllocated(i, allocator);
        Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
        ONNXTensorElementDataType type = tensor_info.GetElementType();
        std::vector<int64_t> node_dims = tensor_info.GetShape();
        std::stringstream shape;
        for (auto j: node_dims) {
            shape << j;
            shape << " ";
        }
        // LOG(INFO) << "\tInput " << i << " : name=" << name.get() << " type=" << type
        //           << " dims=" << shape.str();
        (*in_names)[i] = name.get();
        name.release();
    }
    // Output info
    num_nodes = session->GetOutputCount();
    out_names->resize(num_nodes);
    for (int i = 0; i < num_nodes; ++i) {
        std::unique_ptr<char, Ort::detail::AllocatedFree> name = session->GetOutputNameAllocated(i, allocator);
        Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
        auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
        ONNXTensorElementDataType type = tensor_info.GetElementType();
        std::vector<int64_t> node_dims = tensor_info.GetShape();
        std::stringstream shape;
        for (auto j: node_dims) {
            shape << j;
            shape << " ";
        }
        // LOG(INFO) << "\tOutput " << i << " : name=" << name.get() << " type=" << type
        //           << " dims=" << shape.str();
        (*out_names)[i] = name.get();
        name.release();
    }
}
void FsmnVad::Forward(
        const std::vector<std::vector<float>> &chunk_feats,
        std::vector<std::vector<float>> *out_prob) {
    Ort::MemoryInfo memory_info =
            Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
    int num_frames = chunk_feats.size();
    const int feature_dim = chunk_feats[0].size();
    //  2. Generate input nodes tensor
    // vad node { batch,frame number,feature dim }
    const int64_t vad_feats_shape[3] = {1, num_frames, feature_dim};
    std::vector<float> vad_feats;
    for (const auto &chunk_feat: chunk_feats) {
        vad_feats.insert(vad_feats.end(), chunk_feat.begin(), chunk_feat.end());
    }
    Ort::Value vad_feats_ort = Ort::Value::CreateTensor<float>(
            memory_info, vad_feats.data(), vad_feats.size(), vad_feats_shape, 3);
    // cache node {batch,128,19,1}
    const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
    std::vector<float> cache_feats(128 * 19 * 1, 0);
    Ort::Value cache_feats_ort = Ort::Value::CreateTensor<float>(
            memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4);
    // 3. Put nodes into onnx input vector
    std::vector<Ort::Value> vad_inputs;
    vad_inputs.emplace_back(std::move(vad_feats_ort));
    // 4 caches
    for (int i = 0; i < 4; i++) {
        vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
                memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4)));
    }
    // 4. Onnx infer
    std::vector<Ort::Value> vad_ort_outputs;
    try {
        // VLOG(3) << "Start infer";
        vad_ort_outputs = vad_session_->Run(
                Ort::RunOptions{nullptr}, vad_in_names_.data(), vad_inputs.data(),
                vad_inputs.size(), vad_out_names_.data(), vad_out_names_.size());
    } catch (std::exception const &e) {
        // LOG(ERROR) << e.what();
        return;
    }
    // 5. Change infer result to output shapes
    float *logp_data = vad_ort_outputs[0].GetTensorMutableData<float>();
    auto type_info = vad_ort_outputs[0].GetTensorTypeAndShapeInfo();
    int num_outputs = type_info.GetShape()[1];
    int output_dim = type_info.GetShape()[2];
    out_prob->resize(num_outputs);
    for (int i = 0; i < num_outputs; i++) {
        (*out_prob)[i].resize(output_dim);
        memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
               sizeof(float) * output_dim);
    }
}
void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
                         const std::vector<float> &waves) {
    knf::OnlineFbank fbank(fbank_opts);
    fbank.AcceptWaveform(sample_rate, &waves[0], waves.size());
    int32_t frames = fbank.NumFramesReady();
    for (int32_t i = 0; i != frames; ++i) {
        const float *frame = fbank.GetFrame(i);
        std::vector<float> frame_vector(frame, frame + fbank_opts.mel_opts.num_bins);
        vad_feats.emplace_back(frame_vector);
    }
}
void FsmnVad::load_cmvn(const char *filename)
{
    using namespace std;
    ifstream cmvn_stream(filename);
    string line;
    while (getline(cmvn_stream, line)) {
        istringstream iss(line);
        vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
        if (line_item[0] == "<AddShift>") {
            getline(cmvn_stream, line);
            istringstream means_lines_stream(line);
            vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
            if (means_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < means_lines.size() - 1; j++) {
                    means_list.push_back(stof(means_lines[j]));
                }
                continue;
            }
        }
        else if (line_item[0] == "<Rescale>") {
            getline(cmvn_stream, line);
            istringstream vars_lines_stream(line);
            vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
            if (vars_lines[0] == "<LearnRateCoef>") {
                for (int j = 3; j < vars_lines.size() - 1; j++) {
                    // vars_list.push_back(stof(vars_lines[j])*scale);
                    vars_list.push_back(stof(vars_lines[j]));
                }
                continue;
            }
        }
    }
}
std::vector<std::vector<float>> &FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats, int lfr_m, int lfr_n) {
    std::vector<std::vector<float>> out_feats;
    int T = vad_feats.size();
    int T_lrf = ceil(1.0 * T / lfr_n);
    // Pad frames at start(copy first frame)
    for (int i = 0; i < (lfr_m - 1) / 2; i++) {
        vad_feats.insert(vad_feats.begin(), vad_feats[0]);
    }
    // Merge lfr_m frames as one,lfr_n frames per window
    T = T + (lfr_m - 1) / 2;
    std::vector<float> p;
    for (int i = 0; i < T_lrf; i++) {
        if (lfr_m <= T - i * lfr_n) {
            for (int j = 0; j < lfr_m; j++) {
                p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
            }
            out_feats.emplace_back(p);
            p.clear();
        } else {
            // Fill to lfr_m frames at last window if less than lfr_m frames  (copy last frame)
            int num_padding = lfr_m - (T - i * lfr_n);
            for (int j = 0; j < (vad_feats.size() - i * lfr_n); j++) {
                p.insert(p.end(), vad_feats[i * lfr_n + j].begin(), vad_feats[i * lfr_n + j].end());
            }
            for (int j = 0; j < num_padding; j++) {
                p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
            }
            out_feats.emplace_back(p);
        }
    }
    // Apply cmvn
    for (auto &out_feat: out_feats) {
        for (int j = 0; j < means_list.size(); j++) {
            out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j];
        }
    }
    vad_feats = out_feats;
    return vad_feats;
}
std::vector<std::vector<int>>
FsmnVad::infer(const std::vector<float> &waves) {
    std::vector<std::vector<float>> vad_feats;
    std::vector<std::vector<float>> vad_probs;
    FbankKaldi(vad_sample_rate_, vad_feats, waves);
    vad_feats = LfrCmvn(vad_feats, 5, 1);
    Forward(vad_feats, &vad_probs);
    E2EVadModel vad_scorer = E2EVadModel();
    std::vector<std::vector<int>> vad_segments;
    vad_segments = vad_scorer(vad_probs, waves, true, vad_silence_duration_, vad_max_len_,
                              vad_speech_noise_thres_, vad_sample_rate_);
    return vad_segments;
}
void FsmnVad::test() {
}
FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} {
}
funasr/runtime/onnxruntime/src/FsmnVad.h
New file
@@ -0,0 +1,57 @@
//
// Created by zyf on 4/9/23.
//
#ifndef VAD_SERVER_FSMNVAD_H
#define VAD_SERVER_FSMNVAD_H
#include "e2e_vad.h"
#include "onnxruntime_cxx_api.h"
#include "kaldi-native-fbank/csrc/feature-fbank.h"
#include "kaldi-native-fbank/csrc/online-feature.h"
class FsmnVad {
public:
    FsmnVad();
    void test();
    void init_vad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
                  float vad_speech_noise_thres);
    std::vector<std::vector<int>> infer(const std::vector<float> &waves);
private:
    void read_model(const std::string &vad_model);
    static void GetInputOutputInfo(
            const std::shared_ptr<Ort::Session> &session,
            std::vector<const char *> *in_names, std::vector<const char *> *out_names);
    void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
                    const std::vector<float> &waves);
    std::vector<std::vector<float>> &LfrCmvn(std::vector<std::vector<float>> &vad_feats, int lfr_m, int lfr_n);
    void Forward(
            const std::vector<std::vector<float>> &chunk_feats,
            std::vector<std::vector<float>> *out_prob);
    void load_cmvn(const char *filename);
    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
    Ort::Env env_;
    Ort::SessionOptions session_options_;
    std::vector<const char *> vad_in_names_;
    std::vector<const char *> vad_out_names_;
    knf::FbankOptions fbank_opts;
    std::vector<float> means_list;
    std::vector<float> vars_list;
    int vad_sample_rate_ = 16000;
    int vad_silence_duration_ = 800;
    int vad_max_len_ = 15000;
    double vad_speech_noise_thres_ = 0.9;
};
#endif //VAD_SERVER_FSMNVAD_H
funasr/runtime/onnxruntime/src/Model.cpp
@@ -1,10 +1,10 @@
#include "precomp.h"
Model *create_model(const char *path, int nThread, bool quantize)
Model *create_model(const char *path, int nThread, bool quantize, bool use_vad)
{
    Model *mm;
    mm = new paraformer::ModelImp(path, nThread, quantize);
    mm = new paraformer::ModelImp(path, nThread, quantize, use_vad);
    return mm;
}
funasr/runtime/onnxruntime/src/e2e_vad.h
New file
@@ -0,0 +1,782 @@
//
// Created by root on 3/31/23.
//
#include <utility>
#include <vector>
#include <string>
#include <map>
#include <cmath>
#include <algorithm>
#include <iostream>
#include <numeric>
#include <cassert>
enum class VadStateMachine {
    kVadInStateStartPointNotDetected = 1,
    kVadInStateInSpeechSegment = 2,
    kVadInStateEndPointDetected = 3
};
enum class FrameState {
    kFrameStateInvalid = -1,
    kFrameStateSpeech = 1,
    kFrameStateSil = 0
};
// final voice/unvoice state per frame
enum class AudioChangeState {
    kChangeStateSpeech2Speech = 0,
    kChangeStateSpeech2Sil = 1,
    kChangeStateSil2Sil = 2,
    kChangeStateSil2Speech = 3,
    kChangeStateNoBegin = 4,
    kChangeStateInvalid = 5
};
enum class VadDetectMode {
    kVadSingleUtteranceDetectMode = 0,
    kVadMutipleUtteranceDetectMode = 1
};
class VADXOptions {
public:
    int sample_rate;
    int detect_mode;
    int snr_mode;
    int max_end_silence_time;
    int max_start_silence_time;
    bool do_start_point_detection;
    bool do_end_point_detection;
    int window_size_ms;
    int sil_to_speech_time_thres;
    int speech_to_sil_time_thres;
    float speech_2_noise_ratio;
    int do_extend;
    int lookback_time_start_point;
    int lookahead_time_end_point;
    int max_single_segment_time;
    int nn_eval_block_size;
    int dcd_block_size;
    float snr_thres;
    int noise_frame_num_used_for_snr;
    float decibel_thres;
    float speech_noise_thres;
    float fe_prior_thres;
    int silence_pdf_num;
    std::vector<int> sil_pdf_ids;
    float speech_noise_thresh_low;
    float speech_noise_thresh_high;
    bool output_frame_probs;
    int frame_in_ms;
    int frame_length_ms;
    explicit VADXOptions(
            int sr = 16000,
            int dm = static_cast<int>(VadDetectMode::kVadMutipleUtteranceDetectMode),
            int sm = 0,
            int mset = 800,
            int msst = 3000,
            bool dspd = true,
            bool depd = true,
            int wsm = 200,
            int ststh = 150,
            int sttsh = 150,
            float s2nr = 1.0,
            int de = 1,
            int lbtps = 200,
            int latsp = 100,
            int mss = 15000,
            int nebs = 8,
            int dbs = 4,
            float st = -100.0,
            int nfnus = 100,
            float dt = -100.0,
            float snt = 0.9,
            float fept = 1e-4,
            int spn = 1,
            std::vector<int> spids = {0},
            float sntl = -0.1,
            float snth = 0.3,
            bool ofp = false,
            int fim = 10,
            int flm = 25
    ) :
            sample_rate(sr),
            detect_mode(dm),
            snr_mode(sm),
            max_end_silence_time(mset),
            max_start_silence_time(msst),
            do_start_point_detection(dspd),
            do_end_point_detection(depd),
            window_size_ms(wsm),
            sil_to_speech_time_thres(ststh),
            speech_to_sil_time_thres(sttsh),
            speech_2_noise_ratio(s2nr),
            do_extend(de),
            lookback_time_start_point(lbtps),
            lookahead_time_end_point(latsp),
            max_single_segment_time(mss),
            nn_eval_block_size(nebs),
            dcd_block_size(dbs),
            snr_thres(st),
            noise_frame_num_used_for_snr(nfnus),
            decibel_thres(dt),
            speech_noise_thres(snt),
            fe_prior_thres(fept),
            silence_pdf_num(spn),
            sil_pdf_ids(std::move(spids)),
            speech_noise_thresh_low(sntl),
            speech_noise_thresh_high(snth),
            output_frame_probs(ofp),
            frame_in_ms(fim),
            frame_length_ms(flm) {}
};
class E2EVadSpeechBufWithDoa {
public:
    int start_ms;
    int end_ms;
    std::vector<float> buffer;
    bool contain_seg_start_point;
    bool contain_seg_end_point;
    int doa;
    E2EVadSpeechBufWithDoa() :
            start_ms(0),
            end_ms(0),
            buffer(),
            contain_seg_start_point(false),
            contain_seg_end_point(false),
            doa(0) {}
    void Reset() {
        start_ms = 0;
        end_ms = 0;
        buffer.clear();
        contain_seg_start_point = false;
        contain_seg_end_point = false;
        doa = 0;
    }
};
class E2EVadFrameProb {
public:
    double noise_prob;
    double speech_prob;
    double score;
    int frame_id;
    int frm_state;
    E2EVadFrameProb() :
            noise_prob(0.0),
            speech_prob(0.0),
            score(0.0),
            frame_id(0),
            frm_state(0) {}
};
class WindowDetector {
public:
    int window_size_ms;
    int sil_to_speech_time;
    int speech_to_sil_time;
    int frame_size_ms;
    int win_size_frame;
    int win_sum;
    std::vector<int> win_state;
    int cur_win_pos;
    FrameState pre_frame_state;
    FrameState cur_frame_state;
    int sil_to_speech_frmcnt_thres;
    int speech_to_sil_frmcnt_thres;
    int voice_last_frame_count;
    int noise_last_frame_count;
    int hydre_frame_count;
    WindowDetector(int window_size_ms, int sil_to_speech_time, int speech_to_sil_time, int frame_size_ms) :
            window_size_ms(window_size_ms),
            sil_to_speech_time(sil_to_speech_time),
            speech_to_sil_time(speech_to_sil_time),
            frame_size_ms(frame_size_ms),
            win_size_frame(window_size_ms / frame_size_ms),
            win_sum(0),
            win_state(std::vector<int>(win_size_frame, 0)),
            cur_win_pos(0),
            pre_frame_state(FrameState::kFrameStateSil),
            cur_frame_state(FrameState::kFrameStateSil),
            sil_to_speech_frmcnt_thres(sil_to_speech_time / frame_size_ms),
            speech_to_sil_frmcnt_thres(speech_to_sil_time / frame_size_ms),
            voice_last_frame_count(0),
            noise_last_frame_count(0),
            hydre_frame_count(0) {}
    void Reset() {
        cur_win_pos = 0;
        win_sum = 0;
        win_state = std::vector<int>(win_size_frame, 0);
        pre_frame_state = FrameState::kFrameStateSil;
        cur_frame_state = FrameState::kFrameStateSil;
        voice_last_frame_count = 0;
        noise_last_frame_count = 0;
        hydre_frame_count = 0;
    }
    int GetWinSize() {
        return win_size_frame;
    }
    AudioChangeState DetectOneFrame(FrameState frameState, int frame_count) {
        int cur_frame_state = 0;
        if (frameState == FrameState::kFrameStateSpeech) {
            cur_frame_state = 1;
        } else if (frameState == FrameState::kFrameStateSil) {
            cur_frame_state = 0;
        } else {
            return AudioChangeState::kChangeStateInvalid;
        }
        win_sum -= win_state[cur_win_pos];
        win_sum += cur_frame_state;
        win_state[cur_win_pos] = cur_frame_state;
        cur_win_pos = (cur_win_pos + 1) % win_size_frame;
        if (pre_frame_state == FrameState::kFrameStateSil && win_sum >= sil_to_speech_frmcnt_thres) {
            pre_frame_state = FrameState::kFrameStateSpeech;
            return AudioChangeState::kChangeStateSil2Speech;
        }
        if (pre_frame_state == FrameState::kFrameStateSpeech && win_sum <= speech_to_sil_frmcnt_thres) {
            pre_frame_state = FrameState::kFrameStateSil;
            return AudioChangeState::kChangeStateSpeech2Sil;
        }
        if (pre_frame_state == FrameState::kFrameStateSil) {
            return AudioChangeState::kChangeStateSil2Sil;
        }
        if (pre_frame_state == FrameState::kFrameStateSpeech) {
            return AudioChangeState::kChangeStateSpeech2Speech;
        }
        return AudioChangeState::kChangeStateInvalid;
    }
    int FrameSizeMs() {
        return frame_size_ms;
    }
};
class E2EVadModel {
public:
    E2EVadModel() {
        this->vad_opts = VADXOptions();
//    this->windows_detector = WindowDetector(200,150,150,10);
        // this->encoder = encoder;
        // init variables
        this->is_final = false;
        this->data_buf_start_frame = 0;
        this->frm_cnt = 0;
        this->latest_confirmed_speech_frame = 0;
        this->lastest_confirmed_silence_frame = -1;
        this->continous_silence_frame_count = 0;
        this->vad_state_machine = VadStateMachine::kVadInStateStartPointNotDetected;
        this->confirmed_start_frame = -1;
        this->confirmed_end_frame = -1;
        this->number_end_time_detected = 0;
        this->sil_frame = 0;
        this->sil_pdf_ids = this->vad_opts.sil_pdf_ids;
        this->noise_average_decibel = -100.0;
        this->pre_end_silence_detected = false;
        this->next_seg = true;
//    this->output_data_buf = [];
        this->output_data_buf_offset = 0;
//    this->frame_probs = [];
        this->max_end_sil_frame_cnt_thresh =
                this->vad_opts.max_end_silence_time - this->vad_opts.speech_to_sil_time_thres;
        this->speech_noise_thres = this->vad_opts.speech_noise_thres;
        this->max_time_out = false;
//    this->decibel = [];
        this->ResetDetection();
    }
    std::vector<std::vector<int>>
    operator()(const std::vector<std::vector<float>> &score, const std::vector<float> &waveform, bool is_final = false,
               int max_end_sil = 800, int max_single_segment_time = 15000, float speech_noise_thres = 0.9,
               int sample_rate = 16000) {
        max_end_sil_frame_cnt_thresh = max_end_sil - vad_opts.speech_to_sil_time_thres;
        this->waveform = waveform;
        this->vad_opts.max_single_segment_time = max_single_segment_time;
        this->vad_opts.speech_noise_thres = speech_noise_thres;
        this->vad_opts.sample_rate = sample_rate;
        ComputeDecibel();
        ComputeScores(score);
        if (!is_final) {
            DetectCommonFrames();
        } else {
            DetectLastFrames();
        }
        //    std::vector<std::vector<int>> segments;
        //    for (size_t batch_num = 0; batch_num < score.size(); batch_num++) {
        std::vector<std::vector<int>> segment_batch;
        if (output_data_buf.size() > 0) {
            for (size_t i = output_data_buf_offset; i < output_data_buf.size(); i++) {
                if (!output_data_buf[i].contain_seg_start_point) {
                    continue;
                }
                if (!next_seg && !output_data_buf[i].contain_seg_end_point) {
                    continue;
                }
                int start_ms = next_seg ? output_data_buf[i].start_ms : -1;
                int end_ms;
                if (output_data_buf[i].contain_seg_end_point) {
                    end_ms = output_data_buf[i].end_ms;
                    next_seg = true;
                    output_data_buf_offset += 1;
                } else {
                    end_ms = -1;
                    next_seg = false;
                }
                std::vector<int> segment = {start_ms, end_ms};
                segment_batch.push_back(segment);
            }
        }
        //    }
        if (is_final) {
            AllResetDetection();
        }
        return segment_batch;
    }
private:
    VADXOptions vad_opts;
    WindowDetector windows_detector = WindowDetector(200, 150, 150, 10);
    bool is_final;
    int data_buf_start_frame;
    int frm_cnt;
    int latest_confirmed_speech_frame;
    int lastest_confirmed_silence_frame;
    int continous_silence_frame_count;
    VadStateMachine vad_state_machine;
    int confirmed_start_frame;
    int confirmed_end_frame;
    int number_end_time_detected;
    int sil_frame;
    std::vector<int> sil_pdf_ids;
    float noise_average_decibel;
    bool pre_end_silence_detected;
    bool next_seg;
    std::vector<E2EVadSpeechBufWithDoa> output_data_buf;
    int output_data_buf_offset;
    std::vector<E2EVadFrameProb> frame_probs;
    int max_end_sil_frame_cnt_thresh;
    float speech_noise_thres;
    std::vector<std::vector<float>> scores;
    bool max_time_out;
    std::vector<float> decibel;
    std::vector<float> data_buf;
    std::vector<float> data_buf_all;
    std::vector<float> waveform;
    void AllResetDetection() {
        is_final = false;
        data_buf_start_frame = 0;
        frm_cnt = 0;
        latest_confirmed_speech_frame = 0;
        lastest_confirmed_silence_frame = -1;
        continous_silence_frame_count = 0;
        vad_state_machine = VadStateMachine::kVadInStateStartPointNotDetected;
        confirmed_start_frame = -1;
        confirmed_end_frame = -1;
        number_end_time_detected = 0;
        sil_frame = 0;
        sil_pdf_ids = vad_opts.sil_pdf_ids;
        noise_average_decibel = -100.0;
        pre_end_silence_detected = false;
        next_seg = true;
        output_data_buf.clear();
        output_data_buf_offset = 0;
        frame_probs.clear();
        max_end_sil_frame_cnt_thresh = vad_opts.max_end_silence_time - vad_opts.speech_to_sil_time_thres;
        speech_noise_thres = vad_opts.speech_noise_thres;
        scores.clear();
        max_time_out = false;
        decibel.clear();
        data_buf.clear();
        data_buf_all.clear();
        waveform.clear();
        ResetDetection();
    }
    void ResetDetection() {
        continous_silence_frame_count = 0;
        latest_confirmed_speech_frame = 0;
        lastest_confirmed_silence_frame = -1;
        confirmed_start_frame = -1;
        confirmed_end_frame = -1;
        vad_state_machine = VadStateMachine::kVadInStateStartPointNotDetected;
        windows_detector.Reset();
        sil_frame = 0;
        frame_probs.clear();
    }
    void ComputeDecibel() {
        int frame_sample_length = int(vad_opts.frame_length_ms * vad_opts.sample_rate / 1000);
        int frame_shift_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
        if (data_buf_all.empty()) {
            data_buf_all = waveform;
            data_buf = data_buf_all;
        } else {
            data_buf_all.insert(data_buf_all.end(), waveform.begin(), waveform.end());
        }
        for (int offset = 0; offset < waveform.size() - frame_sample_length + 1; offset += frame_shift_length) {
            float sum = 0.0;
            for (int i = 0; i < frame_sample_length; i++) {
                sum += waveform[offset + i] * waveform[offset + i];
            }
//      float decibel = 10 * log10(sum + 0.000001);
            this->decibel.push_back(10 * log10(sum + 0.000001));
        }
    }
    void ComputeScores(const std::vector<std::vector<float>> &scores) {
        vad_opts.nn_eval_block_size = scores.size();
        frm_cnt += scores.size();
        if (this->scores.empty()) {
            this->scores = scores;  // the first calculation
        } else {
            this->scores.insert(this->scores.end(), scores.begin(), scores.end());
        }
    }
    void PopDataBufTillFrame(int frame_idx) {
        while (data_buf_start_frame < frame_idx) {
            int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
            if (data_buf.size() >= frame_sample_length) {
                data_buf_start_frame += 1;
                data_buf.erase(data_buf.begin(), data_buf.begin() + frame_sample_length);
            } else {
                break;
            }
        }
    }
    void PopDataToOutputBuf(int start_frm, int frm_cnt, bool first_frm_is_start_point, bool last_frm_is_end_point,
                            bool end_point_is_sent_end) {
        PopDataBufTillFrame(start_frm);
        int expected_sample_number = int(frm_cnt * vad_opts.sample_rate * vad_opts.frame_in_ms / 1000);
        if (last_frm_is_end_point) {
            int extra_sample = std::max(0, int(vad_opts.frame_length_ms * vad_opts.sample_rate / 1000 -
                                               vad_opts.sample_rate * vad_opts.frame_in_ms / 1000));
            expected_sample_number += int(extra_sample);
        }
        if (end_point_is_sent_end) {
            expected_sample_number = std::max(expected_sample_number, int(data_buf.size()));
        }
        if (data_buf.size() < expected_sample_number) {
            std::cout << "error in calling pop data_buf\n";
        }
        if (output_data_buf.size() == 0 || first_frm_is_start_point) {
            output_data_buf.push_back(E2EVadSpeechBufWithDoa());
            output_data_buf[output_data_buf.size() - 1].Reset();
            output_data_buf[output_data_buf.size() - 1].start_ms = start_frm * vad_opts.frame_in_ms;
            output_data_buf[output_data_buf.size() - 1].end_ms = output_data_buf[output_data_buf.size() - 1].start_ms;
            output_data_buf[output_data_buf.size() - 1].doa = 0;
        }
        E2EVadSpeechBufWithDoa &cur_seg = output_data_buf.back();
        if (cur_seg.end_ms != start_frm * vad_opts.frame_in_ms) {
            std::cout << "warning\n";
        }
        int out_pos = (int) cur_seg.buffer.size();
        int data_to_pop;
        if (end_point_is_sent_end) {
            data_to_pop = expected_sample_number;
        } else {
            data_to_pop = int(frm_cnt * vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
        }
        if (data_to_pop > int(data_buf.size())) {
            std::cout << "VAD data_to_pop is bigger than data_buf.size()!!!\n";
            data_to_pop = (int) data_buf.size();
            expected_sample_number = (int) data_buf.size();
        }
        cur_seg.doa = 0;
        for (int sample_cpy_out = 0; sample_cpy_out < data_to_pop; sample_cpy_out++) {
            cur_seg.buffer.push_back(data_buf.back());
            out_pos++;
        }
        for (int sample_cpy_out = data_to_pop; sample_cpy_out < expected_sample_number; sample_cpy_out++) {
            cur_seg.buffer.push_back(data_buf.back());
            out_pos++;
        }
        if (cur_seg.end_ms != start_frm * vad_opts.frame_in_ms) {
            std::cout << "Something wrong with the VAD algorithm\n";
        }
        data_buf_start_frame += frm_cnt;
        cur_seg.end_ms = (start_frm + frm_cnt) * vad_opts.frame_in_ms;
        if (first_frm_is_start_point) {
            cur_seg.contain_seg_start_point = true;
        }
        if (last_frm_is_end_point) {
            cur_seg.contain_seg_end_point = true;
        }
    }
    void OnSilenceDetected(int valid_frame) {
        lastest_confirmed_silence_frame = valid_frame;
        if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) {
            PopDataBufTillFrame(valid_frame);
        }
        // silence_detected_callback_
        // pass
    }
    void OnVoiceDetected(int valid_frame) {
        latest_confirmed_speech_frame = valid_frame;
        PopDataToOutputBuf(valid_frame, 1, false, false, false);
    }
    void OnVoiceStart(int start_frame, bool fake_result = false) {
        if (vad_opts.do_start_point_detection) {
            // pass
        }
        if (confirmed_start_frame != -1) {
            std::cout << "not reset vad properly\n";
        } else {
            confirmed_start_frame = start_frame;
        }
        if (!fake_result && vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) {
            PopDataToOutputBuf(confirmed_start_frame, 1, true, false, false);
        }
    }
    void OnVoiceEnd(int end_frame, bool fake_result, bool is_last_frame) {
        for (int t = latest_confirmed_speech_frame + 1; t < end_frame; t++) {
            OnVoiceDetected(t);
        }
        if (vad_opts.do_end_point_detection) {
            // pass
        }
        if (confirmed_end_frame != -1) {
            std::cout << "not reset vad properly\n";
        } else {
            confirmed_end_frame = end_frame;
        }
        if (!fake_result) {
            sil_frame = 0;
            PopDataToOutputBuf(confirmed_end_frame, 1, false, true, is_last_frame);
        }
        number_end_time_detected++;
    }
    void MaybeOnVoiceEndIfLastFrame(bool is_final_frame, int cur_frm_idx) {
        if (is_final_frame) {
            OnVoiceEnd(cur_frm_idx, false, true);
            vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
        }
    }
    int GetLatency() {
        return int(LatencyFrmNumAtStartPoint() * vad_opts.frame_in_ms);
    }
    int LatencyFrmNumAtStartPoint() {
        int vad_latency = windows_detector.GetWinSize();
        if (vad_opts.do_extend) {
            vad_latency += int(vad_opts.lookback_time_start_point / vad_opts.frame_in_ms);
        }
        return vad_latency;
    }
    FrameState GetFrameState(int t) {
        FrameState frame_state = FrameState::kFrameStateInvalid;
        float cur_decibel = decibel[t];
        float cur_snr = cur_decibel - noise_average_decibel;
        if (cur_decibel < vad_opts.decibel_thres) {
            frame_state = FrameState::kFrameStateSil;
            DetectOneFrame(frame_state, t, false);
            return frame_state;
        }
        float sum_score = 0.0;
        float noise_prob = 0.0;
        assert(sil_pdf_ids.size() == vad_opts.silence_pdf_num);
        if (sil_pdf_ids.size() > 0) {
            std::vector<float> sil_pdf_scores;
            for (auto sil_pdf_id: sil_pdf_ids) {
                sil_pdf_scores.push_back(scores[t][sil_pdf_id]);
            }
            sum_score = accumulate(sil_pdf_scores.begin(), sil_pdf_scores.end(), 0.0);
            noise_prob = log(sum_score) * vad_opts.speech_2_noise_ratio;
            float total_score = 1.0;
            sum_score = total_score - sum_score;
        }
        float speech_prob = log(sum_score);
        if (vad_opts.output_frame_probs) {
            E2EVadFrameProb frame_prob;
            frame_prob.noise_prob = noise_prob;
            frame_prob.speech_prob = speech_prob;
            frame_prob.score = sum_score;
            frame_prob.frame_id = t;
            frame_probs.push_back(frame_prob);
        }
        if (exp(speech_prob) >= exp(noise_prob) + speech_noise_thres) {
            if (cur_snr >= vad_opts.snr_thres && cur_decibel >= vad_opts.decibel_thres) {
                frame_state = FrameState::kFrameStateSpeech;
            } else {
                frame_state = FrameState::kFrameStateSil;
            }
        } else {
            frame_state = FrameState::kFrameStateSil;
            if (noise_average_decibel < -99.9) {
                noise_average_decibel = cur_decibel;
            } else {
                noise_average_decibel =
                        (cur_decibel + noise_average_decibel * (vad_opts.noise_frame_num_used_for_snr - 1)) /
                        vad_opts.noise_frame_num_used_for_snr;
            }
        }
        return frame_state;
    }
    int DetectCommonFrames() {
        if (vad_state_machine == VadStateMachine::kVadInStateEndPointDetected) {
            return 0;
        }
        for (int i = vad_opts.nn_eval_block_size - 1; i >= 0; i--) {
            FrameState frame_state = FrameState::kFrameStateInvalid;
            frame_state = GetFrameState(frm_cnt - 1 - i);
            DetectOneFrame(frame_state, frm_cnt - 1 - i, false);
        }
        return 0;
    }
    int DetectLastFrames() {
        if (vad_state_machine == VadStateMachine::kVadInStateEndPointDetected) {
            return 0;
        }
        for (int i = vad_opts.nn_eval_block_size - 1; i >= 0; i--) {
            FrameState frame_state = FrameState::kFrameStateInvalid;
            frame_state = GetFrameState(frm_cnt - 1 - i);
            if (i != 0) {
                DetectOneFrame(frame_state, frm_cnt - 1 - i, false);
            } else {
                DetectOneFrame(frame_state, frm_cnt - 1, true);
            }
        }
        return 0;
    }
    void DetectOneFrame(FrameState cur_frm_state, int cur_frm_idx, bool is_final_frame) {
        FrameState tmp_cur_frm_state = FrameState::kFrameStateInvalid;
        if (cur_frm_state == FrameState::kFrameStateSpeech) {
            if (std::fabs(1.0) > vad_opts.fe_prior_thres) {
                tmp_cur_frm_state = FrameState::kFrameStateSpeech;
            } else {
                tmp_cur_frm_state = FrameState::kFrameStateSil;
            }
        } else if (cur_frm_state == FrameState::kFrameStateSil) {
            tmp_cur_frm_state = FrameState::kFrameStateSil;
        }
        AudioChangeState state_change = windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx);
        int frm_shift_in_ms = vad_opts.frame_in_ms;
        if (AudioChangeState::kChangeStateSil2Speech == state_change) {
            int silence_frame_count = continous_silence_frame_count;
            continous_silence_frame_count = 0;
            pre_end_silence_detected = false;
            int start_frame = 0;
            if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) {
                start_frame = std::max(data_buf_start_frame, cur_frm_idx - LatencyFrmNumAtStartPoint());
                OnVoiceStart(start_frame);
                vad_state_machine = VadStateMachine::kVadInStateInSpeechSegment;
                for (int t = start_frame + 1; t <= cur_frm_idx; t++) {
                    OnVoiceDetected(t);
                }
            } else if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) {
                for (int t = latest_confirmed_speech_frame + 1; t < cur_frm_idx; t++) {
                    OnVoiceDetected(t);
                }
                if (cur_frm_idx - confirmed_start_frame + 1 > vad_opts.max_single_segment_time / frm_shift_in_ms) {
                    OnVoiceEnd(cur_frm_idx, false, false);
                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
                } else if (!is_final_frame) {
                    OnVoiceDetected(cur_frm_idx);
                } else {
                    MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
                }
            }
        } else if (AudioChangeState::kChangeStateSpeech2Sil == state_change) {
            continous_silence_frame_count = 0;
            if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) {
                // do nothing
            } else if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) {
                if (cur_frm_idx - confirmed_start_frame + 1 >
                    vad_opts.max_single_segment_time / frm_shift_in_ms) {
                    OnVoiceEnd(cur_frm_idx, false, false);
                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
                } else if (!is_final_frame) {
                    OnVoiceDetected(cur_frm_idx);
                } else {
                    MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
                }
            }
        } else if (AudioChangeState::kChangeStateSpeech2Speech == state_change) {
            continous_silence_frame_count = 0;
            if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) {
                if (cur_frm_idx - confirmed_start_frame + 1 >
                    vad_opts.max_single_segment_time / frm_shift_in_ms) {
                    max_time_out = true;
                    OnVoiceEnd(cur_frm_idx, false, false);
                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
                } else if (!is_final_frame) {
                    OnVoiceDetected(cur_frm_idx);
                } else {
                    MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
                }
            }
        } else if (AudioChangeState::kChangeStateSil2Sil == state_change) {
            continous_silence_frame_count += 1;
            if (vad_state_machine == VadStateMachine::kVadInStateStartPointNotDetected) {
                if ((vad_opts.detect_mode == static_cast<int>(VadDetectMode::kVadSingleUtteranceDetectMode) &&
                     (continous_silence_frame_count * frm_shift_in_ms > vad_opts.max_start_silence_time)) ||
                    (is_final_frame && number_end_time_detected == 0)) {
                    for (int t = lastest_confirmed_silence_frame + 1; t < cur_frm_idx; t++) {
                        OnSilenceDetected(t);
                    }
                    OnVoiceStart(0, true);
                    OnVoiceEnd(0, true, false);
                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
                } else {
                    if (cur_frm_idx >= LatencyFrmNumAtStartPoint()) {
                        OnSilenceDetected(cur_frm_idx - LatencyFrmNumAtStartPoint());
                    }
                }
            } else if (vad_state_machine == VadStateMachine::kVadInStateInSpeechSegment) {
                if (continous_silence_frame_count * frm_shift_in_ms >= max_end_sil_frame_cnt_thresh) {
                    int lookback_frame = max_end_sil_frame_cnt_thresh / frm_shift_in_ms;
                    if (vad_opts.do_extend) {
                        lookback_frame -= vad_opts.lookahead_time_end_point / frm_shift_in_ms;
                        lookback_frame -= 1;
                        lookback_frame = std::max(0, lookback_frame);
                    }
                    OnVoiceEnd(cur_frm_idx - lookback_frame, false, false);
                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
                } else if (cur_frm_idx - confirmed_start_frame + 1 >
                           vad_opts.max_single_segment_time / frm_shift_in_ms) {
                    OnVoiceEnd(cur_frm_idx, false, false);
                    vad_state_machine = VadStateMachine::kVadInStateEndPointDetected;
                } else if (vad_opts.do_extend && !is_final_frame) {
                    if (continous_silence_frame_count <= vad_opts.lookahead_time_end_point / frm_shift_in_ms) {
                        OnVoiceDetected(cur_frm_idx);
                    }
                } else {
                    MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx);
                }
            }
        }
        if (vad_state_machine == VadStateMachine::kVadInStateEndPointDetected &&
            vad_opts.detect_mode == static_cast<int>(VadDetectMode::kVadMutipleUtteranceDetectMode)) {
            ResetDetection();
        }
    }
};
funasr/runtime/onnxruntime/src/libfunasrapi.cpp
@@ -4,14 +4,14 @@
extern "C" {
#endif
    // APIs for qmasr
    _FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThreadNum, bool quantize)
    // APIs for funasr
    _FUNASRAPI FUNASR_HANDLE  FunASRInit(const char* szModelDir, int nThreadNum, bool quantize, bool use_vad)
    {
        Model* mm = create_model(szModelDir, nThreadNum, quantize);
        Model* mm = create_model(szModelDir, nThreadNum, quantize, use_vad);
        return mm;
    }
    _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
    _FUNASRAPI FUNASR_RESULT FunASRRecogBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
    {
        Model* pRecogObj = (Model*)handle;
        if (!pRecogObj)
@@ -21,7 +21,9 @@
        Audio audio(1);
        if (!audio.loadwav(szBuf, nLen, &sampling_rate))
            return nullptr;
        //audio.split();
        if(use_vad){
            audio.split(pRecogObj);
        }
        float* buff;
        int len;
@@ -31,7 +33,6 @@
        int nStep = 0;
        int nTotal = audio.get_queue_size();
        while (audio.fetch(buff, len, flag) > 0) {
            //pRecogObj->reset();
            string msg = pRecogObj->forward(buff, len, flag);
            pResult->msg += msg;
            nStep++;
@@ -42,7 +43,7 @@
        return pResult;
    }
    _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
    _FUNASRAPI FUNASR_RESULT FunASRRecogPCMBuffer(FUNASR_HANDLE handle, const char* szBuf, int nLen, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
    {
        Model* pRecogObj = (Model*)handle;
        if (!pRecogObj)
@@ -51,7 +52,9 @@
        Audio audio(1);
        if (!audio.loadpcmwav(szBuf, nLen, &sampling_rate))
            return nullptr;
        //audio.split();
        if(use_vad){
            audio.split(pRecogObj);
        }
        float* buff;
        int len;
@@ -61,7 +64,6 @@
        int nStep = 0;
        int nTotal = audio.get_queue_size();
        while (audio.fetch(buff, len, flag) > 0) {
            //pRecogObj->reset();
            string msg = pRecogObj->forward(buff, len, flag);
            pResult->msg += msg;
            nStep++;
@@ -72,7 +74,7 @@
        return pResult;
    }
    _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
    _FUNASRAPI FUNASR_RESULT FunASRRecogPCMFile(FUNASR_HANDLE handle, const char* szFileName, int sampling_rate, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
    {
        Model* pRecogObj = (Model*)handle;
        if (!pRecogObj)
@@ -81,7 +83,9 @@
        Audio audio(1);
        if (!audio.loadpcmwav(szFileName, &sampling_rate))
            return nullptr;
        //audio.split();
        if(use_vad){
            audio.split(pRecogObj);
        }
        float* buff;
        int len;
@@ -91,7 +95,6 @@
        int nStep = 0;
        int nTotal = audio.get_queue_size();
        while (audio.fetch(buff, len, flag) > 0) {
            //pRecogObj->reset();
            string msg = pRecogObj->forward(buff, len, flag);
            pResult->msg += msg;
            nStep++;
@@ -102,7 +105,7 @@
        return pResult;
    }
    _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback)
    _FUNASRAPI FUNASR_RESULT FunASRRecogFile(FUNASR_HANDLE handle, const char* szWavfile, FUNASR_MODE Mode, QM_CALLBACK fnCallback, bool use_vad)
    {
        Model* pRecogObj = (Model*)handle;
        if (!pRecogObj)
@@ -112,7 +115,9 @@
        Audio audio(1);
        if(!audio.loadwav(szWavfile, &sampling_rate))
            return nullptr;
        //audio.split();
        if(use_vad){
            audio.split(pRecogObj);
        }
        float* buff;
        int len;
funasr/runtime/onnxruntime/src/paraformer_onnx.cpp
@@ -3,11 +3,19 @@
using namespace std;
using namespace paraformer;
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize)
ModelImp::ModelImp(const char* path,int nNumThread, bool quantize, bool use_vad)
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),sessionOptions{}{
    string model_path;
    string cmvn_path;
    string config_path;
    // VAD model
    if(use_vad){
        string vad_path = pathAppend(path, "vad_model.onnx");
        string mvn_path = pathAppend(path, "vad.mvn");
        vadHandle = make_unique<FsmnVad>();
        vadHandle->init_vad(vad_path, mvn_path, model_sample_rate, 800, 15000, 0.9);
    }
    if(quantize)
    {
@@ -30,8 +38,10 @@
    //fbank_ = std::make_unique<knf::OnlineFbank>(fbank_opts);
    //sessionOptions.SetInterOpNumThreads(1);
    //sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
    sessionOptions.SetIntraOpNumThreads(nNumThread);
    sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
    sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
    sessionOptions.DisableCpuMemArena();
#ifdef _WIN32
    wstring wstrPath = strToWstr(model_path);
@@ -67,6 +77,10 @@
void ModelImp::reset()
{
}
vector<std::vector<int>> ModelImp::vad_seg(std::vector<float>& pcm_data){
    return vadHandle->infer(pcm_data);
}
vector<float> ModelImp::FbankKaldi(float sample_rate, const float* waves, int len) {
@@ -172,66 +186,6 @@
      p += dim;
    }
  }
//   void ParaformerOnnxAsrModel::ForwardFunc(
//     const std::vector<std::vector<float>>& chunk_feats,
//     std::vector<std::vector<float>>* out_prob) {
//   Ort::MemoryInfo memory_info =
//       Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
//   // 1. Prepare onnx required data, splice cached_feature_ and chunk_feats
//   // chunk
// //  int num_frames = cached_feature_.size() + chunk_feats.size();
//   int num_frames = chunk_feats.size();
//   const int feature_dim = chunk_feats[0].size();
//   //  2. Generate 2 input nodes tensor
//   // speech node { batch,frame number,feature dim }
//   const int64_t paraformer_feats_shape[3] = {1, num_frames, feature_dim};
//   std::vector<float> paraformer_feats;
//   for (const auto & chunk_feat : chunk_feats) {
//     paraformer_feats.insert(paraformer_feats.end(), chunk_feat.begin(), chunk_feat.end());
//   }
//   Ort::Value paraformer_feats_ort = Ort::Value::CreateTensor<float>(
//           memory_info, paraformer_feats.data(), paraformer_feats.size(), paraformer_feats_shape, 3);
//   // speech_lengths node {speech length,}
//   const int64_t paraformer_length_shape[1] = {1};
//   std::vector<int32_t> paraformer_length;
//   paraformer_length.emplace_back(num_frames);
//   Ort::Value paraformer_length_ort = Ort::Value::CreateTensor<int32_t>(
//           memory_info, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1);
//   // 3. Put nodes into onnx input vector
//   std::vector<Ort::Value> paraformer_inputs;
//   paraformer_inputs.emplace_back(std::move(paraformer_feats_ort));
//   paraformer_inputs.emplace_back(std::move(paraformer_length_ort));
//   // 4. Onnx infer
//   std::vector<Ort::Value> paraformer_ort_outputs;
//   try{
//     VLOG(3) << "Start infer";
//     paraformer_ort_outputs = paraformer_session_->Run(
//             Ort::RunOptions{nullptr}, paraformer_in_names_.data(), paraformer_inputs.data(),
//             paraformer_inputs.size(), paraformer_out_names_.data(), paraformer_out_names_.size());
//   }catch (std::exception const& e) {
//     //  Catch "Non-zero status code returned error",usually because there is no asr result.
//     // Need funasr to resolve.
//     LOG(ERROR) << e.what();
//     return;
//   }
//   // 5. Change infer result to output shapes
//   float* logp_data = paraformer_ort_outputs[0].GetTensorMutableData<float>();
//   auto type_info = paraformer_ort_outputs[0].GetTensorTypeAndShapeInfo();
//   int num_outputs = type_info.GetShape()[1];
//   int output_dim = type_info.GetShape()[2];
//   out_prob->resize(num_outputs);
//   for (int i = 0; i < num_outputs; i++) {
//     (*out_prob)[i].resize(output_dim);
//     memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
//            sizeof(float) * output_dim);
//   }
// }
string ModelImp::forward(float* din, int len, int flag)
{
funasr/runtime/onnxruntime/src/paraformer_onnx.h
@@ -4,8 +4,7 @@
#ifndef PARAFORMER_MODELIMP_H
#define PARAFORMER_MODELIMP_H
#include "kaldi-native-fbank/csrc/feature-fbank.h"
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "precomp.h"
namespace paraformer {
@@ -13,6 +12,8 @@
    private:
        //std::unique_ptr<knf::OnlineFbank> fbank_;
        knf::FbankOptions fbank_opts;
        std::unique_ptr<FsmnVad> vadHandle;
        Vocab* vocab;
        vector<float> means_list;
@@ -27,7 +28,7 @@
        string greedy_search( float* in, int nLen);
        std::unique_ptr<Ort::Session> m_session;
        std::shared_ptr<Ort::Session> m_session;
        Ort::Env env_;
        Ort::SessionOptions sessionOptions;
@@ -36,13 +37,14 @@
        vector<const char*> m_szOutputNames;
    public:
        ModelImp(const char* path, int nNumThread=0, bool quantize=false);
        ModelImp(const char* path, int nNumThread=0, bool quantize=false, bool use_vad=false);
        ~ModelImp();
        void reset();
        vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
        string forward_chunk(float* din, int len, int flag);
        string forward(float* din, int len, int flag);
        string rescoring();
        std::vector<std::vector<int>> vad_seg(std::vector<float>& pcm_data);
    };
funasr/runtime/onnxruntime/src/precomp.h
@@ -26,6 +26,8 @@
#include <fftw3.h>
#include "onnxruntime_run_options_config_keys.h"
#include "onnxruntime_cxx_api.h"
#include "kaldi-native-fbank/csrc/feature-fbank.h"
#include "kaldi-native-fbank/csrc/online-feature.h"
// mine
@@ -33,6 +35,7 @@
#include "commonfunc.h"
#include <ComDefine.h>
#include "predefine_coe.h"
#include "FsmnVad.h"
#include <ComDefine.h>
//#include "alignedmem.h"
funasr/runtime/onnxruntime/tester/tester.cpp
@@ -14,9 +14,9 @@
int main(int argc, char *argv[])
{
    if (argc < 4)
    if (argc < 5)
    {
        printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) \n", argv[0]);
        printf("Usage: %s /path/to/model_dir /path/to/wav/file quantize(true or false) use_vad(true or false) \n", argv[0]);
        exit(-1);
    }
    struct timeval start, end;
@@ -24,8 +24,10 @@
    int nThreadNum = 1;
    // is quantize
    bool quantize = false;
    bool use_vad = false;
    istringstream(argv[3]) >> boolalpha >> quantize;
    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize);
    istringstream(argv[4]) >> boolalpha >> use_vad;
    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], nThreadNum, quantize, use_vad);
    if (!AsrHanlde)
    {
@@ -41,7 +43,7 @@
    gettimeofday(&start, NULL);
    float snippet_time = 0.0f;
    FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL);
    FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, argv[2], RASR_NONE, NULL, use_vad);
    gettimeofday(&end, NULL);
   
funasr/runtime/onnxruntime/tester/tester_rtf.cpp
@@ -19,7 +19,7 @@
std::atomic<int> index(0);
std::mutex mtx;
void runReg(FUNASR_HANDLE AsrHanlde, vector<string> wav_list,
void runReg(FUNASR_HANDLE AsrHandle, vector<string> wav_list,
            float* total_length, long* total_time, int core_id) {
    // cpu_set_t cpuset;
@@ -37,7 +37,7 @@
    // warm up
    for (size_t i = 0; i < 1; i++)
    {
        FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[0].c_str(), RASR_NONE, NULL);
        FUNASR_RESULT Result=FunASRRecogFile(AsrHandle, wav_list[0].c_str(), RASR_NONE, NULL);
    }
    while (true) {
@@ -48,7 +48,7 @@
        }
        gettimeofday(&start, NULL);
        FUNASR_RESULT Result=FunASRRecogFile(AsrHanlde, wav_list[i].c_str(), RASR_NONE, NULL);
        FUNASR_RESULT Result=FunASRRecogFile(AsrHandle, wav_list[i].c_str(), RASR_NONE, NULL);
        gettimeofday(&end, NULL);
        seconds = (end.tv_sec - start.tv_sec);
@@ -112,8 +112,8 @@
    int nThreadNum = 1;
    nThreadNum = atoi(argv[4]);
    FUNASR_HANDLE AsrHanlde=FunASRInit(argv[1], 1, quantize);
    if (!AsrHanlde)
    FUNASR_HANDLE AsrHandle=FunASRInit(argv[1], 1, quantize);
    if (!AsrHandle)
    {
        printf("Cannot load ASR Model from: %s, there must be files model.onnx and vocab.txt", argv[1]);
        exit(-1);
@@ -130,7 +130,7 @@
    for (int i = 0; i < nThreadNum; i++)
    {
        threads.emplace_back(thread(runReg, AsrHanlde, wav_list, &total_length, &total_time, i));
        threads.emplace_back(thread(runReg, AsrHandle, wav_list, &total_length, &total_time, i));
    }
    for (auto& thread : threads)
@@ -142,6 +142,6 @@
    printf("total_time_comput %ld ms.\n", total_time / 1000);
    printf("total_rtf %05lf .\n", (double)total_time/ (total_length*1000000));
    FunASRUninit(AsrHanlde);
    FunASRUninit(AsrHandle);
    return 0;
}