From 00f5ea6244384b2338b99984a55bf3f9e08dcc9c Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 二月 2023 10:32:12 +0800
Subject: [PATCH] Merge pull request #129 from alibaba-damo-academy/dev_zly

---
 funasr/models/e2e_vad.py |   24 ++++--------------------
 1 files changed, 4 insertions(+), 20 deletions(-)

diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index 8afc8db..b64c677 100755
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -192,7 +192,7 @@
 
 
 class E2EVadModel(nn.Module):
-    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], streaming=False):
+    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
         super(E2EVadModel, self).__init__()
         self.vad_opts = VADXOptions(**vad_post_args)
         self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@@ -227,7 +227,6 @@
         self.data_buf = None
         self.data_buf_all = None
         self.waveform = None
-        self.streaming = streaming
         self.ResetDetection()
 
     def AllResetDetection(self):
@@ -451,11 +450,7 @@
         if not is_final_send:
             self.DetectCommonFrames()
         else:
-            if self.streaming:
-                self.DetectLastFrames()
-            else:
-                self.AllResetDetection()
-                self.DetectAllFrames()  # offline decode and is_final_send == True
+            self.DetectLastFrames()
         segments = []
         for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
             segment_batch = []
@@ -468,7 +463,8 @@
                         self.output_data_buf_offset += 1  # need update this parameter
             if segment_batch:
                 segments.append(segment_batch)
-
+        if is_final_send:
+            self.AllResetDetection() 
         return segments
 
     def DetectCommonFrames(self) -> int:
@@ -492,18 +488,6 @@
             else:
                 self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
 
-        return 0
-
-    def DetectAllFrames(self) -> int:
-        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
-            return 0
-        if self.vad_opts.nn_eval_block_size != self.vad_opts.dcd_block_size:
-            frame_state = FrameState.kFrameStateInvalid
-            for t in range(0, self.frm_cnt):
-                frame_state = self.GetFrameState(t)
-                self.DetectOneFrame(frame_state, t, t == self.frm_cnt - 1)
-        else:
-            pass
         return 0
 
     def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:

--
Gitblit v1.9.1