From d6fdd1c79364668afbf171bd18586dd1d7570d20 Mon Sep 17 00:00:00 2001
From: 凌匀 <ailsa.zly@alibaba-inc.com>
Date: 星期四, 16 二月 2023 14:56:32 +0800
Subject: [PATCH] support vad streaming decoder

---
 funasr/models/e2e_vad.py |  117 ++++++++++++++++++++++++++++++++++++++++------------------
 1 files changed, 81 insertions(+), 36 deletions(-)

diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index 98504d6..8afc8db 100755
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -5,7 +5,6 @@
 from torch import nn
 import math
 from funasr.models.encoder.fsmn_encoder import FSMN
-# from checkpoint import load_checkpoint
 
 
 class VadStateMachine(Enum):
@@ -136,7 +135,7 @@
 
         self.win_size_frame = int(window_size_ms / frame_size_ms)
         self.win_sum = 0
-        self.win_state = [0 for i in range(0, self.win_size_frame)]  # 鍒濆鍖栫獥
+        self.win_state = [0] * self.win_size_frame  # 鍒濆鍖栫獥
 
         self.cur_win_pos = 0
         self.pre_frame_state = FrameState.kFrameStateSil
@@ -151,7 +150,7 @@
     def Reset(self) -> None:
         self.cur_win_pos = 0
         self.win_sum = 0
-        self.win_state = [0 for i in range(0, self.win_size_frame)]
+        self.win_state = [0] * self.win_size_frame
         self.pre_frame_state = FrameState.kFrameStateSil
         self.cur_frame_state = FrameState.kFrameStateSil
         self.voice_last_frame_count = 0
@@ -192,8 +191,8 @@
         return int(self.frame_size_ms)
 
 
-class E2EVadModel(torch.nn.Module):
-    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]):
+class E2EVadModel(nn.Module):
+    def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], streaming=False):
         super(E2EVadModel, self).__init__()
         self.vad_opts = VADXOptions(**vad_post_args)
         self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
@@ -212,13 +211,13 @@
         self.confirmed_start_frame = -1
         self.confirmed_end_frame = -1
         self.number_end_time_detected = 0
-        self.is_callback_with_sign = False
         self.sil_frame = 0
         self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
         self.noise_average_decibel = -100.0
         self.pre_end_silence_detected = False
 
         self.output_data_buf = []
+        self.output_data_buf_offset = 0
         self.frame_probs = []
         self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
         self.speech_noise_thres = self.vad_opts.speech_noise_thres
@@ -226,10 +225,13 @@
         self.max_time_out = False
         self.decibel = []
         self.data_buf = None
+        self.data_buf_all = None
         self.waveform = None
+        self.streaming = streaming
         self.ResetDetection()
 
     def AllResetDetection(self):
+        self.encoder.cache_reset()  # reset the in_cache in self.encoder for next query or next long sentence
         self.is_final_send = False
         self.data_buf_start_frame = 0
         self.frm_cnt = 0
@@ -240,13 +242,13 @@
         self.confirmed_start_frame = -1
         self.confirmed_end_frame = -1
         self.number_end_time_detected = 0
-        self.is_callback_with_sign = False
         self.sil_frame = 0
         self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
         self.noise_average_decibel = -100.0
         self.pre_end_silence_detected = False
 
         self.output_data_buf = []
+        self.output_data_buf_offset = 0
         self.frame_probs = []
         self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
         self.speech_noise_thres = self.vad_opts.speech_noise_thres
@@ -254,6 +256,7 @@
         self.max_time_out = False
         self.decibel = []
         self.data_buf = None
+        self.data_buf_all = None
         self.waveform = None
         self.ResetDetection()
 
@@ -271,26 +274,32 @@
     def ComputeDecibel(self) -> None:
         frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
         frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
-        self.data_buf = self.waveform[0]  # 鎸囧悜self.waveform[0]
+        if self.data_buf_all is None:
+            self.data_buf_all = self.waveform[0]  # self.data_buf is pointed to self.waveform[0]
+            self.data_buf = self.data_buf_all
+        else:
+            self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0]))
         for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
             self.decibel.append(
                 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                 0.000001))
 
-    def ComputeScores(self, feats: torch.Tensor, feats_lengths: int) -> None:
-        self.scores = self.encoder(feats)  # return B * T * D
-        self.frm_cnt = feats_lengths # frame
-        # return self.scores
+    def ComputeScores(self, feats: torch.Tensor) -> None:
+        scores = self.encoder(feats)  # return B * T * D
+        assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
+        self.vad_opts.nn_eval_block_size = scores.shape[1]
+        self.frm_cnt += scores.shape[1]  # count total frames
+        if self.scores is None:
+            self.scores = scores  # the first calculation
+        else:
+            self.scores = torch.cat((self.scores, scores), dim=1)
 
     def PopDataBufTillFrame(self, frame_idx: int) -> None:  # need check again
         while self.data_buf_start_frame < frame_idx:
             if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
                 self.data_buf_start_frame += 1
