From 7bb2dfba0cb98c0eaaa18b2dfbb47a647eac9d58 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 23 二月 2023 11:49:39 +0800
Subject: [PATCH] bugfix
---
funasr/bin/vad_inference.py | 15 ++++++---------
1 files changed, 6 insertions(+), 9 deletions(-)
diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py
index 1cdb582..607f131 100644
--- a/funasr/bin/vad_inference.py
+++ b/funasr/bin/vad_inference.py
@@ -81,6 +81,7 @@
self.device = device
self.dtype = dtype
self.frontend = frontend
+ self.batch_size = batch_size
@torch.no_grad()
def __call__(
@@ -106,14 +107,11 @@
feats_len = feats_len.int()
else:
raise Exception("Need to extract feats first, please configure frontend configuration")
- # batch = {"feats": feats, "waveform": speech, "is_final_send": True}
- # segments = self.vad_model(**batch)
- # b. Forward Encoder sreaming
- segments = []
- segments_tmp = []
- step = 6000
+ # b. Forward Encoder streaming
t_offset = 0
+ step = min(feats_len, 6000)
+ segments = [[]] * self.batch_size
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1:
step = feats_len - t_offset
@@ -129,8 +127,8 @@
batch = to_device(batch, device=self.device)
segments_part = self.vad_model(**batch)
if segments_part:
- segments_tmp += segments_part[0]
- segments.append(segments_tmp)
+ for batch_num in range(0, self.batch_size):
+ segments[batch_num] += segments_part[batch_num]
return segments
@@ -254,7 +252,6 @@
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# do vad segment
results = speech2vadsegment(**batch)
--
Gitblit v1.9.1