From cf2f14345aa2c4f168ee51c200b8081c748980b8 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 12 一月 2024 00:01:25 +0800
Subject: [PATCH] funasr1.0 fsmn-vad streaming

---
 funasr/models/fsmn_vad/model.py |   58 +++++++++-------------------------------------------------
 1 files changed, 9 insertions(+), 49 deletions(-)

diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py
index f6e0488..1ed0773 100644
--- a/funasr/models/fsmn_vad/model.py
+++ b/funasr/models/fsmn_vad/model.py
@@ -333,8 +333,8 @@
                 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
                                 0.000001))
 
-    def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
-        scores = self.encoder(feats, in_cache).to('cpu')  # return B * T * D
+    def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None:
+        scores = self.encoder(feats, cache).to('cpu')  # return B * T * D
         assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
         self.vad_opts.nn_eval_block_size = scores.shape[1]
         self.frm_cnt += scores.shape[1]  # count total frames
@@ -493,14 +493,14 @@
 
         return frame_state
 
-    def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
+    def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
                 is_final: bool = False
                 ):
-        if not in_cache:
+        if not cache:
             self.AllResetDetection()
         self.waveform = waveform  # compute decibel for each frame
         self.ComputeDecibel()
-        self.ComputeScores(feats, in_cache)
+        self.ComputeScores(feats, cache)
         if not is_final:
             self.DetectCommonFrames()
         else:
@@ -521,7 +521,7 @@
         if is_final:
             # reset class variables and clear the dict for the next query
             self.AllResetDetection()
-        return segments, in_cache
+        return segments, cache
 
     def generate(self,
                  data_in,
@@ -561,7 +561,7 @@
         feats = speech
         feats_len = speech_lengths.max().item()
         waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
-        in_cache = kwargs.get("in_cache", {})
+        cache = kwargs.get("cache", {})
         batch_size = kwargs.get("batch_size", 1)
         step = min(feats_len, 6000)
         segments = [[]] * batch_size
@@ -576,11 +576,11 @@
                 "feats": feats[:, t_offset:t_offset + step, :],
                 "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
                 "is_final": is_final,
-                "in_cache": in_cache
+                "cache": cache
             }
 
 
-            segments_part, in_cache = self.forward(**batch)
+            segments_part, cache = self.forward(**batch)
             if segments_part:
                 for batch_num in range(0, batch_size):
                     segments[batch_num] += segments_part[batch_num]
@@ -603,46 +603,6 @@
             results.append(result_i)
  
         return results, meta_data
-
-    def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
-                       is_final: bool = False, max_end_sil: int = 800
-                       ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
-        if not in_cache:
-            self.AllResetDetection()
-        self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
-        self.waveform = waveform  # compute decibel for each frame
-
-        self.ComputeScores(feats, in_cache)
-        self.ComputeDecibel()
-        if not is_final:
-            self.DetectCommonFrames()
-        else:
-            self.DetectLastFrames()
-        segments = []
-        for batch_num in range(0, feats.shape[0]):  # only support batch_size = 1 now
-            segment_batch = []
-            if len(self.output_data_buf) > 0:
-                for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
-                    if not self.output_data_buf[i].contain_seg_start_point:
-                        continue
-                    if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
-                        continue
-                    start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
-                    if self.output_data_buf[i].contain_seg_end_point:
-                        end_ms = self.output_data_buf[i].end_ms
-                        self.next_seg = True
-                        self.output_data_buf_offset += 1
-                    else:
-                        end_ms = -1
-                        self.next_seg = False
-                    segment = [start_ms, end_ms]
-                    segment_batch.append(segment)
-            if segment_batch:
-                segments.append(segment_batch)
-        if is_final:
-            # reset class variables and clear the dict for the next query
-            self.AllResetDetection()
-        return segments, in_cache
 
     def DetectCommonFrames(self) -> int:
         if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:

--
Gitblit v1.9.1