From 1d4dda939cd8998119d2a9dd4a170e8db657f21e Mon Sep 17 00:00:00 2001
From: 九耳 <mengzhe.cmz@alibaba-inc.com>
Date: 星期四, 30 三月 2023 15:44:37 +0800
Subject: [PATCH] fix

---
 funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py |   10 +++++-----
 1 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 8ea4517..034475c 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -70,8 +70,8 @@
             mini_sentence = cache_sent + mini_sentence
             mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
             data = {
-                "text": mini_sentence_id,
-                "text_lengths": len(mini_sentence_id),
+                "text": mini_sentence_id[None,:].astype(np.int64),
+                "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
             }
             try:
                 outputs = self.infer(data['text'], data['text_lengths'])
@@ -125,8 +125,8 @@
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
         return new_mini_sentence_out, new_mini_sentence_punc_out
 
-    def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]:
-
-        outputs = self.ort_infer(feats)
+    def infer(self, feats: np.ndarray,
+              feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+        outputs = self.ort_infer([feats, feats_len])
         return outputs
 

--
Gitblit v1.9.1