From 6927d0baa7bcb2c86ec5e2517cb652e98e398f97 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 14 六月 2023 00:34:42 +0800
Subject: [PATCH] rename websocket client&server; fix funasr-ws-client; update readme;

---
 funasr/models/e2e_vad.py |   24 +++++++++++++++++-------
 1 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index 594c27e..14d56a8 100644
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -226,7 +226,6 @@
                                                self.vad_opts.frame_in_ms)
         self.encoder = encoder
         # init variables
-        self.is_final = False
         self.data_buf_start_frame = 0
         self.frm_cnt = 0
         self.latest_confirmed_speech_frame = 0
@@ -253,11 +252,10 @@
         self.data_buf = None
         self.data_buf_all = None
         self.waveform = None
-        self.ResetDetection()
         self.frontend = frontend
+        self.last_drop_frames = 0
 
     def AllResetDetection(self):
-        self.is_final = False
         self.data_buf_start_frame = 0
         self.frm_cnt = 0
         self.latest_confirmed_speech_frame = 0
@@ -284,7 +282,8 @@
         self.data_buf = None
         self.data_buf_all = None
         self.waveform = None
-        self.ResetDetection()
+        self.last_drop_frames = 0
+        self.windows_detector.Reset()
 
     def ResetDetection(self):
         self.continous_silence_frame_count = 0
@@ -296,6 +295,15 @@
         self.windows_detector.Reset()
         self.sil_frame = 0
         self.frame_probs = []
+
+        if self.output_data_buf:
+            assert self.output_data_buf[-1].contain_seg_end_point == True
+            drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
+            real_drop_frames = drop_frames - self.last_drop_frames
+            self.last_drop_frames = drop_frames
+            self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
+            self.decibel = self.decibel[real_drop_frames:]
+            self.scores = self.scores[:, real_drop_frames:, :]
 
     def ComputeDecibel(self) -> None:
         frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
@@ -324,7 +332,7 @@
         while self.data_buf_start_frame < frame_idx:
             if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
                 self.data_buf_start_frame += 1
-                self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
+                self.data_buf = self.data_buf_all[(self.data_buf_start_frame - self.last_drop_frames) * int(
                     self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
 
     def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
@@ -473,6 +481,8 @@
     def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
                 is_final: bool = False
                 ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+        if not in_cache:
+            self.AllResetDetection()
         self.waveform = waveform  # compute decibel for each frame
         self.ComputeDecibel()
         self.ComputeScores(feats, in_cache)
@@ -543,7 +553,7 @@
             return 0
         for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
             frame_state = FrameState.kFrameStateInvalid
-            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+            frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
             self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
 
         return 0
@@ -553,7 +563,7 @@
             return 0
         for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
             frame_state = FrameState.kFrameStateInvalid
-            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+            frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
             if i != 0:
                 self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
             else:

--
Gitblit v1.9.1