From 219c2482ab755fbd4e49dfbdee91bf1a8a4ec49a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 19 五月 2023 11:33:27 +0800
Subject: [PATCH] websocket 2pass bugfix

---
 funasr/runtime/onnxruntime/src/fsmn-vad.cpp |   26 ++++++++++++++++----------
 1 files changed, 16 insertions(+), 10 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
index b1b0e63..516dc88 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad.cpp
@@ -6,6 +6,7 @@
 #include <fstream>
 #include "precomp.h"
 
+namespace funasr {
 void FsmnVad::InitVad(const std::string &vad_model, const std::string &vad_cmvn, const std::string &vad_config, int thread_num) {
     session_options_.SetIntraOpNumThreads(thread_num);
     session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
@@ -161,17 +162,21 @@
     }
   
     // get 4 caches outputs,each size is 128*19
-    for (int i = 1; i < 5; i++) {
-      float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
-      memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
-    }
+    // for (int i = 1; i < 5; i++) {
+    //   float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
+    //   memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
+    // }
 }
 
 void FsmnVad::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &vad_feats,
-                         const std::vector<float> &waves) {
+                         std::vector<float> &waves) {
     knf::OnlineFbank fbank(fbank_opts);
 
-    fbank.AcceptWaveform(sample_rate, &waves[0], waves.size());
+    std::vector<float> buf(waves.size());
+    for (int32_t i = 0; i != waves.size(); ++i) {
+        buf[i] = waves[i] * 32768;
+    }
+    fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
     int32_t frames = fbank.NumFramesReady();
     for (int32_t i = 0; i != frames; ++i) {
         const float *frame = fbank.GetFrame(i);
@@ -224,7 +229,7 @@
     }
 }
 
-std::vector<std::vector<float>> &FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats) {
+void FsmnVad::LfrCmvn(std::vector<std::vector<float>> &vad_feats) {
 
     std::vector<std::vector<float>> out_feats;
     int T = vad_feats.size();
@@ -263,15 +268,14 @@
         }
     }
     vad_feats = out_feats;
-    return vad_feats;
 }
 
 std::vector<std::vector<int>>
-FsmnVad::Infer(const std::vector<float> &waves) {
+FsmnVad::Infer(std::vector<float> &waves, bool input_finished) {
     std::vector<std::vector<float>> vad_feats;
     std::vector<std::vector<float>> vad_probs;
     FbankKaldi(vad_sample_rate_, vad_feats, waves);
-    vad_feats = LfrCmvn(vad_feats);
+    LfrCmvn(vad_feats);
     Forward(vad_feats, &vad_probs);
 
     E2EVadModel vad_scorer = E2EVadModel();
@@ -301,3 +305,5 @@
 
 FsmnVad::FsmnVad():env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options_{} {
 }
+
+} // namespace funasr

--
Gitblit v1.9.1