lyblsgo
2023-04-21 7bfc4a84fc2d882f34928a033a6d5b60ff72fe19
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
 
#ifndef VAD_SERVER_FSMNVAD_H
#define VAD_SERVER_FSMNVAD_H
 
#include "e2e_vad.h"
#include "onnxruntime_cxx_api.h"
#include "kaldi-native-fbank/csrc/feature-fbank.h"
#include "kaldi-native-fbank/csrc/online-feature.h"
 
 
class FsmnVad {
public:
    FsmnVad();
    void test();
    void init_vad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
                  float vad_speech_noise_thres);
 
    std::vector<std::vector<int>> infer(const std::vector<float> &waves);
 
private:
 
    void read_model(const std::string &vad_model);
 
    static void GetInputOutputInfo(
            const std::shared_ptr<Ort::Session> &session,
            std::vector<const char *> *in_names, std::vector<const char *> *out_names);
 
    void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
                    const std::vector<float> &waves);
 
    std::vector<std::vector<float>> &LfrCmvn(std::vector<std::vector<float>> &vad_feats, int lfr_m, int lfr_n);
 
    void Forward(
            const std::vector<std::vector<float>> &chunk_feats,
            std::vector<std::vector<float>> *out_prob);
 
    void load_cmvn(const char *filename);
 
    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
    Ort::Env env_;
    Ort::SessionOptions session_options_;
    std::vector<const char *> vad_in_names_;
    std::vector<const char *> vad_out_names_;
    knf::FbankOptions fbank_opts;
    std::vector<float> means_list;
    std::vector<float> vars_list;
    int vad_sample_rate_ = 16000;
    int vad_silence_duration_ = 800;
    int vad_max_len_ = 15000;
    double vad_speech_noise_thres_ = 0.9;
};
 
 
#endif //VAD_SERVER_FSMNVAD_H