From f77c5803f4d61099e572be8d877b1c4a4d6087cd Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期三, 10 五月 2023 12:02:06 +0800
Subject: [PATCH] Merge pull request #485 from alibaba-damo-academy/main

---
 funasr/runtime/onnxruntime/src/e2e-vad.h |   72 ++++++++++++++----------------------
 1 files changed, 28 insertions(+), 44 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/e2e-vad.h b/funasr/runtime/onnxruntime/src/e2e-vad.h
index 90f2635..5ece1f8 100644
--- a/funasr/runtime/onnxruntime/src/e2e-vad.h
+++ b/funasr/runtime/onnxruntime/src/e2e-vad.h
@@ -1,7 +1,10 @@
 /**
  * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  * MIT License  (https://opensource.org/licenses/MIT)
+ * Contributed by zhuzizyf(China Telecom).
 */
+
+#pragma once 
 
 #include <utility>
 #include <vector>
@@ -13,7 +16,7 @@
 #include <numeric>
 #include <cassert>
 
-
+namespace funasr {
 enum class VadStateMachine {
     kVadInStateStartPointNotDetected = 1,
     kVadInStateInSpeechSegment = 2,
@@ -381,10 +384,11 @@
     int max_end_sil_frame_cnt_thresh;
     float speech_noise_thres;
     std::vector<std::vector<float>> scores;
+    int idx_pre_chunk = 0;
     bool max_time_out;
     std::vector<float> decibel;
-    std::vector<float> data_buf;
-    std::vector<float> data_buf_all;
+    int data_buf_size = 0;
+    int data_buf_all_size = 0;
     std::vector<float> waveform;
 
     void AllResetDetection() {
@@ -409,10 +413,11 @@
         max_end_sil_frame_cnt_thresh = vad_opts.max_end_silence_time - vad_opts.speech_to_sil_time_thres;
         speech_noise_thres = vad_opts.speech_noise_thres;
         scores.clear();
+        idx_pre_chunk = 0;
         max_time_out = false;
         decibel.clear();
-        data_buf.clear();
-        data_buf_all.clear();
+        int data_buf_size = 0;
+        int data_buf_all_size = 0;
         waveform.clear();
         ResetDetection();
     }
@@ -432,18 +437,17 @@
     void ComputeDecibel() {
         int frame_sample_length = int(vad_opts.frame_length_ms * vad_opts.sample_rate / 1000);
         int frame_shift_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
-        if (data_buf_all.empty()) {
-            data_buf_all = waveform;
-            data_buf = data_buf_all;
+        if (data_buf_all_size == 0) {
+          data_buf_all_size = waveform.size();
+          data_buf_size = data_buf_all_size;
         } else {
-            data_buf_all.insert(data_buf_all.end(), waveform.begin(), waveform.end());
+          data_buf_all_size += waveform.size();
         }
-        for (int offset = 0; offset < waveform.size() - frame_sample_length + 1; offset += frame_shift_length) {
+        for (int offset = 0; offset + frame_sample_length -1 < waveform.size(); offset += frame_shift_length) {
             float sum = 0.0;
             for (int i = 0; i < frame_sample_length; i++) {
                 sum += waveform[offset + i] * waveform[offset + i];
             }
-//      float decibel = 10 * log10(sum + 0.000001);
             this->decibel.push_back(10 * log10(sum + 0.000001));
         }
     }
@@ -451,29 +455,16 @@
     void ComputeScores(const std::vector<std::vector<float>> &scores) {
         vad_opts.nn_eval_block_size = scores.size();
         frm_cnt += scores.size();
-        if (this->scores.empty()) {
-            this->scores = scores;  // the first calculation
-        } else {
-            this->scores.insert(this->scores.end(), scores.begin(), scores.end());
-        }
+        this->scores = scores;
     }
 
     void PopDataBufTillFrame(int frame_idx) {
       int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
-      int start_pos=-1;
-      int data_length= data_buf.size();
       while (data_buf_start_frame < frame_idx) {
-        if (data_length >= frame_sample_length) {
+        if (data_buf_size >= frame_sample_length) {
           data_buf_start_frame += 1;
-          start_pos= data_buf_start_frame* frame_sample_length;
-          data_length=data_buf_all.size()-start_pos;
-        } else {
-          break;
+          data_buf_size = data_buf_all_size - data_buf_start_frame * frame_sample_length;
         }
-      }
-      if (start_pos!=-1){
-        data_buf.resize(data_length);
-        std::copy(data_buf_all.begin() + start_pos, data_buf_all.end(), data_buf.begin());
       }
     }
 
@@ -487,9 +478,9 @@
             expected_sample_number += int(extra_sample);
         }
         if (end_point_is_sent_end) {
-            expected_sample_number = std::max(expected_sample_number, int(data_buf.size()));
+            expected_sample_number = std::max(expected_sample_number, data_buf_size);
         }
-        if (data_buf.size() < expected_sample_number) {
+        if (data_buf_size < expected_sample_number) {
             std::cout << "error in calling pop data_buf\n";
         }
         if (output_data_buf.size() == 0 || first_frm_is_start_point) {
@@ -503,27 +494,20 @@
         if (cur_seg.end_ms != start_frm * vad_opts.frame_in_ms) {
             std::cout << "warning\n";
         }
-        int out_pos = (int) cur_seg.buffer.size();
+
         int data_to_pop;
         if (end_point_is_sent_end) {
             data_to_pop = expected_sample_number;
         } else {
             data_to_pop = int(frm_cnt * vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
         }
-        if (data_to_pop > int(data_buf.size())) {
+        if (data_to_pop > data_buf_size) {
             std::cout << "VAD data_to_pop is bigger than data_buf.size()!!!\n";
-            data_to_pop = (int) data_buf.size();
-            expected_sample_number = (int) data_buf.size();
+            data_to_pop = data_buf_size;
+            expected_sample_number = data_buf_size;
         }
         cur_seg.doa = 0;
-        for (int sample_cpy_out = 0; sample_cpy_out < data_to_pop; sample_cpy_out++) {
-            cur_seg.buffer.push_back(data_buf.back());
-            out_pos++;
-        }
-        for (int sample_cpy_out = data_to_pop; sample_cpy_out < expected_sample_number; sample_cpy_out++) {
-            cur_seg.buffer.push_back(data_buf.back());
-            out_pos++;
-        }
+        
         if (cur_seg.end_ms != start_frm * vad_opts.frame_in_ms) {
             std::cout << "Something wrong with the VAD algorithm\n";
         }
@@ -619,7 +603,7 @@
         if (sil_pdf_ids.size() > 0) {
             std::vector<float> sil_pdf_scores;
             for (auto sil_pdf_id: sil_pdf_ids) {
-                sil_pdf_scores.push_back(scores[t][sil_pdf_id]);
+                sil_pdf_scores.push_back(scores[t - idx_pre_chunk][sil_pdf_id]);
             }
             sum_score = accumulate(sil_pdf_scores.begin(), sil_pdf_scores.end(), 0.0);
             noise_prob = log(sum_score) * vad_opts.speech_2_noise_ratio;
@@ -663,6 +647,7 @@
             frame_state = GetFrameState(frm_cnt - 1 - i);
             DetectOneFrame(frame_state, frm_cnt - 1 - i, false);
         }
+        idx_pre_chunk += scores.size();
         return 0;
     }
 
@@ -797,5 +782,4 @@
 
 };
 
-
-
+} // namespace funasr

--
Gitblit v1.9.1