From f13cfbc18e6b7e37d4e5a515cf18411aa0c56d55 Mon Sep 17 00:00:00 2001
From: 北念 <lzr265946@alibaba-inc.com>
Date: 星期二, 21 二月 2023 17:34:45 +0800
Subject: [PATCH] support hotword parameter passing in the pipeline forward

---
 funasr/bin/vad_inference.py |   15 ++++++---------
 1 files changed, 6 insertions(+), 9 deletions(-)

diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py
index 1cdb582..607f131 100644
--- a/funasr/bin/vad_inference.py
+++ b/funasr/bin/vad_inference.py
@@ -81,6 +81,7 @@
         self.device = device
         self.dtype = dtype
         self.frontend = frontend
+        self.batch_size = batch_size
 
     @torch.no_grad()
     def __call__(
@@ -106,14 +107,11 @@
             feats_len = feats_len.int()
         else:
             raise Exception("Need to extract feats first, please configure frontend configuration")
-        # batch = {"feats": feats, "waveform": speech, "is_final_send": True}
-        # segments = self.vad_model(**batch)
 
-        # b. Forward Encoder sreaming
-        segments = []
-        segments_tmp = []
-        step = 6000
+        # b. Forward Encoder streaming
         t_offset = 0
+        step = min(feats_len, 6000)
+        segments = [[]] * self.batch_size
         for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
             if t_offset + step >= feats_len - 1:
                 step = feats_len - t_offset
@@ -129,8 +127,8 @@
             batch = to_device(batch, device=self.device)
             segments_part = self.vad_model(**batch)
             if segments_part:
-                segments_tmp += segments_part[0]
-        segments.append(segments_tmp)
+                for batch_num in range(0, self.batch_size):
+                    segments[batch_num] += segments_part[batch_num]
         return segments
 
 
@@ -254,7 +252,6 @@
             assert all(isinstance(s, str) for s in keys), keys
             _bs = len(next(iter(batch.values())))
             assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
 
             # do vad segment
             results = speech2vadsegment(**batch)

--
Gitblit v1.9.1