cdevelop
2023-11-15 eff2570faf3dae7908db87edf4ef1a6ea88e5b33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
 
#pragma once 
#include "precomp.h"
 
namespace funasr {
class FsmnVadOnline : public VadModel {
/**
 * Author: Speech Lab of DAMO Academy, Alibaba Group
 * Deep-FSMN for Large Vocabulary Continuous Speech Recognition
 * https://arxiv.org/abs/1803.05030
*/
 
public:
    explicit FsmnVadOnline(FsmnVad* fsmnvad_handle);
    ~FsmnVadOnline();
    void Test();
    std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
    void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
    void Reset();
    // 2pass
    std::unique_ptr<Audio> audio_handle = nullptr;
 
private:
    E2EVadModel vad_scorer = E2EVadModel();
    // std::unique_ptr<FsmnVad> fsmnvad_handle_;
    FsmnVad* fsmnvad_handle_ = nullptr;
 
    void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
                    std::vector<float> &waves);
    int OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished);
    void InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num){}
    void InitCache();
    void InitOnline(std::shared_ptr<Ort::Session> &vad_session,
                    Ort::Env &env,
                    std::vector<const char *> &vad_in_names,
                    std::vector<const char *> &vad_out_names,
                    knf::FbankOptions &fbank_opts,
                    std::vector<float> &means_list,
                    std::vector<float> &vars_list,
                    int vad_sample_rate,
                    int vad_silence_duration,
                    int vad_max_len,
                    double vad_speech_noise_thres);
 
    static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
        int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
        if (frame_num >= 1 && sample_length >= frame_sample_length)
            return frame_num;
        else
            return 0;
    }
    void ResetCache() {
        reserve_waveforms_.clear();
        input_cache_.clear();
        lfr_splice_cache_.clear();
    }
 
    // from fsmnvad_handle_
    std::shared_ptr<Ort::Session> vad_session_ = nullptr;
    Ort::Env env_;
    Ort::SessionOptions session_options_;
    std::vector<const char *> vad_in_names_;
    std::vector<const char *> vad_out_names_;
    knf::FbankOptions fbank_opts_;
    std::vector<float> means_list_;
    std::vector<float> vars_list_;
 
    std::vector<std::vector<float>> in_cache_;
    // The reserved waveforms by fbank
    std::vector<float> reserve_waveforms_;
    // waveforms reserved after last shift position
    std::vector<float> input_cache_;
    // lfr reserved cache
    std::vector<std::vector<float>> lfr_splice_cache_;
 
    int vad_sample_rate_ = MODEL_SAMPLE_RATE;
    int vad_silence_duration_ = VAD_SILENCE_DURATION;
    int vad_max_len_ = VAD_MAX_LEN;
    double vad_speech_noise_thres_ = VAD_SPEECH_NOISE_THRES;
    int lfr_m = VAD_LFR_M;
    int lfr_n = VAD_LFR_N;
    int frame_sample_length_ = vad_sample_rate_ / 1000 * 25;;
    int frame_shift_sample_length_ = vad_sample_rate_ / 1000 * 10;
};
 
} // namespace funasr