From 9ddfac27d2a9ec6b136ab92539f5e786647def8f Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期五, 21 四月 2023 21:46:06 +0800
Subject: [PATCH] Merge pull request #397 from zhuzizyf/patch-1
---
funasr/runtime/onnxruntime/src/e2e_vad.h | 62 ++++++++++++++++++++-----------
1 files changed, 40 insertions(+), 22 deletions(-)
diff --git a/funasr/runtime/onnxruntime/src/e2e_vad.h b/funasr/runtime/onnxruntime/src/e2e_vad.h
index f0c4975..e029dc3 100644
--- a/funasr/runtime/onnxruntime/src/e2e_vad.h
+++ b/funasr/runtime/onnxruntime/src/e2e_vad.h
@@ -294,8 +294,8 @@
std::vector<std::vector<int>>
operator()(const std::vector<std::vector<float>> &score, const std::vector<float> &waveform, bool is_final = false,
- int max_end_sil = 800, int max_single_segment_time = 15000, float speech_noise_thres = 0.9,
- int sample_rate = 16000) {
+ bool online = false, int max_end_sil = 800, int max_single_segment_time = 15000,
+ float speech_noise_thres = 0.8, int sample_rate = 16000) {
max_end_sil_frame_cnt_thresh = max_end_sil - vad_opts.speech_to_sil_time_thres;
this->waveform = waveform;
this->vad_opts.max_single_segment_time = max_single_segment_time;
@@ -309,33 +309,44 @@
} else {
DetectLastFrames();
}
- // std::vector<std::vector<int>> segments;
- // for (size_t batch_num = 0; batch_num < score.size(); batch_num++) {
+
std::vector<std::vector<int>> segment_batch;
if (output_data_buf.size() > 0) {
for (size_t i = output_data_buf_offset; i < output_data_buf.size(); i++) {
+ int start_ms;
+ int end_ms;
+ if (online) {
+
if (!output_data_buf[i].contain_seg_start_point) {
- continue;
+ continue;
}
if (!next_seg && !output_data_buf[i].contain_seg_end_point) {
- continue;
+ continue;
}
- int start_ms = next_seg ? output_data_buf[i].start_ms : -1;
- int end_ms;
+ start_ms = next_seg ? output_data_buf[i].start_ms : -1;
+
if (output_data_buf[i].contain_seg_end_point) {
- end_ms = output_data_buf[i].end_ms;
- next_seg = true;
- output_data_buf_offset += 1;
+ end_ms = output_data_buf[i].end_ms;
+ next_seg = true;
+ output_data_buf_offset += 1;
} else {
- end_ms = -1;
- next_seg = false;
+ end_ms = -1;
+ next_seg = false;
}
+ } else {
+ if (!is_final &&
+ (!output_data_buf[i].contain_seg_start_point || !output_data_buf[i].contain_seg_end_point)) {
+ continue;
+ }
+ start_ms = output_data_buf[i].start_ms;
+ end_ms = output_data_buf[i].end_ms;
+ output_data_buf_offset += 1;
+ }
std::vector<int> segment = {start_ms, end_ms};
segment_batch.push_back(segment);
}
}
- // }
if (is_final) {
AllResetDetection();
}
@@ -444,15 +455,22 @@
}
void PopDataBufTillFrame(int frame_idx) {
- while (data_buf_start_frame < frame_idx) {
- int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
- if (data_buf.size() >= frame_sample_length) {
- data_buf_start_frame += 1;
- data_buf.erase(data_buf.begin(), data_buf.begin() + frame_sample_length);
- } else {
- break;
- }
+ int frame_sample_length = int(vad_opts.frame_in_ms * vad_opts.sample_rate / 1000);
+ int start_pos=-1;
+ int data_length= data_buf.size();
+ while (data_buf_start_frame < frame_idx) {
+ if (data_length >= frame_sample_length) {
+ data_buf_start_frame += 1;
+ start_pos= data_buf_start_frame* frame_sample_length;
+ data_length=data_buf_all.size()-start_pos;
+ } else {
+ break;
}
+ }
+ if (start_pos!=-1){
+ data_buf.resize(data_length);
+ std::copy(data_buf_all.begin() + start_pos, data_buf_all.end(), data_buf.begin());
+ }
}
void PopDataToOutputBuf(int start_frm, int frm_cnt, bool first_frm_is_start_point, bool last_frm_is_end_point,
--
Gitblit v1.9.1