-                self.data_buf = self.waveform[0][self.data_buf_start_frame * int(
+                self.data_buf = self.data_buf_all[self.data_buf_start_frame * int(
                     self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
-                # for i in range(0, int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)):
-                #     self.data_buf.popleft()
-                # self.data_buf_start_frame += 1
 
     def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
                            last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
@@ -301,8 +310,9 @@
                                self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
             expected_sample_number += int(extra_sample)
         if end_point_is_sent_end:
-            # expected_sample_number = max(expected_sample_number, len(self.data_buf))
-            pass
+            expected_sample_number = max(expected_sample_number, len(self.data_buf))
+        if len(self.data_buf) < expected_sample_number:
+            print('error in calling pop data_buf\n')
 
         if len(self.output_data_buf) == 0 or first_frm_is_start_point:
             self.output_data_buf.append(E2EVadSpeechBufWithDoa())
@@ -312,15 +322,18 @@
             self.output_data_buf[-1].doa = 0
         cur_seg = self.output_data_buf[-1]
         if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
-            print('warning')
+            print('warning\n')
         out_pos = len(cur_seg.buffer)  # cur_seg.buff鐜板湪娌″仛浠讳綍鎿嶄綔
         data_to_pop = 0
         if end_point_is_sent_end:
             data_to_pop = expected_sample_number
         else:
             data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
-        # if data_to_pop > len(self.data_buf_)
-        #   pass
+        if data_to_pop > len(self.data_buf):
+            print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
+            data_to_pop = len(self.data_buf)
+            expected_sample_number = len(self.data_buf)
+
         cur_seg.doa = 0
         for sample_cpy_out in range(0, data_to_pop):
             # cur_seg.buffer[out_pos ++] = data_buf_.back();
@@ -329,7 +342,7 @@
             # cur_seg.buffer[out_pos++] = data_buf_.back()
             out_pos += 1
         if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
-            print('warning')
+            print('Something wrong with the VAD algorithm\n')
         self.data_buf_start_frame += frm_cnt
         cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
         if first_frm_is_start_point:
@@ -346,14 +359,13 @@
 
     def OnVoiceDetected(self, valid_frame: int) -> None:
         self.latest_confirmed_speech_frame = valid_frame
-        if True:  # is_new_api_enable_ = True
-            self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
+        self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
 
     def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
         if self.vad_opts.do_start_point_detection:
             pass
         if self.confirmed_start_frame != -1:
-            print('warning')
+            print('not reset vad properly\n')
         else:
             self.confirmed_start_frame = start_frame
 
@@ -366,7 +378,7 @@
         if self.vad_opts.do_end_point_detection:
             pass
         if self.confirmed_end_frame != -1:
-            print('warning')
+            print('not reset vad properly\n')
         else:
             self.confirmed_end_frame = end_frame
         if not fake_result:
@@ -406,7 +418,6 @@
             sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
             sum_score = sum(sil_pdf_scores)
             noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
-            # total_score = sum(self.scores[0][t][:])
             total_score = 1.0
             sum_score = total_score - sum_score
         speech_prob = math.log(sum_score)
@@ -433,25 +444,59 @@
 
         return frame_state
 
-    def forward(self, feats: torch.Tensor, feats_lengths: int, waveform: torch.tensor) -> List[List[List[int]]]:
-        self.AllResetDetection()
+    def forward(self, feats: torch.Tensor, waveform: torch.tensor, is_final_send: bool = False) -> List[List[List[int]]]:
         self.waveform = waveform  # compute decibel for each frame
         self.ComputeDecibel()
-        self.ComputeScores(feats, feats_lengths)
-        assert len(self.decibel) == len(self.scores[0])  # 淇濊瘉甯ф暟涓�鑷�
-        self.DetectLastFrames()
+        self.ComputeScores(feats)
+        if not is_final_send:
+            self.DetectCommonFrames()
+        else:
+            if self.streaming:
+                self.DetectLastFrames()
+            else:
+                self.AllResetDetection()
+                self.DetectAllFrames()  # offline decode and is_final_send == True
         segments = []
         for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
             segment_batch = []
-            for i in range(0, len(self.output_data_buf)):
-                segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
-                segment_batch.append(segment)
-            segments.append(segment_batch)
+            if len(self.output_data_buf) > 0:
+                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
+                    if self.output_data_buf[i].contain_seg_start_point and self.output_data_buf[
+                        i].contain_seg_end_point:
+                        segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
+                        segment_batch.append(segment)
+                        self.output_data_buf_offset += 1  # need update this parameter
+            if segment_batch:
+                segments.append(segment_batch)
+
         return segments
 
+    def DetectCommonFrames(self) -> int:
+        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+            return 0
+        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+            frame_state = FrameState.kFrameStateInvalid
+            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+            self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+
+        return 0
+
     def DetectLastFrames(self) -> int:
         if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
             return 0
+        for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
+            frame_state = FrameState.kFrameStateInvalid
+            frame_state = self.GetFrameState(self.frm_cnt - 1 - i)
+            if i != 0:
+                self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
+            else:
+                self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
+
+        return 0
+
+    def DetectAllFrames(self) -> int:
+        if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
+            return 0
         if self.vad_opts.nn_eval_block_size != self.vad_opts.dcd_block_size:
             frame_state = FrameState.kFrameStateInvalid
             for t in range(0, self.frm_cnt):

--
Gitblit v1.9.1