From 4b3c4929988a673b4211c255b701ff63a4365155 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 30 三月 2023 16:06:16 +0800
Subject: [PATCH] Merge branch 'dev_cmz2' of github.com:alibaba-damo-academy/FunASR into dev_cmz2 add
---
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 25 ++++++++++++-------------
1 files changed, 12 insertions(+), 13 deletions(-)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 64ced69..034475c 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -32,8 +32,7 @@
self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
self.batch_size = 1
- self.encoder_conf = config["encoder_conf"]
- self.punc_list = config.punc_list
+ self.punc_list = config['punc_list']
self.period = 0
for i in range(len(self.punc_list)):
if self.punc_list[i] == ",":
@@ -44,13 +43,13 @@
self.period = i
self.preprocessor = CodeMixTokenizerCommonPreprocessor(
train=False,
- token_type=config.token_type,
- token_list=config.token_list,
- bpemodel=config.bpemodel,
- text_cleaner=config.cleaner,
- g2p_type=config.g2p,
+ token_type=config['token_type'],
+ token_list=config['token_list'],
+ bpemodel=config['bpemodel'],
+ text_cleaner=config['cleaner'],
+ g2p_type=config['g2p'],
text_name="text",
- non_linguistic_symbols=config.non_linguistic_symbols,
+ non_linguistic_symbols=config['non_linguistic_symbols'],
)
def __call__(self, text: Union[list, str], split_size=20):
@@ -71,8 +70,8 @@
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
data = {
- "text": mini_sentence_id,
- "text_lengths": len(mini_sentence_id),
+ "text": mini_sentence_id[None,:].astype(np.int64),
+ "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
}
try:
outputs = self.infer(data['text'], data['text_lengths'])
@@ -126,8 +125,8 @@
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out
- def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
-
- outputs = self.ort_infer(feats)
+ def infer(self, feats: np.ndarray,
+ feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ outputs = self.ort_infer([feats, feats_len])
return outputs
--
Gitblit v1.9.1