zhifu gao
2023-03-16 d783b24ba7d8a03dabfa2139fcbf40c216e0ea3d
funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py
@@ -148,7 +148,9 @@
class OrtInferSession():
    def __init__(self, model_file, device_id=-1):
        device_id = str(device_id)
        sess_opt = SessionOptions()
        sess_opt.intra_op_num_threads = 4
        sess_opt.log_severity_level = 4
        sess_opt.enable_cpu_mem_arena = False
        sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
@@ -166,7 +168,7 @@
        }
        EP_list = []
        if device_id != -1 and get_device() == 'GPU' \
        if device_id != "-1" and get_device() == 'GPU' \
                and cuda_ep in get_available_providers():
            EP_list = [(cuda_ep, cuda_provider_options)]
        EP_list.append((cpu_ep, cpu_provider_options))
@@ -176,7 +178,7 @@
                                        sess_options=sess_opt,
                                        providers=EP_list)
        if device_id != -1 and cuda_ep not in self.session.get_providers():
        if device_id != "-1" and cuda_ep not in self.session.get_providers():
            warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
                          'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
                          'you can check their relations from the offical web site: '