From 1d1ef01b4e23630a99a3be7e9d1dce9550a793e9 Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期四, 11 五月 2023 16:26:24 +0800
Subject: [PATCH] Merge branch 'main' into dev_smohan
---
funasr/bin/punctuation_infer_vadrealtime.py | 46 +++++++++++-----------------------------------
1 files changed, 11 insertions(+), 35 deletions(-)
diff --git a/funasr/bin/punctuation_infer_vadrealtime.py b/funasr/bin/punctuation_infer_vadrealtime.py
index d6cc153..0dc01f5 100644
--- a/funasr/bin/punctuation_infer_vadrealtime.py
+++ b/funasr/bin/punctuation_infer_vadrealtime.py
@@ -23,7 +23,7 @@
from funasr.utils import config_argparse
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
-from funasr.punctuation.text_preprocessor import split_to_mini_sentence
+from funasr.datasets.preprocessor import split_to_mini_sentence
class Text2Punc:
@@ -61,7 +61,7 @@
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
)
- print("start decoding!!!")
+
@torch.no_grad()
def __call__(self, text: Union[list, str], cache: list, split_size=20):
@@ -69,7 +69,8 @@
precache = "".join(cache)
else:
precache = ""
- data = {"text": precache + text}
+ cache = []
+ data = {"text": precache + " " + text}
result = self.preprocessor(data=data, uid="12938712838719")
split_text = self.preprocessor.pop_split_text_data(result)
mini_sentences = split_to_mini_sentence(split_text, split_size)
@@ -89,7 +90,7 @@
data = {
"text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0),
"text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')),
- "vad_indexes": torch.from_numpy(np.array([len(cache)-1], dtype='int32')),
+ "vad_indexes": torch.from_numpy(np.array([len(cache)], dtype='int32')),
}
data = to_device(data, self.device)
y, _ = self.wrapped_model(**data)
@@ -202,10 +203,8 @@
**kwargs,
):
assert check_argument_types()
- logging.basicConfig(
- level=log_level,
- format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
- )
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
if ngpu >= 1 and torch.cuda.is_available():
device = "cuda"
@@ -225,7 +224,7 @@
):
results = []
split_size = 10
-
+ cache_in = param_dict["cache"]
if raw_inputs != None:
line = raw_inputs.strip()
key = "demo"
@@ -233,35 +232,12 @@
item = {'key': key, 'value': ""}
results.append(item)
return results
- #import pdb;pdb.set_trace()
- result, _, cache = text2punc(line, cache)
- item = {'key': key, 'value': result, 'cache': cache}
+ result, _, cache = text2punc(line, cache_in)
+ param_dict["cache"] = cache
+ item = {'key': key, 'value': result}
results.append(item)
return results
- for inference_text, _, _ in data_path_and_name_and_type:
- with open(inference_text, "r", encoding="utf-8") as fin:
- for line in fin:
- line = line.strip()
- segs = line.split("\t")
- if len(segs) != 2:
- continue
- key = segs[0]
- if len(segs[1]) == 0:
- continue
- result, _ = text2punc(segs[1])
- item = {'key': key, 'value': result}
- results.append(item)
- output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
- if output_path != None:
- output_file_name = "infer.out"
- Path(output_path).mkdir(parents=True, exist_ok=True)
- output_file_path = (Path(output_path) / output_file_name).absolute()
- with open(output_file_path, "w", encoding="utf-8") as fout:
- for item_i in results:
- key_out = item_i["key"]
- value_out = item_i["value"]
- fout.write(f"{key_out}\t{value_out}\n")
return results
return _forward
--
Gitblit v1.9.1