From 3f15e4268a6a7753f9b3d1736d87f74d55112175 Mon Sep 17 00:00:00 2001
From: lingyunfly <121302812+lingyunfly@users.noreply.github.com>
Date: 星期四, 13 四月 2023 15:03:23 +0800
Subject: [PATCH] Update e2e_vad.py

---
 funasr/models/e2e_vad.py |   13 ++++++++-----
 1 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
old mode 100755
new mode 100644
index 2c5673c..440a049
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -192,7 +192,7 @@
 
 
 class E2EVadModel(nn.Module):
-    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
+    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None):
         super(E2EVadModel, self).__init__()
         self.vad_opts = VADXOptions(**vad_post_args)
         self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@@ -229,6 +229,7 @@
         self.data_buf_all = None
         self.waveform = None
         self.ResetDetection()
+        self.frontend = frontend
 
     def AllResetDetection(self):
         self.is_final = False
@@ -459,8 +460,8 @@
             segment_batch = []
             if len(self.output_data_buf) > 0:
                 for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
-                    if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
-                        i].contain_seg_end_point:
+                    if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
+                        i].contain_seg_end_point):
                         continue
                     segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
                     segment_batch.append(segment)
@@ -473,11 +474,13 @@
         return segments, in_cache
 
     def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
-                is_final: bool = False
+                is_final: bool = False, max_end_sil: int = 800
                 ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+        self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
         self.waveform = waveform  # compute decibel for each frame
-        self.ComputeDecibel()
+        
         self.ComputeScores(feats, in_cache)
+        self.ComputeDecibel()
         if not is_final:
             self.DetectCommonFrames()
         else:

--
Gitblit v1.9.1