From 32d2b3ec153e53176da710ebcc0aba5669effd8a Mon Sep 17 00:00:00 2001
From: yhliang <429259365@qq.com>
Date: 星期四, 27 四月 2023 17:45:00 +0800
Subject: [PATCH] update m2met2 docs

---
 funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py |   44 +++++++++++++++++++++++++++++++-------------
 1 files changed, 31 insertions(+), 13 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
index 3f6c3d1..b5b3312 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/e2e_vad.py
@@ -1,3 +1,7 @@
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 from enum import Enum
 from typing import List, Tuple, Dict, Any
 
@@ -189,6 +193,11 @@
 
 
 class E2EVadModel():
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Deep-FSMN for Large Vocabulary Continuous Speech Recognition
+    https://arxiv.org/abs/1803.05030
+    """
     def __init__(self, vad_post_args: Dict[str, Any]):
         super(E2EVadModel, self).__init__()
         self.vad_opts = VADXOptions(**vad_post_args)
@@ -439,10 +448,9 @@
                         - 1)) / self.vad_opts.noise_frame_num_used_for_snr
 
         return frame_state
-     
 
     def __call__(self, score: np.ndarray, waveform: np.ndarray,
-                is_final: bool = False, max_end_sil: int = 800
+                is_final: bool = False, max_end_sil: int = 800, online: bool = False
                 ):
         self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
         self.waveform = waveform  # compute decibel for each frame
@@ -457,20 +465,29 @@
             segment_batch = []
             if len(self.output_data_buf) > 0:
                 for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
-                    if not self.output_data_buf[i].contain_seg_start_point:
-                        continue
-                    if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
-                        continue
-                    start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
-                    if self.output_data_buf[i].contain_seg_end_point:
-                        end_ms = self.output_data_buf[i].end_ms
-                        self.next_seg = True
-                        self.output_data_buf_offset += 1
+                    if online:
+                        if not self.output_data_buf[i].contain_seg_start_point:
+                            continue
+                        if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
+                            continue
+                        start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
+                        if self.output_data_buf[i].contain_seg_end_point:
+                            end_ms = self.output_data_buf[i].end_ms
+                            self.next_seg = True
+                            self.output_data_buf_offset += 1
+                        else:
+                            end_ms = -1
+                            self.next_seg = False
                     else:
-                        end_ms = -1
-                        self.next_seg = False
+                        if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+                            i].contain_seg_end_point):
+                            continue
+                        start_ms = self.output_data_buf[i].start_ms
+                        end_ms = self.output_data_buf[i].end_ms
+                        self.output_data_buf_offset += 1
                     segment = [start_ms, end_ms]
                     segment_batch.append(segment)
+
             if segment_batch:
                 segments.append(segment_batch)
         if is_final:
@@ -605,3 +622,4 @@
         if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
                 self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
             self.ResetDetection()
+

--
Gitblit v1.9.1