zhuzizyf
2023-04-22 4f8bce944e273e317cb84c7046ea514b9d958b4b
funasr/runtime/onnxruntime/src/FsmnVad.cc
@@ -1,3 +1,4 @@
// Collaborators: zhuzizyf(China Telecom Shanghai)
#include <fstream>
#include "FsmnVad.h"
@@ -17,6 +18,7 @@
    read_model(vad_model);
    load_cmvn(vad_cmvn.c_str());
    init_cache();
    fbank_opts.frame_opts.dither = 0;
    fbank_opts.mel_opts.num_bins = 80;
@@ -104,20 +106,18 @@
    }
    Ort::Value vad_feats_ort = Ort::Value::CreateTensor<float>(
            memory_info, vad_feats.data(), vad_feats.size(), vad_feats_shape, 3);
    // cache node {batch,128,19,1}
    const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
    std::vector<float> cache_feats(128 * 19 * 1, 0);
    Ort::Value cache_feats_ort = Ort::Value::CreateTensor<float>(
            memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4);
    // 3. Put nodes into onnx input vector
    std::vector<Ort::Value> vad_inputs;
    vad_inputs.emplace_back(std::move(vad_feats_ort));
    // 4 caches
    for (int i = 0; i < 4; i++) {
        vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
                memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4)));
    // cache node {batch,128,19,1}
    const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
    for (int i = 0; i < in_cache_.size(); i++) {
      vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
              memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
    }
    // 4. Onnx infer
    std::vector<Ort::Value> vad_ort_outputs;
    try {
@@ -141,6 +141,12 @@
        (*out_prob)[i].resize(output_dim);
        memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
               sizeof(float) * output_dim);
    }
    // get 4 caches outputs,each size is 128*19
    for (int i = 1; i < 5; i++) {
      float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
      memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
    }
}
@@ -245,12 +251,23 @@
    E2EVadModel vad_scorer = E2EVadModel();
    std::vector<std::vector<int>> vad_segments;
    vad_segments = vad_scorer(vad_probs, waves, true, vad_silence_duration_, vad_max_len_,
    vad_segments = vad_scorer(vad_probs, waves, true, false, vad_silence_duration_, vad_max_len_,
                              vad_speech_noise_thres_, vad_sample_rate_);
    return vad_segments;
}
void FsmnVad::init_cache(){
  std::vector<float> cache_feats(128 * 19 * 1, 0);
  for (int i=0;i<4;i++){
    in_cache_.emplace_back(cache_feats);
  }
};
void FsmnVad::Reset(){
  in_cache_.clear();
  init_cache();
};
void FsmnVad::test() {