From e33bb15d269bb3e2e41f7a3540d9b92703bb5c50 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 15 三月 2023 10:51:52 +0800
Subject: [PATCH] Merge branch 'main' into dev_aky
---
funasr/models/e2e_vad.py | 40 +++++++++++++++++++++++++++++++++++++++-
1 files changed, 39 insertions(+), 1 deletions(-)
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index b9be89a..2c5673c 100755
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -215,6 +215,7 @@
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
+ self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
@@ -244,6 +245,7 @@
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
+ self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
@@ -441,7 +443,7 @@
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
-
+
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
@@ -470,6 +472,42 @@
self.AllResetDetection()
return segments, in_cache
+ def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
+ is_final: bool = False
+ ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+ self.waveform = waveform # compute decibel for each frame
+ self.ComputeDecibel()
+ self.ComputeScores(feats, in_cache)
+ if not is_final:
+ self.DetectCommonFrames()
+ else:
+ self.DetectLastFrames()
+ segments = []
+ for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
+ segment_batch = []
+ if len(self.output_data_buf) > 0:
+ for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
+ if not self.output_data_buf[i].contain_seg_start_point:
+ continue
+ if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
+ continue
+ start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
+ if self.output_data_buf[i].contain_seg_end_point:
+ end_ms = self.output_data_buf[i].end_ms
+ self.next_seg = True
+ self.output_data_buf_offset += 1
+ else:
+ end_ms = -1
+ self.next_seg = False
+ segment = [start_ms, end_ms]
+ segment_batch.append(segment)
+ if segment_batch:
+ segments.append(segment_batch)
+ if is_final:
+ # reset class variables and clear the dict for the next query
+ self.AllResetDetection()
+ return segments, in_cache
+
def DetectCommonFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
--
Gitblit v1.9.1