From 0bb5d87d1ee98289bbe241e1f2caf1ab8e64c69c Mon Sep 17 00:00:00 2001
From: lyblsgo <lyblsgo@163.com>
Date: 星期六, 22 四月 2023 20:34:07 +0800
Subject: [PATCH] Merge branch 'dev_knf' of https://github.com/alibaba-damo-academy/FunASR into dev_knf

---
 funasr/runtime/onnxruntime/src/OnlineFeature.h  |   59 ++++++++++++++
 funasr/runtime/onnxruntime/src/FsmnVad.h        |    4 +
 funasr/runtime/onnxruntime/src/FsmnVad.cc       |   34 ++++++--
 funasr/runtime/onnxruntime/src/OnlineFeature.cc |  133 +++++++++++++++++++++++++++++++++
 4 files changed, 221 insertions(+), 9 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/FsmnVad.cc b/funasr/runtime/onnxruntime/src/FsmnVad.cc
index f75ead7..de63225 100644
--- a/funasr/runtime/onnxruntime/src/FsmnVad.cc
+++ b/funasr/runtime/onnxruntime/src/FsmnVad.cc
@@ -18,6 +18,7 @@
 
     read_model(vad_model);
     load_cmvn(vad_cmvn.c_str());
+    init_cache();
 
     fbank_opts.frame_opts.dither = 0;
     fbank_opts.mel_opts.num_bins = 80;
@@ -105,20 +106,18 @@
     }
     Ort::Value vad_feats_ort = Ort::Value::CreateTensor<float>(
             memory_info, vad_feats.data(), vad_feats.size(), vad_feats_shape, 3);
-    // cache node {batch,128,19,1}
-    const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
-    std::vector<float> cache_feats(128 * 19 * 1, 0);
-    Ort::Value cache_feats_ort = Ort::Value::CreateTensor<float>(
-            memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4);
-
+    
     // 3. Put nodes into onnx input vector
     std::vector<Ort::Value> vad_inputs;
     vad_inputs.emplace_back(std::move(vad_feats_ort));
     // 4 caches
-    for (int i = 0; i < 4; i++) {
-        vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
-                memory_info, cache_feats.data(), cache_feats.size(), cache_feats_shape, 4)));
+    // cache node {batch,128,19,1}
+    const int64_t cache_feats_shape[4] = {1, 128, 19, 1};
+    for (int i = 0; i < in_cache_.size(); i++) {
+      vad_inputs.emplace_back(std::move(Ort::Value::CreateTensor<float>(
+              memory_info, in_cache_[i].data(), in_cache_[i].size(), cache_feats_shape, 4)));
     }
+  
     // 4. Onnx infer
     std::vector<Ort::Value> vad_ort_outputs;
     try {
@@ -142,6 +141,12 @@
         (*out_prob)[i].resize(output_dim);
         memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
                sizeof(float) * output_dim);
+    }
+  
+    // get 4 caches outputs,each size is 128*19
+    for (int i = 1; i < 5; i++) {
+      float* data = vad_ort_outputs[i].GetTensorMutableData<float>();
+      memcpy(in_cache_[i-1].data(), data, sizeof(float) * 128*19);
     }
 }
 
@@ -252,6 +257,17 @@
 
 }
 
+void FsmnVad::init_cache(){
+  std::vector<float> cache_feats(128 * 19 * 1, 0);
+  for (int i=0;i<4;i++){
+    in_cache_.emplace_back(cache_feats);
+  }
+};
+
+void FsmnVad::Reset(){
+  in_cache_.clear();
+  init_cache();
+};
 
 void FsmnVad::test() {
 
diff --git a/funasr/runtime/onnxruntime/src/FsmnVad.h b/funasr/runtime/onnxruntime/src/FsmnVad.h
index d7ec554..78302ae 100644
--- a/funasr/runtime/onnxruntime/src/FsmnVad.h
+++ b/funasr/runtime/onnxruntime/src/FsmnVad.h
@@ -16,6 +16,7 @@
                   float vad_speech_noise_thres);
 
     std::vector<std::vector<int>> infer(const std::vector<float> &waves);
+    void Reset();
 
 private:
 
