From 602fe75a1f0a8d64ccb6fc4d69ad510872fdfd13 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 17 三月 2023 20:30:40 +0800
Subject: [PATCH] rtf benchmark
---
funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py | 8 +++++---
1 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
index 8e220e0..ec907c0 100644
--- a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
+++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
@@ -147,8 +147,10 @@
class OrtInferSession():
- def __init__(self, model_file, device_id=-1):
+ def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
+ device_id = str(device_id)
sess_opt = SessionOptions()
+ sess_opt.intra_op_num_threads = intra_op_num_threads
sess_opt.log_severity_level = 4
sess_opt.enable_cpu_mem_arena = False
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
@@ -166,7 +168,7 @@
}
EP_list = []
- if device_id != -1 and get_device() == 'GPU' \
+ if device_id != "-1" and get_device() == 'GPU' \
and cuda_ep in get_available_providers():
EP_list = [(cuda_ep, cuda_provider_options)]
EP_list.append((cpu_ep, cpu_provider_options))
@@ -176,7 +178,7 @@
sess_options=sess_opt,
providers=EP_list)
- if device_id != -1 and cuda_ep not in self.session.get_providers():
+ if device_id != "-1" and cuda_ep not in self.session.get_providers():
warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
'you can check their relations from the offical web site: '
--
Gitblit v1.9.1