From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 runtime/python/onnxruntime/funasr_onnx/utils/utils.py |  138 +++++++++++++++++++++++++--------------------
 1 files changed, 76 insertions(+), 62 deletions(-)

diff --git a/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
index 768b813..e176fb6 100644
--- a/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
+++ b/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -2,16 +2,21 @@
 
 import functools
 import logging
-import pickle
 from pathlib import Path
 from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
 
 import re
 import numpy as np
 import yaml
+
 try:
-    from onnxruntime import (GraphOptimizationLevel, InferenceSession,
-                             SessionOptions, get_available_providers, get_device)
+    from onnxruntime import (
+        GraphOptimizationLevel,
+        InferenceSession,
+        SessionOptions,
+        get_available_providers,
+        get_device,
+    )
 except:
     print("please pip3 install onnxruntime")
 import jieba
@@ -34,7 +39,8 @@
 
     return pad
 
-'''
+
+"""
 def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
     if length_dim == 0:
         raise ValueError("length_dim cannot be 0: {}".format(length_dim))
@@ -67,26 +73,26 @@
         )
         mask = mask[ind].expand_as(xs).to(xs.device)
     return mask
-'''
+"""
 
-class TokenIDConverter():
-    def __init__(self, token_list: Union[List, str],
-                 ):
+
+class TokenIDConverter:
+    def __init__(
+        self,
+        token_list: Union[List, str],
+    ):
 
         self.token_list = token_list
         self.unk_symbol = token_list[-1]
         self.token2id = {v: i for i, v in enumerate(self.token_list)}
         self.unk_id = self.token2id[self.unk_symbol]
 
-
     def get_num_vocabulary_size(self) -> int:
         return len(self.token_list)
 
-    def ids2tokens(self,
-                   integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
+    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
         if isinstance(integers, np.ndarray) and integers.ndim != 1:
-            raise TokenIDConverterError(
-                f"Must be 1 dim ndarray, but got {integers.ndim}")
+            raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
         return [self.token_list[i] for i in integers]
 
     def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
@@ -94,7 +100,7 @@
         return [self.token2id.get(i, self.unk_id) for i in tokens]
 
 
-class CharTokenizer():
+class CharTokenizer:
     def __init__(
         self,
         symbol_value: Union[Path, str, Iterable[str]] = None,
@@ -129,7 +135,7 @@
                 if line.startswith(w):
                     if not self.remove_non_linguistic_symbols:
                         tokens.append(line[: len(w)])
-                    line = line[len(w):]
+                    line = line[len(w) :]
                     break
             else:
                 t = line[0]
@@ -150,7 +156,6 @@
             f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
             f")"
         )
-
 
 
 class Hypothesis(NamedTuple):
@@ -178,7 +183,7 @@
     pass
 
 
-class OrtInferSession():
+class OrtInferSession:
     def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
         device_id = str(device_id)
         sess_opt = SessionOptions()
@@ -187,54 +192,56 @@
         sess_opt.enable_cpu_mem_arena = False
         sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
 
-        cuda_ep = 'CUDAExecutionProvider'
+        cuda_ep = "CUDAExecutionProvider"
         cuda_provider_options = {
             "device_id": device_id,
             "arena_extend_strategy": "kNextPowerOfTwo",
             "cudnn_conv_algo_search": "EXHAUSTIVE",
             "do_copy_in_default_stream": "true",
         }
-        cpu_ep = 'CPUExecutionProvider'
+        cpu_ep = "CPUExecutionProvider"
         cpu_provider_options = {
             "arena_extend_strategy": "kSameAsRequested",
         }
 
         EP_list = []
-        if device_id != "-1" and get_device() == 'GPU' \
-                and cuda_ep in get_available_providers():
+        if device_id != "-1" and get_device() == "GPU" and cuda_ep in get_available_providers():
             EP_list = [(cuda_ep, cuda_provider_options)]
         EP_list.append((cpu_ep, cpu_provider_options))
 
         self._verify_model(model_file)
-        self.session = InferenceSession(model_file,
-                                        sess_options=sess_opt,
-                                        providers=EP_list)
+        self.session = InferenceSession(model_file, sess_options=sess_opt, providers=EP_list)
 
         if device_id != "-1" and cuda_ep not in self.session.get_providers():
-            warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
-                          'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
-                          'you can check their relations from the offical web site: '
-                          'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html',
-                          RuntimeWarning)
+            warnings.warn(
+                f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
+                "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
+                "you can check their relations from the offical web site: "
+                "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
+                RuntimeWarning,
+            )
 
-    def __call__(self,
-                 input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
+    def __call__(self, input_content: List[Union[np.ndarray, np.ndarray]], run_options = None) -> np.ndarray:
         input_dict = dict(zip(self.get_input_names(), input_content))
         try:
-            return self.session.run(self.get_output_names(), input_dict)
+            return self.session.run(self.get_output_names(), input_dict, run_options)
         except Exception as e:
-            raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
+            raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
 
-    def get_input_names(self, ):
+    def get_input_names(
+        self,
+    ):
         return [v.name for v in self.session.get_inputs()]
 
-    def get_output_names(self,):
+    def get_output_names(
+        self,
+    ):
         return [v.name for v in self.session.get_outputs()]
 
-    def get_character_list(self, key: str = 'character'):
+    def get_character_list(self, key: str = "character"):
         return self.meta_dict[key].splitlines()
 
-    def have_key(self, key: str = 'character') -> bool:
+    def have_key(self, key: str = "character") -> bool:
         self.meta_dict = self.session.get_modelmeta().custom_metadata_map
         if key in self.meta_dict.keys():
             return True
@@ -244,9 +251,10 @@
     def _verify_model(model_path):
         model_path = Path(model_path)
         if not model_path.exists():
-            raise FileNotFoundError(f'{model_path} does not exists.')
+            raise FileNotFoundError(f"{model_path} does not exists.")
         if not model_path.is_file():
-            raise FileExistsError(f'{model_path} is not a file.')
+            raise FileExistsError(f"{model_path} is not a file.")
+
 
 def split_to_mini_sentence(words: list, word_limit: int = 20):
     assert word_limit > 1
@@ -256,10 +264,11 @@
     length = len(words)
     sentence_len = length // word_limit
     for i in range(sentence_len):
-        sentences.append(words[i * word_limit:(i + 1) * word_limit])
+        sentences.append(words[i * word_limit : (i + 1) * word_limit])
     if length % word_limit > 0:
-        sentences.append(words[sentence_len * word_limit:])
+        sentences.append(words[sentence_len * word_limit :])
     return sentences
+
 
 def code_mix_split_words(text: str):
     words = []
@@ -281,22 +290,25 @@
             words.append(current_word)
     return words
 
-def isEnglish(text:str):
-    if re.search('^[a-zA-Z\']+$', text):
+
+def isEnglish(text: str):
+    if re.search("^[a-zA-Z']+$", text):
         return True
     else:
         return False
 
+
 def join_chinese_and_english(input_list):
-    line = ''
+    line = ""
     for token in input_list:
         if isEnglish(token):
-            line = line + ' ' + token
+            line = line + " " + token
         else:
             line = line + token
 
     line = line.strip()
     return line
+
 
 def code_mix_split_words_jieba(seg_dict_file: str):
     jieba.load_userdict(seg_dict_file)
@@ -308,48 +320,50 @@
         token_list_tmp = []
         language_flag = None
         for token in input_list:
-            if isEnglish(token) and language_flag == 'Chinese':
+            if isEnglish(token) and language_flag == "Chinese":
                 token_list_all.append(token_list_tmp)
-                langauge_list.append('Chinese')
+                langauge_list.append("Chinese")
                 token_list_tmp = []
-            elif not isEnglish(token) and language_flag == 'English':
+            elif not isEnglish(token) and language_flag == "English":
                 token_list_all.append(token_list_tmp)
-                langauge_list.append('English')
+                langauge_list.append("English")
                 token_list_tmp = []
-    
+
             token_list_tmp.append(token)
-    
+
             if isEnglish(token):
-                language_flag = 'English'
+                language_flag = "English"
             else:
-                language_flag = 'Chinese'
-    
+                language_flag = "Chinese"
+
         if token_list_tmp:
             token_list_all.append(token_list_tmp)
             langauge_list.append(language_flag)
-    
+
         result_list = []
         for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
-            if language_flag == 'English':
+            if language_flag == "English":
                 result_list.extend(token_list_tmp)
             else:
                 seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False)
                 result_list.extend(seg_list)
-    
+
         return result_list
+
     return _fn
+
 
 def read_yaml(yaml_path: Union[str, Path]) -> Dict:
     if not Path(yaml_path).exists():
-        raise FileExistsError(f'The {yaml_path} does not exist.')
+        raise FileExistsError(f"The {yaml_path} does not exist.")
 
-    with open(str(yaml_path), 'rb') as f:
+    with open(str(yaml_path), "rb") as f:
         data = yaml.load(f, Loader=yaml.Loader)
     return data
 
 
 @functools.lru_cache()
-def get_logger(name='funasr_onnx'):
+def get_logger(name="funasr_onnx"):
     """Initialize and get a logger by name.
     If the logger has not been initialized, this method will initialize the
     logger by adding one or two handlers, otherwise the initialized logger will
@@ -369,8 +383,8 @@
             return logger
 
     formatter = logging.Formatter(
-        '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
-        datefmt="%Y/%m/%d %H:%M:%S")
+        "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
+    )
 
     sh = logging.StreamHandler()
     sh.setFormatter(formatter)

--
Gitblit v1.9.1