From 0b15e6ea5cccbea3c590958d60e623800bbe3dfb Mon Sep 17 00:00:00 2001
From: 九耳 <mengzhe.cmz@alibaba-inc.com>
Date: 星期四, 30 三月 2023 16:27:07 +0800
Subject: [PATCH] fix
---
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 18 ++++++++----------
1 files changed, 8 insertions(+), 10 deletions(-)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 8ea4517..3f649bc 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -70,15 +70,14 @@
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
data = {
- "text": mini_sentence_id,
- "text_lengths": len(mini_sentence_id),
+ "text": mini_sentence_id[None,:].astype(np.int64),
+ "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
}
try:
outputs = self.infer(data['text'], data['text_lengths'])
y = outputs[0]
- _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
- punctuations = indices
- assert punctuations.size()[0] == len(mini_sentence)
+ punctuations = np.argmax(y,axis=-1)[0]
+ assert punctuations.size == len(mini_sentence)
except ONNXRuntimeError:
logging.warning("error")
@@ -102,8 +101,7 @@
mini_sentence = mini_sentence[0:sentenceEnd + 1]
punctuations = punctuations[0:sentenceEnd + 1]
- punctuations_np = punctuations.cpu().numpy()
- new_mini_sentence_punc += [int(x) for x in punctuations_np]
+ new_mini_sentence_punc += [int(x) for x in punctuations]
words_with_punc = []
for i in range(len(mini_sentence)):
if i > 0:
@@ -125,8 +123,8 @@
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out
- def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
-
- outputs = self.ort_infer(feats)
+ def infer(self, feats: np.ndarray,
+ feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ outputs = self.ort_infer([feats, feats_len])
return outputs
--
Gitblit v1.9.1