From c2dee5e3c29eba79e591d9e9caebaef15ea4e56b Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 29 六月 2023 11:09:28 +0800
Subject: [PATCH] Merge pull request #687 from alibaba-damo-academy/dev_lhn

---
 funasr/runtime/onnxruntime/src/fsmn-vad.cpp |   67 +++++++++++++++++++--------------
 1 files changed, 39 insertions(+), 28 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
index b1b0e63..697828b 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -6,6 +6,7 @@
 #include <fstream>
 #include "precomp.h"
 
+namespace funasr {
 void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num) {
     session_options_.SetIntraOpNumThreads(thread_num);
     session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
@@ -36,14 +37,14 @@
         this->vad_max_len_ = post_conf["max_single_segment_time"].as<int>();
         this->vad_speech_noise_thres_ = post_conf["speech_noise_thres"].as<double>();
 
-        fbank_opts.frame_opts.dither = frontend_conf["dither"].as<float>();
-        fbank_opts.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
-        fbank_opts.frame_opts.samp_freq = (float)vad_sample_rate_;
-        fbank_opts.frame_opts.window_type = frontend_conf["window"].as<string>();
-        fbank_opts.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
-        fbank_opts.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
-        fbank_opts.energy_floor = 0;
-        fbank_opts.mel_opts.debug_mel = false;
+        fbank_opts_.frame_opts.dither = frontend_conf["dither"].as<float>();
+        fbank_opts_.mel_opts.num_bins = frontend_conf["n_mels"].as<int>();
+        fbank_opts_.frame_opts.samp_freq = (float)vad_sample_rate_;
+        fbank_opts_.frame_opts.window_type = frontend_conf["window"].as<string>();
+        fbank_opts_.frame_opts.frame_shift_ms = frontend_conf["frame_shift"].as<float>();
+        fbank_opts_.frame_opts.frame_length_ms = frontend_conf["frame_length"].as<float>();
+        fbank_opts_.energy_floor = 0;
+        fbank_opts_.mel_opts.debug_mel = false;
     }catch(exception const &e){
         LOG(ERROR) << "Error when load argument from vad config YAML.";
         exit(-1);
@@ -54,6 +55,7 @@
     try {
         vad_session_ = std::make_shared<Ort::Session>(
                 env_, vad_model, session_options_);
+        LOG(INFO) << "Successfully load model from " << vad_model;
     } catch (std::exception const &e) {
         LOG(ERROR) << "Error when load vad onnx model: " << e.what();
         exit(0);
@@ -108,7 +110,9 @@
 
 void FsmnVad::Forward(
         const std::vector<std::vector<float>> &chunk_feats,
-        std::vector<std::vector<float>> *out_prob) {
+        std::vector<std::vector<float>> *out_prob,
+        std::vector<std::vector<float>> *in_cache,
+        bool is_final) {
     Ort::MemoryInfo memory_info =
             Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
 
@@ -131,9 +135,9 @@
     // 4 caches
     // cache node {batch,128,19,1}
     const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
-    for (int i = 0; i < in_cache_.size(); i++) {
+    for (int i = 0; i < in_cache->size(); i++) {
       vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
-              memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
+              memory_info, (*in_cache)[i].data(), (*in_cache)[i].size(), cache_feats_shape, 4)));
     }
   
     // 4. Onnx infer
@@ -161,21 +165,27 @@
     }
   
     // get 4 caches outputs,each size is 128*19
-    for (int i = 1; i < 5; i++) {
-      float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
-      memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
+    if(!is_final){
+        for (int i = 1; i < 5; i++) {
+        float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
+        memcpy((*in_cache)[i-1].data(), data, sizeof(float) * 128*19);
+        }
     }
 }
 
 void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
-                         const std::vector<float> &waves) {
-    knf::OnlineFbank fbank(fbank_opts);
+                         std::vector<float> &waves) {
+    knf::OnlineFbank fbank(fbank_opts_);
 
-    fbank.AcceptWaveform(sample_rate, &waves[0], waves.size());
+    std::vector<float> buf(waves.size());
+    for (int32_t i = 0; i != waves.size(); ++i) {
+        buf[i] = waves[i] * 32768;
+    }
+    fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
     int32_t frames = fbank.NumFramesReady();
     for (int32_t i = 0; i != frames; ++i) {
         const float *frame = fbank.GetFrame(i);
-        std::vector<float> frame_vector(frame, frame + fbank_opts.mel_opts.num_bins);
+        std::vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
         vad_feats.emplace_back(frame_vector);
     }
 }
@@ -200,7 +210,7 @@
                 vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
                 if (means_lines[0] == "<LearnRateCoef>") {
                     for (int j = 3; j < means_lines.size() - 1; j++) {
-                        means_list.push_back(stof(means_lines[j]));
+                        means_list_.push_back(stof(means_lines[j]));
                     }
                     continue;
                 }
@@ -211,8 +221,8 @@
                 vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
                 if (vars_lines[0] == "<LearnRateCoef>") {
                     for (int j = 3; j < vars_lines.size() - 1; j++) {
-                        // vars_list.push_back(stof(vars_lines[j])*scale);
-                        vars_list.push_back(stof(vars_lines[j]));
+                        // vars_list_.push_back(stof(vars_lines[j])*scale);
+                        vars_list_.push_back(stof(vars_lines[j]));
                     }
                     continue;
                 }
@@ -224,7 +234,7 @@
     }
 }
 
-std::vector<std::vector<float>> &FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats) {
+void FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats) {
 
     std::vector<std::vector<float>> out_feats;
     int T = vad_feats.size();
@@ -258,21 +268,20 @@
     }
     // Apply cmvn
     for (auto &out_feat: out_feats) {
-        for (int j = 0; j < means_list.size(); j++) {
-            out_feat[j] = (out_feat[j] + means_list[j]) * vars_list[j];
+        for (int j = 0; j < means_list_.size(); j++) {
+            out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
         }
     }
     vad_feats = out_feats;
-    return vad_feats;
 }
 
 std::vector<std::vector<int>>
-FsmnVad::Infer(const std::vector<float> &waves) {
+FsmnVad::Infer(std::vector<float> &waves, bool input_finished) {
     std::vector<std::vector<float>> vad_feats;
     std::vector<std::vector<float>> vad_probs;
     FbankKaldi(vad_sample_rate_, vad_feats, waves);
-    vad_feats = LfrCmvn(vad_feats);
-    Forward(vad_feats, &vad_probs);
+    LfrCmvn(vad_feats);
+    Forward(vad_feats, &vad_probs, &in_cache_, input_finished);
 
     E2EVadModel vad_scorer = E2EVadModel();
     std::vector<std::vector<int>> vad_segments;
@@ -301,3 +310,5 @@
 
 FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} {
 }
+
+} // namespace funasr

--
Gitblit v1.9.1