From c1d4bd297a4418ef44882079c4845cfe64ed0b21 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 27 四月 2023 19:36:32 +0800
Subject: [PATCH] docs

---
 funasr/bin/punctuation_infer_vadrealtime.py |   42 +++++++++---------------------------------
 1 files changed, 9 insertions(+), 33 deletions(-)

diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
index d6cc153..b2db1bf 100644
--- a/funasr/bin/punctuation_infer_vadrealtime.py
+++ b/funasr/bin/punctuation_infer_vadrealtime.py
@@ -23,7 +23,7 @@
 from funasr.utils import config_argparse
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
-from funasr.punctuation.text_preprocessor import split_to_mini_sentence
+from funasr.datasets.preprocessor import split_to_mini_sentence
 
 
 class Text2Punc:
@@ -69,6 +69,7 @@
             precache = "".join(cache)
         else:
             precache = ""
+            cache = []
         data = {"text": precache + text}
         result = self.preprocessor(data=data, uid="12938712838719")
         split_text = self.preprocessor.pop_split_text_data(result)
@@ -89,7 +90,7 @@
             data = {
                 "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
                 "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
-                "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
+                "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
             }
             data = to_device(data, self.device)
             y, _ = self.wrapped_model(**data)
@@ -202,10 +203,8 @@
     **kwargs,
 ):
     assert check_argument_types()
-    logging.basicConfig(
-        level=log_level,
-        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
-    )
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
 
     if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
@@ -225,7 +224,7 @@
     ):
         results = []
         split_size = 10
-
+        cache_in = param_dict["cache"]
         if raw_inputs != None:
             line = raw_inputs.strip()
             key = "demo"
@@ -233,35 +232,12 @@
                 item = {'key': key, 'value': ""}
                 results.append(item)
                 return results
-            #import pdb;pdb.set_trace()
-            result, _, cache = text2punc(line, cache)
-            item = {'key': key, 'value': result, 'cache': cache}
+            result, _, cache = text2punc(line, cache_in)
+            param_dict["cache"] = cache
+            item = {'key': key, 'value': result}
             results.append(item)
             return results
 
-        for inference_text, _, _ in data_path_and_name_and_type:
-            with open(inference_text, "r", encoding="utf-8") as fin:
-                for line in fin:
-                    line = line.strip()
-                    segs = line.split("\t")
-                    if len(segs) != 2:
-                        continue
-                    key = segs[0]
-                    if len(segs[1]) == 0:
-                        continue
-                    result, _ = text2punc(segs[1])
-                    item = {'key': key, 'value': result}
-                    results.append(item)
-        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
-        if output_path != None:
-            output_file_name = "infer.out"
-            Path(output_path).mkdir(parents=True, exist_ok=True)
-            output_file_path = (Path(output_path) / output_file_name).absolute()
-            with open(output_file_path, "w", encoding="utf-8") as fout:
-                for item_i in results:
-                    key_out = item_i["key"]
-                    value_out = item_i["value"]
-                    fout.write(f"{key_out}\t{value_out}\n")
         return results
 
     return _forward

--
Gitblit v1.9.1