From 4f8bce944e273e317cb84c7046ea514b9d958b4b Mon Sep 17 00:00:00 2001
From: zhuzizyf <42790740+zhuzizyf@users.noreply.github.com>
Date: 星期六, 22 四月 2023 14:54:49 +0800
Subject: [PATCH] Update FsmnVad.cc
---
funasr/runtime/onnxruntime/src/FsmnVad.cc | 34 +++++++++++++++++++++++++---------
1 files changed, 25 insertions(+), 9 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/FsmnVad.cc b/funasr/runtime/onnxruntime/src/FsmnVad.cc
index f75ead7..de63225 100644
--- a/funasr/runtime/onnxruntime/src/FsmnVad.cc
+++ b/funasr/runtime/onnxruntime/src/FsmnVad.cc
@@ -18,6 +18,7 @@
read_model(vad_model);
load_cmvn(vad_cmvn.c_str());
+ init_cache();
fbank_opts.frame_opts.dither = 0;
fbank_opts.mel_opts.num_bins = 80;
@@ -105,20 +106,18 @@
}
Ort::Value vad_feats_ort = Ort::Value::CreateTensor<float>(
memory_info, vad_feats.data(), vad_feats.size(), vad_feats_shape, 3);
- // cache node {batch,128,19,1}
- const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
- std::vector<float> cache_feats(128 * 19 * 1, 0);
- Ort::Value cache_feats_ort = Ort::Value::CreateTensor<float>(
- memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4);
-
+
// 3. Put nodes into onnx input vector
std::vector<Ort::Value> vad_inputs;
vad_inputs.emplace_back(std::move(vad_feats_ort));
// 4 caches
- for (int i = 0; i < 4; i++) {
- vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
- memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4)));
+ // cache node {batch,128,19,1}
+ const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
+ for (int i = 0; i < in_cache_.size(); i++) {
+ vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
+ memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
}
+
// 4. Onnx infer
std::vector<Ort::Value> vad_ort_outputs;
try {
@@ -142,6 +141,12 @@
(*out_prob)[i].resize(output_dim);
memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
sizeof(float) * output_dim);
+ }
+
+ // get 4 caches outputs,each size is 128*19
+ for (int i = 1; i < 5; i++) {
+ float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
+ memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
}
}
@@ -252,6 +257,17 @@
}
+void FsmnVad::init_cache(){
+ std::vector<float> cache_feats(128 * 19 * 1, 0);
+ for (int i=0;i<4;i++){
+ in_cache_.emplace_back(cache_feats);
+ }
+};
+
+void FsmnVad::Reset(){
+ in_cache_.clear();
+ init_cache();
+};
void FsmnVad::test() {
--
Gitblit v1.9.1