From 00f5ea6244384b2338b99984a55bf3f9e08dcc9c Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 二月 2023 10:32:12 +0800
Subject: [PATCH] Merge pull request #129 from alibaba-damo-academy/dev_zly
---
funasr/models/e2e_vad.py | 24 ++++--------------------
1 files changed, 4 insertions(+), 20 deletions(-)
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index 8afc8db..b64c677 100755
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -192,7 +192,7 @@
class E2EVadModel(nn.Module):
- def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], streaming=False):
+ def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
super(E2EVadModel, self).__init__()
self.vad_opts = VADXOptions(**vad_post_args)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@@ -227,7 +227,6 @@
self.data_buf = None
self.data_buf_all = None
self.waveform = None
- self.streaming = streaming
self.ResetDetection()
def AllResetDetection(self):
@@ -451,11 +450,7 @@
if not is_final_send:
self.DetectCommonFrames()
else:
- if self.streaming:
- self.DetectLastFrames()
- else:
- self.AllResetDetection()
- self.DetectAllFrames() # offline decode and is_final_send == True
+ self.DetectLastFrames()
segments = []
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
segment_batch = []
@@ -468,7 +463,8 @@
self.output_data_buf_offset += 1 # need update this parameter
if segment_batch:
segments.append(segment_batch)
-
+ if is_final_send:
+ self.AllResetDetection()
return segments
def DetectCommonFrames(self) -> int:
@@ -492,18 +488,6 @@
else:
self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
- return 0
-
- def DetectAllFrames(self) -> int:
- if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
- return 0
- if self.vad_opts.nn_eval_block_size != self.vad_opts.dcd_block_size:
- frame_state = FrameState.kFrameStateInvalid
- for t in range(0, self.frm_cnt):
- frame_state = self.GetFrameState(t)
- self.DetectOneFrame(frame_state, t, t == self.frm_cnt - 1)
- else:
- pass
return 0
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
--
Gitblit v1.9.1