From 5c4bd2b718d5f2935ea6609b5051c4182b8b3c50 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 24 三月 2023 13:46:34 +0800
Subject: [PATCH] Merge pull request #293 from alibaba-damo-academy/dev_zly

---
 egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py  |    2 +-
 funasr/bin/vad_inference_online.py                                  |   10 +++++++---
 funasr/models/e2e_vad.py                                            |    3 ++-
 egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py |    2 +-
 4 files changed, 11 insertions(+), 6 deletions(-)

diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py
index d70ed25..02e919d 100644
--- a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py
@@ -22,7 +22,7 @@
     sample_offset = 0
     
     step = 160 * 10
-    param_dict = {'in_cache': dict()}
+    param_dict = {'in_cache': dict(), 'max_end_sil': 800}
     for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
         if sample_offset + step >= speech_length - 1:
             step = speech_length - sample_offset
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py
index fb56908..a8cc912 100644
--- a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py
@@ -22,7 +22,7 @@
     sample_offset = 0
     
     step = 80 * 10
-    param_dict = {'in_cache': dict()}
+    param_dict = {'in_cache': dict(), 'max_end_sil': 800}
     for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
         if sample_offset + step >= speech_length - 1:
             step = speech_length - sample_offset
diff --git a/funasr/bin/vad_inference_online.py b/funasr/bin/vad_inference_online.py
index dadfd8c..f35e5a1 100644
--- a/funasr/bin/vad_inference_online.py
+++ b/funasr/bin/vad_inference_online.py
@@ -30,7 +30,8 @@
 from funasr.models.frontend.wav_frontend import WavFrontend
 from funasr.bin.vad_inference import Speech2VadSegment
 
-
+header_colors = '\033[95m'
+end_colors = '\033[0m'
 
 
 class Speech2VadSegmentOnline(Speech2VadSegment):
@@ -55,7 +56,7 @@
     @torch.no_grad()
     def __call__(
             self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
-            in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False
+            in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
     ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
         """Inference
 
@@ -86,7 +87,8 @@
                 "feats": feats,
                 "waveform": waveforms,
                 "in_cache": in_cache,
-                "is_final": is_final
+                "is_final": is_final,
+                "max_end_sil": max_end_sil
             }
             # a. To device
             batch = to_device(batch, device=self.device)
@@ -217,6 +219,7 @@
         vad_results = []
         batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
         is_final = param_dict['is_final'] if param_dict is not None else False
+        max_end_sil = param_dict['max_end_sil'] if param_dict is not None else 800
         for keys, batch in loader:
             assert isinstance(batch, dict), type(batch)
             assert all(isinstance(s, str) for s in keys), keys
@@ -224,6 +227,7 @@
             assert len(keys) == _bs, f"{len(keys)} != {_bs}"
             batch['in_cache'] = batch_in_cache
             batch['is_final'] = is_final
+            batch['max_end_sil'] = max_end_sil
 
             # do vad segment
             _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
old mode 100755
new mode 100644
index 2c5673c..e6cd7c0
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -473,8 +473,9 @@
         return segments, in_cache
 
     def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
-                is_final: bool = False
+                is_final: bool = False, max_end_sil: int = 800
                 ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+        self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
         self.waveform = waveform  # compute decibel for each frame
         self.ComputeDecibel()
         self.ComputeScores(feats, in_cache)

--
Gitblit v1.9.1