From 435a5906e538de4c975c7847acfd99772881e3f1 Mon Sep 17 00:00:00 2001
From: 凌匀 <ailsa.zly@alibaba-inc.com>
Date: 星期三, 12 四月 2023 17:43:29 +0800
Subject: [PATCH] support onnxruntime of streaming vad & bug fix

---
 funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py |   35 ++++++++++++++++++++++-------------
 1 files changed, 22 insertions(+), 13 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
index 3f6c3d1..f540765 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
@@ -439,10 +439,9 @@
                         - 1)) / self.vad_opts.noise_frame_num_used_for_snr
 
         return frame_state
-     
 
     def __call__(self, score: np.ndarray, waveform: np.ndarray,
-                is_final: bool = False, max_end_sil: int = 800
+                is_final: bool = False, max_end_sil: int = 800, online: bool = False
                 ):
         self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
         self.waveform = waveform  # compute decibel for each frame
@@ -457,20 +456,29 @@
             segment_batch = []
             if len(self.output_data_buf) > 0:
                 for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
-                    if not self.output_data_buf[i].contain_seg_start_point:
-                        continue
-                    if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
-                        continue
-                    start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
-                    if self.output_data_buf[i].contain_seg_end_point:
-                        end_ms = self.output_data_buf[i].end_ms
-                        self.next_seg = True
-                        self.output_data_buf_offset += 1
+                    if online:
+                        if not self.output_data_buf[i].contain_seg_start_point:
+                            continue
+                        if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
+                            continue
+                        start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
+                        if self.output_data_buf[i].contain_seg_end_point:
+                            end_ms = self.output_data_buf[i].end_ms
+                            self.next_seg = True
+                            self.output_data_buf_offset += 1
+                        else:
+                            end_ms = -1
+                            self.next_seg = False
                     else:
-                        end_ms = -1
-                        self.next_seg = False
+                        if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+                            i].contain_seg_end_point:
+                            continue
+                        start_ms = self.output_data_buf[i].start_ms
+                        end_ms = self.output_data_buf[i].end_ms
+                        self.output_data_buf_offset += 1
                     segment = [start_ms, end_ms]
                     segment_batch.append(segment)
+
             if segment_batch:
                 segments.append(segment_batch)
         if is_final:
@@ -605,3 +613,4 @@
         if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
                 self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
             self.ResetDetection()
+

--
Gitblit v1.9.1