From d2e13af377e953a39544f893acdc065a8e538096 Mon Sep 17 00:00:00 2001
From: smohan-speech <smohan@mail.ustc.edu.cn>
Date: 星期四, 11 五月 2023 17:15:02 +0800
Subject: [PATCH] update run.sh and readme

---
 funasr/bin/punctuation_infer_vadrealtime.py |   12 +++++-------
 1 files changed, 5 insertions(+), 7 deletions(-)

diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
index 81f9d7a..0dc01f5 100644
--- a/funasr/bin/punctuation_infer_vadrealtime.py
+++ b/funasr/bin/punctuation_infer_vadrealtime.py
@@ -61,7 +61,7 @@
             text_name="text",
             non_linguistic_symbols=train_args.non_linguistic_symbols,
         )
-        print("start decoding!!!")
+        
 
     @torch.no_grad()
     def __call__(self, text: Union[list, str], cache: list, split_size=20):
@@ -70,7 +70,7 @@
         else:
             precache = ""
             cache = []
-        data = {"text": precache + text}
+        data = {"text": precache + " " + text}
         result = self.preprocessor(data=data, uid="12938712838719")
         split_text = self.preprocessor.pop_split_text_data(result)
         mini_sentences = split_to_mini_sentence(split_text, split_size)
@@ -90,7 +90,7 @@
             data = {
                 "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
                 "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
-                "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
+                "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
             }
             data = to_device(data, self.device)
             y, _ = self.wrapped_model(**data)
@@ -203,10 +203,8 @@
     **kwargs,
 ):
     assert check_argument_types()
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
 
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"

--
Gitblit v1.9.1