From 435a5906e538de4c975c7847acfd99772881e3f1 Mon Sep 17 00:00:00 2001
From: 凌匀 <ailsa.zly@alibaba-inc.com>
Date: 星期三, 12 四月 2023 17:43:29 +0800
Subject: [PATCH] support onnxruntime of streaming vad & bug fix
---
funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py | 35 ++++++++++++++++++++++-------------
1 files changed, 22 insertions(+), 13 deletions(-)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
index 3f6c3d1..f540765 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
@@ -439,10 +439,9 @@
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
-
def __call__(self, score: np.ndarray, waveform: np.ndarray,
- is_final: bool = False, max_end_sil: int = 800
+ is_final: bool = False, max_end_sil: int = 800, online: bool = False
):
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame
@@ -457,20 +456,29 @@
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
- if not self.output_data_buf[i].contain_seg_start_point:
- continue
- if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
- continue
- start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
- if self.output_data_buf[i].contain_seg_end_point:
- end_ms = self.output_data_buf[i].end_ms
- self.next_seg = True
- self.output_data_buf_offset += 1
+ if online:
+ if not self.output_data_buf[i].contain_seg_start_point:
+ continue
+ if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
+ continue
+ start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
+ if self.output_data_buf[i].contain_seg_end_point:
+ end_ms = self.output_data_buf[i].end_ms
+ self.next_seg = True
+ self.output_data_buf_offset += 1
+ else:
+ end_ms = -1
+ self.next_seg = False
else:
- end_ms = -1
- self.next_seg = False
+ if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+ i].contain_seg_end_point:
+ continue
+ start_ms = self.output_data_buf[i].start_ms
+ end_ms = self.output_data_buf[i].end_ms
+ self.output_data_buf_offset += 1
segment = [start_ms, end_ms]
segment_batch.append(segment)
+
if segment_batch:
segments.append(segment_batch)
if is_final:
@@ -605,3 +613,4 @@
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
self.ResetDetection()
+
--
Gitblit v1.9.1