From fa0356b81dd3c99a2bd573d1f60d7b7131e00533 Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期一, 24 四月 2023 11:23:40 +0800
Subject: [PATCH] rename src/e2e_vad.h
---
funasr/runtime/onnxruntime/src/FsmnVad.cc | 52 ++++++++++++++++++++++++++++++++--------------------
1 files changed, 32 insertions(+), 20 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/FsmnVad.cc b/funasr/runtime/onnxruntime/src/FsmnVad.cc
index 6de482e..0f87cb2 100644
--- a/funasr/runtime/onnxruntime/src/FsmnVad.cc
+++ b/funasr/runtime/onnxruntime/src/FsmnVad.cc
@@ -1,11 +1,10 @@
#include <fstream>
-#include "FsmnVad.h"
#include "precomp.h"
//#include "glog/logging.h"
-void FsmnVad::init_vad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
+void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
float vad_speech_noise_thres) {
session_options_.SetIntraOpNumThreads(1);
session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
@@ -15,8 +14,9 @@
this->vad_max_len_=vad_max_len;
this->vad_speech_noise_thres_=vad_speech_noise_thres;
- read_model(vad_model);
- load_cmvn(vad_cmvn.c_str());
+ ReadModel(vad_model);
+ LoadCmvn(vad_cmvn.c_str());
+ InitCache();
fbank_opts.frame_opts.dither = 0;
fbank_opts.mel_opts.num_bins = 80;
@@ -29,7 +29,7 @@
}
-void FsmnVad::read_model(const std::string &vad_model) {
+void FsmnVad::ReadModel(const std::string &vad_model) {
try {
vad_session_ = std::make_shared<Ort::Session>(
env_, vad_model.c_str(), session_options_);
@@ -104,20 +104,18 @@
}
Ort::Value vad_feats_ort = Ort::Value::CreateTensor<float>(
memory_info, vad_feats.data(), vad_feats.size(), vad_feats_shape, 3);
- // cache node {batch,128,19,1}
- const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
- std::vector<float> cache_feats(128 * 19 * 1, 0);
- Ort::Value cache_feats_ort = Ort::Value::CreateTensor<float>(
- memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4);
-
+
// 3. Put nodes into onnx input vector
std::vector<Ort::Value> vad_inputs;
vad_inputs.emplace_back(std::move(vad_feats_ort));
// 4 caches
- for (int i = 0; i < 4; i++) {
- vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
- memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4)));
+ // cache node {batch,128,19,1}
+ const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
+ for (int i = 0; i < in_cache_.size(); i++) {
+ vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
+ memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
}
+
// 4. Onnx infer
std::vector<Ort::Value> vad_ort_outputs;
try {
@@ -142,8 +140,13 @@
memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
sizeof(float) * output_dim);
}
+
+ // get 4 caches outputs,each size is 128*19
+ for (int i = 1; i < 5; i++) {
+ float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
+ memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
+ }
}
-
void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
const std::vector<float> &waves) {
@@ -158,7 +161,7 @@
}
}
-void FsmnVad::load_cmvn(const char *filename)
+void FsmnVad::LoadCmvn(const char *filename)
{
using namespace std;
ifstream cmvn_stream(filename);
@@ -236,7 +239,7 @@
}
std::vector<std::vector<int>>
-FsmnVad::infer(const std::vector<float> &waves) {
+FsmnVad::Infer(const std::vector<float> &waves) {
std::vector<std::vector<float>> vad_feats;
std::vector<std::vector<float>> vad_probs;
FbankKaldi(vad_sample_rate_, vad_feats, waves);
@@ -245,17 +248,26 @@
E2EVadModel vad_scorer = E2EVadModel();
std::vector<std::vector<int>> vad_segments;
- vad_segments = vad_scorer(vad_probs, waves, true, vad_silence_duration_, vad_max_len_,
+ vad_segments = vad_scorer(vad_probs, waves, true, false, vad_silence_duration_, vad_max_len_,
vad_speech_noise_thres_, vad_sample_rate_);
return vad_segments;
}
+void FsmnVad::InitCache(){
+ std::vector<float> cache_feats(128 * 19 * 1, 0);
+ for (int i=0;i<4;i++){
+ in_cache_.emplace_back(cache_feats);
+ }
+};
-void FsmnVad::test() {
+void FsmnVad::Reset(){
+ in_cache_.clear();
+ InitCache();
+};
+void FsmnVad::Test() {
}
FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} {
-
}
--
Gitblit v1.9.1