From 3372b13d24aceef7002cfa0fc8222b3085c15110 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期五, 02 六月 2023 22:02:31 +0800
Subject: [PATCH] add fsmn-vad-online
---
funasr/runtime/onnxruntime/src/fsmn-vad.cpp | 51 ++++++++++++++++++++++++++++-----------------------
1 files changed, 28 insertions(+), 23 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
index 516dc88..697828b 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -37,14 +37,14 @@
this->vad_max_len_ = post_conf["max_single_segment_time"].as<int>();
this->vad_speech_noise_thres_ = post_conf["speech_noise_thres"].as<double>();
- fbank_opts.frame_opts.dither = frontend_conf["dither"].as<float>();
- fbank_opts.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
- fbank_opts.frame_opts.samp_freq = (float)vad_sample_rate_;
- fbank_opts.frame_opts.window_type = frontend_conf["window"].as<string>();
- fbank_opts.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
- fbank_opts.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
- fbank_opts.energy_floor = 0;
- fbank_opts.mel_opts.debug_mel = false;
+ fbank_opts_.frame_opts.dither = frontend_conf["dither"].as<float>();
+ fbank_opts_.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
+ fbank_opts_.frame_opts.samp_freq = (float)vad_sample_rate_;
+ fbank_opts_.frame_opts.window_type = frontend_conf["window"].as<string>();
+ fbank_opts_.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
+ fbank_opts_.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
+ fbank_opts_.energy_floor = 0;
+ fbank_opts_.mel_opts.debug_mel = false;
}catch(exception const &e){
LOG(ERROR) << "Error when load argument from vad config YAML.";
exit(-1);
@@ -55,6 +55,7 @@
try {
vad_session_ = std::make_shared<Ort::Session>(
env_, vad_model, session_options_);
+ LOG(INFO) << "Successfully load model from " << vad_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load vad onnx model: " << e.what();
exit(0);
@@ -109,7 +110,9 @@
void FsmnVad::Forward(
const std::vector<std::vector<float>> &chunk_feats,
- std::vector<std::vector<float>> *out_prob) {
+ std::vector<std::vector<float>> *out_prob,
+ std::vector<std::vector<float>> *in_cache,
+ bool is_final) {
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
@@ -132,9 +135,9 @@
// 4 caches
// cache node {batch,128,19,1}
const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
- for (int i = 0; i < in_cache_.size(); i++) {
+ for (int i = 0; i < in_cache->size(); i++) {
vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
- memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
+ memory_info, (*in_cache)[i].data(), (*in_cache)[i].size(), cache_feats_shape, 4)));
}
// 4. Onnx infer
@@ -162,15 +165,17 @@
}
// get 4 caches outputs,each size is 128*19
- // for (int i = 1; i < 5; i++) {
- // float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
- // memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
- // }
+ if(!is_final){
+ for (int i = 1; i < 5; i++) {
+ float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
+ memcpy((*in_cache)[i-1].data(), data, sizeof(float) * 128*19);
+ }
+ }
}
void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
std::vector<float> &waves) {
- knf::OnlineFbank fbank(fbank_opts);
+ knf::OnlineFbank fbank(fbank_opts_);
std::vector<float> buf(waves.size());
for (int32_t i = 0; i != waves.size(); ++i) {
@@ -180,7 +185,7 @@
int32_t frames = fbank.NumFramesReady();
for (int32_t i = 0; i != frames; ++i) {
const float *frame = fbank.GetFrame(i);
- std::vector<float> frame_vector(frame, frame + fbank_opts.mel_opts.num_bins);
+ std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
vad_feats.emplace_back(frame_vector);
}
}
@@ -205,7 +210,7 @@
vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
if (means_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < means_lines.size() - 1; j++) {
- means_list.push_back(stof(means_lines[j]));
+ means_list_.push_back(stof(means_lines[j]));
}
continue;
}
@@ -216,8 +221,8 @@
vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
if (vars_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < vars_lines.size() - 1; j++) {
- // vars_list.push_back(stof(vars_lines[j])*scale);
- vars_list.push_back(stof(vars_lines[j]));
+ // vars_list_.push_back(stof(vars_lines[j])*scale);
+ vars_list_.push_back(stof(vars_lines[j]));
}
continue;
}
@@ -263,8 +268,8 @@
}
// Apply cmvn
for (auto &out_feat: out_feats) {
- for (int j = 0; j < means_list.size(); j++) {
- out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j];
+ for (int j = 0; j < means_list_.size(); j++) {
+ out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
}
}
vad_feats = out_feats;
@@ -276,7 +281,7 @@
std::vector<std::vector<float>> vad_probs;
FbankKaldi(vad_sample_rate_, vad_feats, waves);
LfrCmvn(vad_feats);
- Forward(vad_feats, &vad_probs);
+ Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
E2EVadModel vad_scorer = E2EVadModel();
std::vector<std::vector<int>> vad_segments;
--
Gitblit v1.9.1