From 3d7c0c5d2d1e097ae17dc4275c75050574d1f217 Mon Sep 17 00:00:00 2001
From: mengzhe.cmz <mengzhe.cmz@alibaba-inc.com>
Date: 星期一, 08 五月 2023 12:44:50 +0800
Subject: [PATCH] fix vad realtime space bug
---
funasr/bin/punctuation_infer_vadrealtime.py | 10 ++++------
1 files changed, 4 insertions(+), 6 deletions(-)
diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
index 81f9d7a..01d2420 100644
--- a/funasr/bin/punctuation_infer_vadrealtime.py
+++ b/funasr/bin/punctuation_infer_vadrealtime.py
@@ -70,7 +70,7 @@
else:
precache = ""
cache = []
- data = {"text": precache + text}
+ data = {"text": precache + " " + text}
result = self.preprocessor(data=data, uid="12938712838719")
split_text = self.preprocessor.pop_split_text_data(result)
mini_sentences = split_to_mini_sentence(split_text, split_size)
@@ -90,7 +90,7 @@
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
- "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
+ "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
}
data = to_device(data, self.device)
y, _ = self.wrapped_model(**data)
@@ -203,10 +203,8 @@
**kwargs,
):
assert check_argument_types()
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
--
Gitblit v1.9.1