From 865ae89f0a713f70dda16859638b25e7350275ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 13 二月 2023 17:43:01 +0800
Subject: [PATCH] export model

---
 funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py |   49 ++++++++++++++++++++++++++++---------------------
 1 files changed, 28 insertions(+), 21 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py
similarity index 90%
rename from funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py
rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py
index 839adb4..ea3c0b7 100644
--- a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py
+++ b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py
@@ -14,6 +14,7 @@
 from typeguard import check_argument_types
 
 from .kaldifeat import compute_fbank_feats
+import warnings
 
 root_dir = Path(__file__).resolve().parent
 
@@ -21,24 +22,25 @@
 
 
 class TokenIDConverter():
-    def __init__(self, token_path: Union[Path, str],
+    def __init__(self, token_list: Union[Path, str],
                  unk_symbol: str = "<unk>",):
         check_argument_types()
 
-        self.token_list = self.load_token(token_path)
-        self.unk_symbol = unk_symbol
+        # self.token_list = self.load_token(token_path)
+        self.token_list = token_list
+        self.unk_symbol = token_list[-1]
 
-    @staticmethod
-    def load_token(file_path: Union[Path, str]) -> List:
-        if not Path(file_path).exists():
-            raise TokenIDConverterError(f'The {file_path} does not exist.')
-
-        with open(str(file_path), 'rb') as f:
-            token_list = pickle.load(f)
-
-        if len(token_list) != len(set(token_list)):
-            raise TokenIDConverterError('The Token exists duplicated symbol.')
-        return token_list
+    # @staticmethod
+    # def load_token(file_path: Union[Path, str]) -> List:
+    #     if not Path(file_path).exists():
+    #         raise TokenIDConverterError(f'The {file_path} does not exist.')
+    #
+    #     with open(str(file_path), 'rb') as f:
+    #         token_list = pickle.load(f)
+    #
+    #     if len(token_list) != len(set(token_list)):
+    #         raise TokenIDConverterError('The Token exists duplicated symbol.')
+    #     return token_list
 
     def get_num_vocabulary_size(self) -> int:
         return len(self.token_list)
@@ -268,31 +270,36 @@
 
 
 class OrtInferSession():
-    def __init__(self, config):
+    def __init__(self, model_file, device_id=-1):
         sess_opt = SessionOptions()
         sess_opt.log_severity_level = 4
         sess_opt.enable_cpu_mem_arena = False
         sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
 
         cuda_ep = 'CUDAExecutionProvider'
+        cuda_provider_options = {
+            "device_id": device_id,
+            "arena_extend_strategy": "kNextPowerOfTwo",
+            "cudnn_conv_algo_search": "EXHAUSTIVE",
+            "do_copy_in_default_stream": "true",
+        }
         cpu_ep = 'CPUExecutionProvider'
         cpu_provider_options = {
             "arena_extend_strategy": "kSameAsRequested",
         }
 
         EP_list = []
-        if config['use_cuda'] and get_device() == 'GPU' \
+        if device_id != -1 and get_device() == 'GPU' \
                 and cuda_ep in get_available_providers():
-            EP_list = [(cuda_ep, config[cuda_ep])]
+            EP_list = [(cuda_ep, cuda_provider_options)]
         EP_list.append((cpu_ep, cpu_provider_options))
 
-        config['model_path'] = config['model_path']
-        self._verify_model(config['model_path'])
-        self.session = InferenceSession(config['model_path'],
+        self._verify_model(model_file)
+        self.session = InferenceSession(model_file,
                                         sess_options=sess_opt,
                                         providers=EP_list)
 
-        if config['use_cuda'] and cuda_ep not in self.session.get_providers():
+        if device_id != -1 and cuda_ep not in self.session.get_providers():
             warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
                           'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
                           'you can check their relations from the offical web site: '

--
Gitblit v1.9.1