From 91027ddab49e5791fc42569b4db9dafca55735e6 Mon Sep 17 00:00:00 2001
From: 凌匀 <ailsa.zly@alibaba-inc.com>
Date: 星期四, 16 二月 2023 22:11:18 +0800
Subject: [PATCH] fix vad results bug
---
funasr/bin/vad_inference.py | 11 +++++------
1 files changed, 5 insertions(+), 6 deletions(-)
diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py
index 1cdb582..b0f8a77 100644
--- a/funasr/bin/vad_inference.py
+++ b/funasr/bin/vad_inference.py
@@ -81,6 +81,7 @@
self.device = device
self.dtype = dtype
self.frontend = frontend
+ self.batch_size = batch_size
@torch.no_grad()
def __call__(
@@ -110,10 +111,9 @@
# segments = self.vad_model(**batch)
# b. Forward Encoder sreaming
- segments = []
- segments_tmp = []
- step = 6000
t_offset = 0
+ step = min(feats_len, 6000)
+ segments = [[]] * self.batch_size
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1:
step = feats_len - t_offset
@@ -129,8 +129,8 @@
batch = to_device(batch, device=self.device)
segments_part = self.vad_model(**batch)
if segments_part:
- segments_tmp += segments_part[0]
- segments.append(segments_tmp)
+ for batch_num in range(0, self.batch_size):
+ segments[batch_num] += segments_part[batch_num]
return segments
@@ -254,7 +254,6 @@
assert all(isinstance(s, str) for s in keys), keys
_bs = len(next(iter(batch.values())))
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
# do vad segment
results = speech2vadsegment(**batch)
--
Gitblit v1.9.1