From f98c4bf6d2bb5202488cd4243efdbca65288c313 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 27 二月 2023 14:26:32 +0800
Subject: [PATCH] onnx export

---
 funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py |    5 +++--
 1 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
index 8e220e0..7943abb 100644
--- a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
+++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
@@ -148,6 +148,7 @@
 
 class OrtInferSession():
     def __init__(self, model_file, device_id=-1):
+        device_id = str(device_id)
         sess_opt = SessionOptions()
         sess_opt.log_severity_level = 4
         sess_opt.enable_cpu_mem_arena = False
@@ -166,7 +167,7 @@
         }
 
         EP_list = []
-        if device_id != -1 and get_device() == 'GPU' \
+        if device_id != "-1" and get_device() == 'GPU' \
                 and cuda_ep in get_available_providers():
             EP_list = [(cuda_ep, cuda_provider_options)]
         EP_list.append((cpu_ep, cpu_provider_options))
@@ -176,7 +177,7 @@
                                         sess_options=sess_opt,
                                         providers=EP_list)
 
-        if device_id != -1 and cuda_ep not in self.session.get_providers():
+        if device_id != "-1" and cuda_ep not in self.session.get_providers():
             warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
                           'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
                           'you can check their relations from the offical web site: '

--
Gitblit v1.9.1