From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
runtime/python/onnxruntime/funasr_onnx/utils/utils.py | 137 +++++++++++++++++++++++++--------------------
1 files changed, 76 insertions(+), 61 deletions(-)
diff --git a/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
index 260a85e..e176fb6 100644
--- a/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
+++ b/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -8,9 +8,15 @@
import re
import numpy as np
import yaml
+
try:
- from onnxruntime import (GraphOptimizationLevel, InferenceSession,
- SessionOptions, get_available_providers, get_device)
+ from onnxruntime import (
+ GraphOptimizationLevel,
+ InferenceSession,
+ SessionOptions,
+ get_available_providers,
+ get_device,
+ )
except:
print("please pip3 install onnxruntime")
import jieba
@@ -33,7 +39,8 @@
return pad
-'''
+
+"""
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
@@ -66,26 +73,26 @@
)
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
-'''
+"""
-class TokenIDConverter():
- def __init__(self, token_list: Union[List, str],
- ):
+
+class TokenIDConverter:
+ def __init__(
+ self,
+ token_list: Union[List, str],
+ ):
self.token_list = token_list
self.unk_symbol = token_list[-1]
self.token2id = {v: i for i, v in enumerate(self.token_list)}
self.unk_id = self.token2id[self.unk_symbol]
-
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
- def ids2tokens(self,
- integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
+ def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
if isinstance(integers, np.ndarray) and integers.ndim != 1:
- raise TokenIDConverterError(
- f"Must be 1 dim ndarray, but got {integers.ndim}")
+ raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
@@ -93,7 +100,7 @@
return [self.token2id.get(i, self.unk_id) for i in tokens]
-class CharTokenizer():
+class CharTokenizer:
def __init__(
self,
symbol_value: Union[Path, str, Iterable[str]] = None,
@@ -128,7 +135,7 @@
if line.startswith(w):
if not self.remove_non_linguistic_symbols:
tokens.append(line[: len(w)])
- line = line[len(w):]
+ line = line[len(w) :]
break
else:
t = line[0]
@@ -149,7 +156,6 @@
f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
f")"
)
-
class Hypothesis(NamedTuple):
@@ -177,7 +183,7 @@
pass
-class OrtInferSession():
+class OrtInferSession:
def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
device_id = str(device_id)
sess_opt = SessionOptions()
@@ -186,54 +192,56 @@
sess_opt.enable_cpu_mem_arena = False
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
- cuda_ep = 'CUDAExecutionProvider'
+ cuda_ep = "CUDAExecutionProvider"
cuda_provider_options = {
"device_id": device_id,
"arena_extend_strategy": "kNextPowerOfTwo",
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": "true",
}
- cpu_ep = 'CPUExecutionProvider'
+ cpu_ep = "CPUExecutionProvider"
cpu_provider_options = {
"arena_extend_strategy": "kSameAsRequested",
}
EP_list = []
- if device_id != "-1" and get_device() == 'GPU' \
- and cuda_ep in get_available_providers():
+ if device_id != "-1" and get_device() == "GPU" and cuda_ep in get_available_providers():
EP_list = [(cuda_ep, cuda_provider_options)]
EP_list.append((cpu_ep, cpu_provider_options))
self._verify_model(model_file)
- self.session = InferenceSession(model_file,
- sess_options=sess_opt,
- providers=EP_list)
+ self.session = InferenceSession(model_file, sess_options=sess_opt, providers=EP_list)
if device_id != "-1" and cuda_ep not in self.session.get_providers():
- warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
- 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
- 'you can check their relations from the offical web site: '
- 'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html',
- RuntimeWarning)
+ warnings.warn(
+ f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
+ "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
+ "you can check their relations from the offical web site: "
+ "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
+ RuntimeWarning,
+ )
- def __call__(self,
- input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
+ def __call__(self, input_content: List[Union[np.ndarray, np.ndarray]], run_options = None) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content))
try:
- return self.session.run(self.get_output_names(), input_dict)
+ return self.session.run(self.get_output_names(), input_dict, run_options)
except Exception as e:
- raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
+ raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
- def get_input_names(self, ):
+ def get_input_names(
+ self,
+ ):
return [v.name for v in self.session.get_inputs()]
- def get_output_names(self,):
+ def get_output_names(
+ self,
+ ):
return [v.name for v in self.session.get_outputs()]
- def get_character_list(self, key: str = 'character'):
+ def get_character_list(self, key: str = "character"):
return self.meta_dict[key].splitlines()
- def have_key(self, key: str = 'character') -> bool:
+ def have_key(self, key: str = "character") -> bool:
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
if key in self.meta_dict.keys():
return True
@@ -243,9 +251,10 @@
def _verify_model(model_path):
model_path = Path(model_path)
if not model_path.exists():
- raise FileNotFoundError(f'{model_path} does not exists.')
+ raise FileNotFoundError(f"{model_path} does not exists.")
if not model_path.is_file():
- raise FileExistsError(f'{model_path} is not a file.')
+ raise FileExistsError(f"{model_path} is not a file.")
+
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
@@ -255,10 +264,11 @@
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
- sentences.append(words[i * word_limit:(i + 1) * word_limit])
+ sentences.append(words[i * word_limit : (i + 1) * word_limit])
if length % word_limit > 0:
- sentences.append(words[sentence_len * word_limit:])
+ sentences.append(words[sentence_len * word_limit :])
return sentences
+
def code_mix_split_words(text: str):
words = []
@@ -280,22 +290,25 @@
words.append(current_word)
return words
-def isEnglish(text:str):
- if re.search('^[a-zA-Z\']+$', text):
+
+def isEnglish(text: str):
+ if re.search("^[a-zA-Z']+$", text):
return True
else:
return False
+
def join_chinese_and_english(input_list):
- line = ''
+ line = ""
for token in input_list:
if isEnglish(token):
- line = line + ' ' + token
+ line = line + " " + token
else:
line = line + token
line = line.strip()
return line
+
def code_mix_split_words_jieba(seg_dict_file: str):
jieba.load_userdict(seg_dict_file)
@@ -307,48 +320,50 @@
token_list_tmp = []
language_flag = None
for token in input_list:
- if isEnglish(token) and language_flag == 'Chinese':
+ if isEnglish(token) and language_flag == "Chinese":
token_list_all.append(token_list_tmp)
- langauge_list.append('Chinese')
+ langauge_list.append("Chinese")
token_list_tmp = []
- elif not isEnglish(token) and language_flag == 'English':
+ elif not isEnglish(token) and language_flag == "English":
token_list_all.append(token_list_tmp)
- langauge_list.append('English')
+ langauge_list.append("English")
token_list_tmp = []
-
+
token_list_tmp.append(token)
-
+
if isEnglish(token):
- language_flag = 'English'
+ language_flag = "English"
else:
- language_flag = 'Chinese'
-
+ language_flag = "Chinese"
+
if token_list_tmp:
token_list_all.append(token_list_tmp)
langauge_list.append(language_flag)
-
+
result_list = []
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
- if language_flag == 'English':
+ if language_flag == "English":
result_list.extend(token_list_tmp)
else:
seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False)
result_list.extend(seg_list)
-
+
return result_list
+
return _fn
+
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():
- raise FileExistsError(f'The {yaml_path} does not exist.')
+ raise FileExistsError(f"The {yaml_path} does not exist.")
- with open(str(yaml_path), 'rb') as f:
+ with open(str(yaml_path), "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
@functools.lru_cache()
-def get_logger(name='funasr_onnx'):
+def get_logger(name="funasr_onnx"):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
@@ -368,8 +383,8 @@
return logger
formatter = logging.Formatter(
- '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
- datefmt="%Y/%m/%d %H:%M:%S")
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
+ )
sh = logging.StreamHandler()
sh.setFormatter(formatter)
--
Gitblit v1.9.1