From d0cd484fdc21c06b8bc892bb2ab1c2a25fb1da8a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 31 三月 2023 15:05:37 +0800
Subject: [PATCH] export
---
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 37 ++++++++++++-------------------------
1 files changed, 12 insertions(+), 25 deletions(-)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 034475c..949172e 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -8,12 +8,11 @@
from .utils.utils import (ONNXRuntimeError,
OrtInferSession, get_logger,
read_yaml)
-from .utils.preprocessor import CodeMixTokenizerCommonPreprocessor
-from .utils.utils import split_to_mini_sentence
+from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
logging = get_logger()
-class TargetDelayTransformer():
+class CT_Transformer():
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
@@ -30,6 +29,7 @@
config_file = os.path.join(model_dir, 'punc.yaml')
config = read_yaml(config_file)
+ self.converter = TokenIDConverter(config['token_list'])
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
self.batch_size = 1
self.punc_list = config['punc_list']
@@ -41,23 +41,12 @@
self.punc_list[i] = "锛�"
elif self.punc_list[i] == "銆�":
self.period = i
- self.preprocessor = CodeMixTokenizerCommonPreprocessor(
- train=False,
- token_type=config['token_type'],
- token_list=config['token_list'],
- bpemodel=config['bpemodel'],
- text_cleaner=config['cleaner'],
- g2p_type=config['g2p'],
- text_name="text",
- non_linguistic_symbols=config['non_linguistic_symbols'],
- )
def __call__(self, text: Union[list, str], split_size=20):
- data = {"text": text}
- result = self.preprocessor(data=data, uid="12938712838719")
- split_text = self.preprocessor.pop_split_text_data(result)
+ split_text = code_mix_split_words(text)
+ split_text_id = self.converter.tokens2ids(split_text)
mini_sentences = split_to_mini_sentence(split_text, split_size)
- mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
+ mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
cache_sent_id = []
@@ -68,17 +57,16 @@
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
- mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+ mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
data = {
- "text": mini_sentence_id[None,:].astype(np.int64),
+ "text": mini_sentence_id[None,:],
"text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
}
try:
outputs = self.infer(data['text'], data['text_lengths'])
y = outputs[0]
- _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
- punctuations = indices
- assert punctuations.size()[0] == len(mini_sentence)
+ punctuations = np.argmax(y,axis=-1)[0]
+ assert punctuations.size == len(mini_sentence)
except ONNXRuntimeError:
logging.warning("error")
@@ -98,12 +86,11 @@
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
cache_sent = mini_sentence[sentenceEnd + 1:]
- cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1:].tolist()
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
- punctuations_np = punctuations.cpu().numpy()
- new_mini_sentence_punc += [int(x) for x in punctuations_np]
+ new_mini_sentence_punc += [int(x) for x in punctuations]
words_with_punc = []
for i in range(len(mini_sentence)):
if i > 0:
--
Gitblit v1.9.1