lyblsgo
2023-04-24 fa0356b81dd3c99a2bd573d1f60d7b7131e00533
funasr/runtime/onnxruntime/src/FsmnVad.h
@@ -1,27 +1,22 @@
// Collaborators: zhuzizyf(China Telecom Shanghai)
#ifndef VAD_SERVER_FSMNVAD_H
#define VAD_SERVER_FSMNVAD_H
#include "e2e_vad.h"
#include "onnxruntime_cxx_api.h"
#include "kaldi-native-fbank/csrc/feature-fbank.h"
#include "kaldi-native-fbank/csrc/online-feature.h"
#include "precomp.h"
class FsmnVad {
public:
    FsmnVad();
    void test();
    void init_vad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
    void Test();
    void InitVad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
                  float vad_speech_noise_thres);
    std::vector<std::vector<int>> infer(const std::vector<float> &waves);
    std::vector<std::vector<int>> Infer(const std::vector<float> &waves);
    void Reset();
private:
    void read_model(const std::string &vad_model);
    void ReadModel(const std::string &vad_model);
    static void GetInputOutputInfo(
            const std::shared_ptr<Ort::Session> &session,
@@ -36,8 +31,8 @@
            const std::vector<std::vector<float>> &chunk_feats,
            std::vector<std::vector<float>> *out_prob);
    void load_cmvn(const char *filename);
    void init_cache();
    void LoadCmvn(const char *filename);
    void InitCache();
    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
    Ort::Env env_;