Kyle He
2025-08-14 82a07e2f6ec60aa25a3931e9ee0d99ead642484a
runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -2,16 +2,21 @@
import functools
import logging
import pickle
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import re
import numpy as np
import yaml
try:
    from onnxruntime import (GraphOptimizationLevel, InferenceSession,
                             SessionOptions, get_available_providers, get_device)
    from onnxruntime import (
        GraphOptimizationLevel,
        InferenceSession,
        SessionOptions,
        get_available_providers,
        get_device,
    )
except:
    print("please pip3 install onnxruntime")
import jieba
@@ -34,7 +39,8 @@
    return pad
'''
"""
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
    if length_dim == 0:
        raise ValueError("length_dim cannot be 0: {}".format(length_dim))
@@ -67,26 +73,26 @@
        )
        mask = mask[ind].expand_as(xs).to(xs.device)
    return mask
'''
"""
class TokenIDConverter():
    def __init__(self, token_list: Union[List, str],
                 ):
class TokenIDConverter:
    def __init__(
        self,
        token_list: Union[List, str],
    ):
        self.token_list = token_list
        self.unk_symbol = token_list[-1]
        self.token2id = {v: i for i, v in enumerate(self.token_list)}
        self.unk_id = self.token2id[self.unk_symbol]
    def get_num_vocabulary_size(self) -> int:
        return len(self.token_list)
    def ids2tokens(self,
                   integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
    def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
        if isinstance(integers, np.ndarray) and integers.ndim != 1:
            raise TokenIDConverterError(
                f"Must be 1 dim ndarray, but got {integers.ndim}")
            raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
        return [self.token_list[i] for i in integers]
    def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
@@ -94,7 +100,7 @@
        return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer():
class CharTokenizer:
    def __init__(
        self,
        symbol_value: Union[Path, str, Iterable[str]] = None,
@@ -129,7 +135,7 @@
                if line.startswith(w):
                    if not self.remove_non_linguistic_symbols:
                        tokens.append(line[: len(w)])
                    line = line[len(w):]
                    line = line[len(w) :]
                    break
            else:
                t = line[0]
@@ -150,7 +156,6 @@
            f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
            f")"
        )
class Hypothesis(NamedTuple):
@@ -178,7 +183,7 @@
    pass
class OrtInferSession():
class OrtInferSession:
    def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
        device_id = str(device_id)
        sess_opt = SessionOptions()
@@ -187,54 +192,56 @@
        sess_opt.enable_cpu_mem_arena = False
        sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
        cuda_ep = 'CUDAExecutionProvider'
        cuda_ep = "CUDAExecutionProvider"
        cuda_provider_options = {
            "device_id": device_id,
            "arena_extend_strategy": "kNextPowerOfTwo",
            "cudnn_conv_algo_search": "EXHAUSTIVE",
            "do_copy_in_default_stream": "true",
        }
        cpu_ep = 'CPUExecutionProvider'
        cpu_ep = "CPUExecutionProvider"
        cpu_provider_options = {
            "arena_extend_strategy": "kSameAsRequested",
        }
        EP_list = []
        if device_id != "-1" and get_device() == 'GPU' \
                and cuda_ep in get_available_providers():
        if device_id != "-1" and get_device() == "GPU" and cuda_ep in get_available_providers():
            EP_list = [(cuda_ep, cuda_provider_options)]
        EP_list.append((cpu_ep, cpu_provider_options))
        self._verify_model(model_file)
        self.session = InferenceSession(model_file,
                                        sess_options=sess_opt,
                                        providers=EP_list)
        self.session = InferenceSession(model_file, sess_options=sess_opt, providers=EP_list)
        if device_id != "-1" and cuda_ep not in self.session.get_providers():
            warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
                          'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
                          'you can check their relations from the offical web site: '
                          'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html',
                          RuntimeWarning)
            warnings.warn(
                f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
                "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
                "you can check their relations from the offical web site: "
                "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
                RuntimeWarning,
            )
    def __call__(self,
                 input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
    def __call__(self, input_content: List[Union[np.ndarray, np.ndarray]], run_options = None) -> np.ndarray:
        input_dict = dict(zip(self.get_input_names(), input_content))
        try:
            return self.session.run(self.get_output_names(), input_dict)
            return self.session.run(self.get_output_names(), input_dict, run_options)
        except Exception as e:
            raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
            raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
    def get_input_names(self, ):
    def get_input_names(
        self,
    ):
        return [v.name for v in self.session.get_inputs()]
    def get_output_names(self,):
    def get_output_names(
        self,
    ):
        return [v.name for v in self.session.get_outputs()]
    def get_character_list(self, key: str = 'character'):
    def get_character_list(self, key: str = "character"):
        return self.meta_dict[key].splitlines()
    def have_key(self, key: str = 'character') -> bool:
    def have_key(self, key: str = "character") -> bool:
        self.meta_dict = self.session.get_modelmeta().custom_metadata_map
        if key in self.meta_dict.keys():
            return True
@@ -244,9 +251,10 @@
    def _verify_model(model_path):
        model_path = Path(model_path)
        if not model_path.exists():
            raise FileNotFoundError(f'{model_path} does not exists.')
            raise FileNotFoundError(f"{model_path} does not exists.")
        if not model_path.is_file():
            raise FileExistsError(f'{model_path} is not a file.')
            raise FileExistsError(f"{model_path} is not a file.")
def split_to_mini_sentence(words: list, word_limit: int = 20):
    assert word_limit > 1
@@ -256,10 +264,11 @@
    length = len(words)
    sentence_len = length // word_limit
    for i in range(sentence_len):
        sentences.append(words[i * word_limit:(i + 1) * word_limit])
        sentences.append(words[i * word_limit : (i + 1) * word_limit])
    if length % word_limit > 0:
        sentences.append(words[sentence_len * word_limit:])
        sentences.append(words[sentence_len * word_limit :])
    return sentences
def code_mix_split_words(text: str):
    words = []
@@ -281,22 +290,25 @@
            words.append(current_word)
    return words
def isEnglish(text:str):
    if re.search('^[a-zA-Z\']+$', text):
def isEnglish(text: str):
    if re.search("^[a-zA-Z']+$", text):
        return True
    else:
        return False
def join_chinese_and_english(input_list):
    line = ''
    line = ""
    for token in input_list:
        if isEnglish(token):
            line = line + ' ' + token
            line = line + " " + token
        else:
            line = line + token
    line = line.strip()
    return line
def code_mix_split_words_jieba(seg_dict_file: str):
    jieba.load_userdict(seg_dict_file)
@@ -308,48 +320,50 @@
        token_list_tmp = []
        language_flag = None
        for token in input_list:
            if isEnglish(token) and language_flag == 'Chinese':
            if isEnglish(token) and language_flag == "Chinese":
                token_list_all.append(token_list_tmp)
                langauge_list.append('Chinese')
                langauge_list.append("Chinese")
                token_list_tmp = []
            elif not isEnglish(token) and language_flag == 'English':
            elif not isEnglish(token) and language_flag == "English":
                token_list_all.append(token_list_tmp)
                langauge_list.append('English')
                langauge_list.append("English")
                token_list_tmp = []
            token_list_tmp.append(token)
            if isEnglish(token):
                language_flag = 'English'
                language_flag = "English"
            else:
                language_flag = 'Chinese'
                language_flag = "Chinese"
        if token_list_tmp:
            token_list_all.append(token_list_tmp)
            langauge_list.append(language_flag)
        result_list = []
        for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
            if language_flag == 'English':
            if language_flag == "English":
                result_list.extend(token_list_tmp)
            else:
                seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False)
                result_list.extend(seg_list)
        return result_list
    return _fn
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
    if not Path(yaml_path).exists():
        raise FileExistsError(f'The {yaml_path} does not exist.')
        raise FileExistsError(f"The {yaml_path} does not exist.")
    with open(str(yaml_path), 'rb') as f:
    with open(str(yaml_path), "rb") as f:
        data = yaml.load(f, Loader=yaml.Loader)
    return data
@functools.lru_cache()
def get_logger(name='funasr_onnx'):
def get_logger(name="funasr_onnx"):
    """Initialize and get a logger by name.
    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
@@ -369,8 +383,8 @@
            return logger
    formatter = logging.Formatter(
        '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
        datefmt="%Y/%m/%d %H:%M:%S")
        "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
    )
    sh = logging.StreamHandler()
    sh.setFormatter(formatter)