@@ -35,12 +36,15 @@
             std::vector<std::vector<float>> *out_prob);
 
     void load_cmvn(const char *filename);
+    void init_cache();
 
     std::shared_ptr<Ort::Session> vad_session_ = nullptr;
     Ort::Env env_;
     Ort::SessionOptions session_options_;
     std::vector<const char *> vad_in_names_;
     std::vector<const char *> vad_out_names_;
+    std::vector<std::vector<float>> in_cache_;
+    
     knf::FbankOptions fbank_opts;
     std::vector<float> means_list;
     std::vector<float> vars_list;
diff --git a/funasr/runtime/onnxruntime/src/OnlineFeature.cc b/funasr/runtime/onnxruntime/src/OnlineFeature.cc
new file mode 100644
index 0000000..a2bbafd
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/OnlineFeature.cc
@@ -0,0 +1,133 @@
+//
+// Created by zhuzizyf(China Telecom Shanghai) on 4/22/23.
+//
+
+#include "OnlineFeature.h"
+
+#include <utility>
+
+OnlineFeature::OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m, int lfr_n,
+                             std::vector<std::vector<float>> cmvns)
+  : sample_rate_(sample_rate),
+    fbank_opts_(std::move(fbank_opts)),
+    lfr_m_(lfr_m),
+    lfr_n_(lfr_n),
+    cmvns_(std::move(cmvns)) {
+  frame_sample_length_ = sample_rate_ / 1000 * 25;;
+  frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
+}
+
+void OnlineFeature::extractFeats(vector<std::vector<float>> &vad_feats,
+                                 vector<float> waves, bool input_finished) {
+  input_finished_ = input_finished;
+  onlineFbank(vad_feats, waves);
+  // cache deal & online lfr,cmvn
+  if (vad_feats.size() > 0) {
+    if (!reserve_waveforms_.empty()) {
+      waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
+    }
+    if (lfr_splice_cache_.empty()) {
+      for (int i = 0; i < (lfr_m_ - 1) / 2; i++) {
+        lfr_splice_cache_.emplace_back(vad_feats[0]);
+      }
+    }
+    if (vad_feats.size() + lfr_splice_cache_.size() >= lfr_m_) {
+      vad_feats.insert(vad_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
+      int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
+      int minus_frame = reserve_waveforms_.empty() ? (lfr_m_ - 1) / 2 : 0;
+      int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats);
+      int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
+      reserve_waveforms_.clear();
+      reserve_waveforms_.insert(reserve_waveforms_.begin(),
+                                waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
+                                waves.begin() + frame_from_waves * frame_shift_sample_length_);
+      int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
+      waves.erase(waves.begin() + sample_length, waves.end());
+    } else {
+      reserve_waveforms_.clear();
+      reserve_waveforms_.insert(reserve_waveforms_.begin(),
+                                waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
+      lfr_splice_cache_.insert(lfr_splice_cache_.end(), vad_feats.begin(), vad_feats.end());
+    }
+
+  } else {
+    if (input_finished_) {
+      if (!reserve_waveforms_.empty()) {
+        waves = reserve_waveforms_;
+      }
+      vad_feats = lfr_splice_cache_;
+      OnlineLfrCmvn(vad_feats);
+      reset_cache();
+    }
+  }
+
+}
+
+int OnlineFeature::OnlineLfrCmvn(vector<vector<float>> &vad_feats) {
+  vector<vector<float>> out_feats;
+  int T = vad_feats.size();
+  int T_lrf = ceil((T - (lfr_m_ - 1) / 2) / lfr_n_);
+  int lfr_splice_frame_idxs = T_lrf;
+  vector<float> p;
+  for (int i = 0; i < T_lrf; i++) {
+    if (lfr_m_ <= T - i * lfr_n_) {
+      for (int j = 0; j < lfr_m_; j++) {
+        p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
+      }
+      out_feats.emplace_back(p);
+      p.clear();
+    } else {
+      if (input_finished_) {
+        int num_padding = lfr_m_ - (T - i * lfr_n_);
+        for (int j = 0; j < (vad_feats.size() - i * lfr_n_); j++) {
+          p.insert(p.end(), vad_feats[i * lfr_n_ + j].begin(), vad_feats[i * lfr_n_ + j].end());
+        }
+        for (int j = 0; j < num_padding; j++) {
+          p.insert(p.end(), vad_feats[vad_feats.size() - 1].begin(), vad_feats[vad_feats.size() - 1].end());
+        }
+        out_feats.emplace_back(p);
+      } else {
+        lfr_splice_frame_idxs = i;
+        break;
+      }
+    }
+  }
+  lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n_);
+  lfr_splice_cache_.clear();
+  lfr_splice_cache_.insert(lfr_splice_cache_.begin(), vad_feats.begin() + lfr_splice_frame_idxs, vad_feats.end());
+
+  // Apply cmvn
+  for (auto &out_feat: out_feats) {
+    for (int j = 0; j < cmvns_[0].size(); j++) {
+      out_feat[j] = (out_feat[j] + cmvns_[0][j]) * cmvns_[1][j];
+    }
+  }
+  vad_feats = out_feats;
+  return lfr_splice_frame_idxs;
+}
+
+void OnlineFeature::onlineFbank(vector<std::vector<float>> &vad_feats,
+                                vector<float> &waves) {
+
+  knf::OnlineFbank fbank(fbank_opts_);
+  // cache merge
+  waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
+  int frame_number = compute_frame_num(waves.size(), frame_sample_length_, frame_shift_sample_length_);
+  // Send the audio after the last frame shift position to the cache
+  input_cache_.clear();
+  input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
+  if (frame_number == 0) {
+    return;
+  }
+  // Delete audio that haven't undergone fbank processing
+  waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
+
+  fbank.AcceptWaveform(sample_rate_, &waves[0], waves.size());
+  int32_t frames = fbank.NumFramesReady();
+  for (int32_t i = 0; i != frames; ++i) {
+    const float *frame = fbank.GetFrame(i);
+    vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
+    vad_feats.emplace_back(frame_vector);
+  }
+
+}
diff --git a/funasr/runtime/onnxruntime/src/OnlineFeature.h b/funasr/runtime/onnxruntime/src/OnlineFeature.h
new file mode 100644
index 0000000..bd613ab
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/OnlineFeature.h
@@ -0,0 +1,59 @@
+//
+// Created by zhuzizyf(China Telecom Shanghai) on 4/22/23.
+//
+
+
+#include "kaldi-native-fbank/csrc/feature-fbank.h"
+#include "kaldi-native-fbank/csrc/online-feature.h"
+#include <vector>
+
+using namespace std;
+
+class OnlineFeature {
+
+public:
+  OnlineFeature(int sample_rate, knf::FbankOptions fbank_opts, int lfr_m_, int lfr_n_,
+                std::vector<std::vector<float>> cmvns_);
+
+  void extractFeats(vector<vector<float>> &vad_feats, vector<float> waves, bool input_finished);
+
+
+private:
+  void onlineFbank(vector<vector<float>> &vad_feats, vector<float> &waves);
+
+  int OnlineLfrCmvn(vector<vector<float>> &vad_feats);
+
+  static int compute_frame_num(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
+    int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
+
+    if (frame_num >= 1 && sample_length >= frame_sample_length)
+      return frame_num;
+    else
+      return 0;
+  }
+
+  void reset_cache() {
+    reserve_waveforms_.clear();
+    input_cache_.clear();
+    lfr_splice_cache_.clear();
+    input_finished_ = false;
+
+  }
+
+  knf::FbankOptions fbank_opts_;
+  // The reserved waveforms by fbank
+  std::vector<float> reserve_waveforms_;
+  // waveforms reserved after last shift position
+  std::vector<float> input_cache_;
+  // lfr reserved cache
+  std::vector<std::vector<float>> lfr_splice_cache_;
+  std::vector<std::vector<float>> cmvns_;
+
+  int sample_rate_ = 16000;
+  int frame_sample_length_ = sample_rate_ / 1000 * 25;;
+  int frame_shift_sample_length_ = sample_rate_ / 1000 * 10;
+  int lfr_m_;
+  int lfr_n_;
+  bool input_finished_ = false;
+
+};

--
Gitblit v1.9.1