From 7df8452a8520cf3e4609114b47510371b2c621a2 Mon Sep 17 00:00:00 2001
From: 九耳 <mengzhe.cmz@alibaba-inc.com>
Date: 星期四, 30 三月 2023 16:15:49 +0800
Subject: [PATCH] fix

---
 funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py |    5 ++---
 1 files changed, 2 insertions(+), 3 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 034475c..c00a3d7 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -76,9 +76,8 @@
             try:
                 outputs = self.infer(data['text'], data['text_lengths'])
                 y = outputs[0]
-                _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
-                punctuations = indices
-                assert punctuations.size()[0] == len(mini_sentence)
+                punctuations = np.argmax(y,axis=-1)[0]
+                assert punctuations.size == len(mini_sentence)
             except ONNXRuntimeError:
                 logging.warning("error")
 

--
Gitblit v1.9.1