From c542eacb0aadcbc49c63db40429fca4e08f807a4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 21 七月 2023 10:27:35 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/runtime/onnxruntime/src/fsmn-vad.cpp | 183 ++++++++++++++++++++++++++++-----------------
1 files changed, 112 insertions(+), 71 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
index 0f87cb2..697828b 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -1,43 +1,65 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
#include <fstream>
#include "precomp.h"
-//#include "glog/logging.h"
-
-void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, int vad_sample_rate, int vad_silence_duration, int vad_max_len,
- float vad_speech_noise_thres) {
- session_options_.SetIntraOpNumThreads(1);
+namespace funasr {
+void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num) {
+ session_options_.SetIntraOpNumThreads(thread_num);
session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
session_options_.DisableCpuMemArena();
- this->vad_sample_rate_ = vad_sample_rate;
- this->vad_silence_duration_=vad_silence_duration;
- this->vad_max_len_=vad_max_len;
- this->vad_speech_noise_thres_=vad_speech_noise_thres;
- ReadModel(vad_model);
+ ReadModel(vad_model.c_str());
LoadCmvn(vad_cmvn.c_str());
+ LoadConfigFromYaml(vad_config.c_str());
InitCache();
-
- fbank_opts.frame_opts.dither = 0;
- fbank_opts.mel_opts.num_bins = 80;
- fbank_opts.frame_opts.samp_freq = vad_sample_rate;
- fbank_opts.frame_opts.window_type = "hamming";
- fbank_opts.frame_opts.frame_shift_ms = 10;
- fbank_opts.frame_opts.frame_length_ms = 25;
- fbank_opts.energy_floor = 0;
- fbank_opts.mel_opts.debug_mel = false;
-
}
-void FsmnVad::ReadModel(const std::string &vad_model) {
+void FsmnVad::LoadConfigFromYaml(const char* filename){
+
+ YAML::Node config;
+ try{
+ config = YAML::LoadFile(filename);
+ }catch(exception const &e){
+ LOG(ERROR) << "Error loading file, yaml file error or not exist.";
+ exit(-1);
+ }
+
+ try{
+ YAML::Node frontend_conf = config["frontend_conf"];
+ YAML::Node post_conf = config["vad_post_conf"];
+
+ this->vad_sample_rate_ = frontend_conf["fs"].as<int>();
+ this->vad_silence_duration_ = post_conf["max_end_silence_time"].as<int>();
+ this->vad_max_len_ = post_conf["max_single_segment_time"].as<int>();
+ this->vad_speech_noise_thres_ = post_conf["speech_noise_thres"].as<double>();
+
+ fbank_opts_.frame_opts.dither = frontend_conf["dither"].as<float>();
+ fbank_opts_.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
+ fbank_opts_.frame_opts.samp_freq = (float)vad_sample_rate_;
+ fbank_opts_.frame_opts.window_type = frontend_conf["window"].as<string>();
+ fbank_opts_.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
+ fbank_opts_.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
+ fbank_opts_.energy_floor = 0;
+ fbank_opts_.mel_opts.debug_mel = false;
+ }catch(exception const &e){
+ LOG(ERROR) << "Error when load argument from vad config YAML.";
+ exit(-1);
+ }
+}
+
+void FsmnVad::ReadModel(const char* vad_model) {
try {
vad_session_ = std::make_shared<Ort::Session>(
- env_, vad_model.c_str(), session_options_);
+ env_, vad_model, session_options_);
+ LOG(INFO) << "Successfully load model from " << vad_model;
} catch (std::exception const &e) {
- //LOG(ERROR) << "Error when load onnx model: " << e.what();
+ LOG(ERROR) << "Error when load vad onnx model: " << e.what();
exit(0);
}
- //LOG(INFO) << "vad onnx:";
GetInputOutputInfo(vad_session_, &vad_in_names_, &vad_out_names_);
}
@@ -88,7 +110,9 @@
void FsmnVad::Forward(
const std::vector<std::vector<float>> &chunk_feats,
- std::vector<std::vector<float>> *out_prob) {
+ std::vector<std::vector<float>> *out_prob,
+ std::vector<std::vector<float>> *in_cache,
+ bool is_final) {
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
@@ -111,21 +135,20 @@
// 4 caches
// cache node {batch,128,19,1}
const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
- for (int i = 0; i < in_cache_.size(); i++) {
+ for (int i = 0; i < in_cache->size(); i++) {
vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
- memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
+ memory_info, (*in_cache)[i].data(), (*in_cache)[i].size(), cache_feats_shape, 4)));
}
// 4. Onnx infer
std::vector<Ort::Value> vad_ort_outputs;
try {
- // VLOG(3) << "Start infer";
vad_ort_outputs = vad_session_->Run(
Ort::RunOptions{nullptr}, vad_in_names_.data(), vad_inputs.data(),
vad_inputs.size(), vad_out_names_.data(), vad_out_names_.size());
} catch (std::exception const &e) {
- // LOG(ERROR) << e.what();
- return;
+ LOG(ERROR) << "Error when run vad onnx forword: " << (e.what());
+ exit(0);
}
// 5. Change infer result to output shapes
@@ -142,61 +165,76 @@
}
// get 4 caches outputs,each size is 128*19
- for (int i = 1; i < 5; i++) {
- float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
- memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
+ if(!is_final){
+ for (int i = 1; i < 5; i++) {
+ float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
+ memcpy((*in_cache)[i-1].data(), data, sizeof(float) * 128*19);
+ }
}
}
void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
- const std::vector<float> &waves) {
- knf::OnlineFbank fbank(fbank_opts);
+ std::vector<float> &waves) {
+ knf::OnlineFbank fbank(fbank_opts_);
- fbank.AcceptWaveform(sample_rate, &waves[0], waves.size());
+ std::vector<float> buf(waves.size());
+ for (int32_t i = 0; i != waves.size(); ++i) {
+ buf[i] = waves[i] * 32768;
+ }
+ fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
int32_t frames = fbank.NumFramesReady();
for (int32_t i = 0; i != frames; ++i) {
const float *frame = fbank.GetFrame(i);
- std::vector<float> frame_vector(frame, frame + fbank_opts.mel_opts.num_bins);
+ std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
vad_feats.emplace_back(frame_vector);
}
}
void FsmnVad::LoadCmvn(const char *filename)
{
- using namespace std;
- ifstream cmvn_stream(filename);
- string line;
+ try{
+ using namespace std;
+ ifstream cmvn_stream(filename);
+ if (!cmvn_stream.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << filename;
+ exit(0);
+ }
+ string line;
- while (getline(cmvn_stream, line)) {
- istringstream iss(line);
- vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
- if (line_item[0] == "<AddShift>") {
- getline(cmvn_stream, line);
- istringstream means_lines_stream(line);
- vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
- if (means_lines[0] == "<LearnRateCoef>") {
- for (int j = 3; j < means_lines.size() - 1; j++) {
- means_list.push_back(stof(means_lines[j]));
+ while (getline(cmvn_stream, line)) {
+ istringstream iss(line);
+ vector<string> line_item{istream_iterator<string>{iss}, istream_iterator<string>{}};
+ if (line_item[0] == "<AddShift>") {
+ getline(cmvn_stream, line);
+ istringstream means_lines_stream(line);
+ vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
+ if (means_lines[0] == "<LearnRateCoef>") {
+ for (int j = 3; j < means_lines.size() - 1; j++) {
+ means_list_.push_back(stof(means_lines[j]));
+ }
+ continue;
}
- continue;
+ }
+ else if (line_item[0] == "<Rescale>") {
+ getline(cmvn_stream, line);
+ istringstream vars_lines_stream(line);
+ vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
+ if (vars_lines[0] == "<LearnRateCoef>") {
+ for (int j = 3; j < vars_lines.size() - 1; j++) {
+ // vars_list_.push_back(stof(vars_lines[j])*scale);
+ vars_list_.push_back(stof(vars_lines[j]));
+ }
+ continue;
+ }
}
}
- else if (line_item[0] == "<Rescale>") {
- getline(cmvn_stream, line);
- istringstream vars_lines_stream(line);
- vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
- if (vars_lines[0] == "<LearnRateCoef>") {
- for (int j = 3; j < vars_lines.size() - 1; j++) {
- // vars_list.push_back(stof(vars_lines[j])*scale);
- vars_list.push_back(stof(vars_lines[j]));
- }
- continue;
- }
- }
+ }catch(std::exception const &e) {
+ LOG(ERROR) << "Error when load vad cmvn : " << e.what();
+ exit(0);
}
}
-std::vector<std::vector<float>> &FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats, int lfr_m, int lfr_n) {
+void FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats) {
std::vector<std::vector<float>> out_feats;
int T = vad_feats.size();
@@ -230,28 +268,26 @@
}
// Apply cmvn
for (auto &out_feat: out_feats) {
- for (int j = 0; j < means_list.size(); j++) {
- out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j];
+ for (int j = 0; j < means_list_.size(); j++) {
+ out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
}
}
vad_feats = out_feats;
- return vad_feats;
}
std::vector<std::vector<int>>
-FsmnVad::Infer(const std::vector<float> &waves) {
+FsmnVad::Infer(std::vector<float> &waves, bool input_finished) {
std::vector<std::vector<float>> vad_feats;
std::vector<std::vector<float>> vad_probs;
FbankKaldi(vad_sample_rate_, vad_feats, waves);
- vad_feats = LfrCmvn(vad_feats, 5, 1);
- Forward(vad_feats, &vad_probs);
+ LfrCmvn(vad_feats);
+ Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
E2EVadModel vad_scorer = E2EVadModel();
std::vector<std::vector<int>> vad_segments;
vad_segments = vad_scorer(vad_probs, waves, true, false, vad_silence_duration_, vad_max_len_,
vad_speech_noise_thres_, vad_sample_rate_);
return vad_segments;
-
}
void FsmnVad::InitCache(){
@@ -269,5 +305,10 @@
void FsmnVad::Test() {
}
+FsmnVad::~FsmnVad() {
+}
+
FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} {
}
+
+} // namespace funasr
--
Gitblit v1.9.1