| | |
| | | |
| | | import functools |
| | | import logging |
| | | import pickle |
| | | from pathlib import Path |
| | | from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union |
| | | |
| | | import re |
| | | import numpy as np |
| | | import yaml |
| | | |
| | | try: |
| | | from onnxruntime import (GraphOptimizationLevel, InferenceSession, |
| | | SessionOptions, get_available_providers, get_device) |
| | | from onnxruntime import ( |
| | | GraphOptimizationLevel, |
| | | InferenceSession, |
| | | SessionOptions, |
| | | get_available_providers, |
| | | get_device, |
| | | ) |
| | | except: |
| | | print("please pip3 install onnxruntime") |
| | | import jieba |
| | |
| | | |
| | | return pad |
| | | |
| | | ''' |
| | | |
| | | """ |
| | | def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): |
| | | if length_dim == 0: |
| | | raise ValueError("length_dim cannot be 0: {}".format(length_dim)) |
| | |
| | | ) |
| | | mask = mask[ind].expand_as(xs).to(xs.device) |
| | | return mask |
| | | ''' |
| | | """ |
| | | |
| | | class TokenIDConverter(): |
| | | def __init__(self, token_list: Union[List, str], |
| | | ): |
| | | |
| | | class TokenIDConverter: |
| | | def __init__( |
| | | self, |
| | | token_list: Union[List, str], |
| | | ): |
| | | |
| | | self.token_list = token_list |
| | | self.unk_symbol = token_list[-1] |
| | | self.token2id = {v: i for i, v in enumerate(self.token_list)} |
| | | self.unk_id = self.token2id[self.unk_symbol] |
| | | |
| | | |
| | | def get_num_vocabulary_size(self) -> int: |
| | | return len(self.token_list) |
| | | |
| | | def ids2tokens(self, |
| | | integers: Union[np.ndarray, Iterable[int]]) -> List[str]: |
| | | def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: |
| | | if isinstance(integers, np.ndarray) and integers.ndim != 1: |
| | | raise TokenIDConverterError( |
| | | f"Must be 1 dim ndarray, but got {integers.ndim}") |
| | | raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}") |
| | | return [self.token_list[i] for i in integers] |
| | | |
| | | def tokens2ids(self, tokens: Iterable[str]) -> List[int]: |
| | |
| | | return [self.token2id.get(i, self.unk_id) for i in tokens] |
| | | |
| | | |
| | | class CharTokenizer(): |
| | | class CharTokenizer: |
| | | def __init__( |
| | | self, |
| | | symbol_value: Union[Path, str, Iterable[str]] = None, |
| | |
| | | if line.startswith(w): |
| | | if not self.remove_non_linguistic_symbols: |
| | | tokens.append(line[: len(w)]) |
| | | line = line[len(w):] |
| | | line = line[len(w) :] |
| | | break |
| | | else: |
| | | t = line[0] |
| | |
| | | f'non_linguistic_symbols="{self.non_linguistic_symbols}"' |
| | | f")" |
| | | ) |
| | | |
| | | |
| | | |
| | | class Hypothesis(NamedTuple): |
| | |
| | | pass |
| | | |
| | | |
| | | class OrtInferSession(): |
| | | class OrtInferSession: |
| | | def __init__(self, model_file, device_id=-1, intra_op_num_threads=4): |
| | | device_id = str(device_id) |
| | | sess_opt = SessionOptions() |
| | |
| | | sess_opt.enable_cpu_mem_arena = False |
| | | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL |
| | | |
| | | cuda_ep = 'CUDAExecutionProvider' |
| | | cuda_ep = "CUDAExecutionProvider" |
| | | cuda_provider_options = { |
| | | "device_id": device_id, |
| | | "arena_extend_strategy": "kNextPowerOfTwo", |
| | | "cudnn_conv_algo_search": "EXHAUSTIVE", |
| | | "do_copy_in_default_stream": "true", |
| | | } |
| | | cpu_ep = 'CPUExecutionProvider' |
| | | cpu_ep = "CPUExecutionProvider" |
| | | cpu_provider_options = { |
| | | "arena_extend_strategy": "kSameAsRequested", |
| | | } |
| | | |
| | | EP_list = [] |
| | | if device_id != "-1" and get_device() == 'GPU' \ |
| | | and cuda_ep in get_available_providers(): |
| | | if device_id != "-1" and get_device() == "GPU" and cuda_ep in get_available_providers(): |
| | | EP_list = [(cuda_ep, cuda_provider_options)] |
| | | EP_list.append((cpu_ep, cpu_provider_options)) |
| | | |
| | | self._verify_model(model_file) |
| | | self.session = InferenceSession(model_file, |
| | | sess_options=sess_opt, |
| | | providers=EP_list) |
| | | self.session = InferenceSession(model_file, sess_options=sess_opt, providers=EP_list) |
| | | |
| | | if device_id != "-1" and cuda_ep not in self.session.get_providers(): |
| | | warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n' |
| | | 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, ' |
| | | 'you can check their relations from the offical web site: ' |
| | | 'https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html', |
| | | RuntimeWarning) |
| | | warnings.warn( |
| | | f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n" |
| | | "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, " |
| | | "you can check their relations from the offical web site: " |
| | | "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", |
| | | RuntimeWarning, |
| | | ) |
| | | |
| | | def __call__(self, |
| | | input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray: |
| | | def __call__(self, input_content: List[Union[np.ndarray, np.ndarray]], run_options = None) -> np.ndarray: |
| | | input_dict = dict(zip(self.get_input_names(), input_content)) |
| | | try: |
| | | return self.session.run(self.get_output_names(), input_dict) |
| | | return self.session.run(self.get_output_names(), input_dict, run_options) |
| | | except Exception as e: |
| | | raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e |
| | | raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e |
| | | |
| | | def get_input_names(self, ): |
| | | def get_input_names( |
| | | self, |
| | | ): |
| | | return [v.name for v in self.session.get_inputs()] |
| | | |
| | | def get_output_names(self,): |
| | | def get_output_names( |
| | | self, |
| | | ): |
| | | return [v.name for v in self.session.get_outputs()] |
| | | |
| | | def get_character_list(self, key: str = 'character'): |
| | | def get_character_list(self, key: str = "character"): |
| | | return self.meta_dict[key].splitlines() |
| | | |
| | | def have_key(self, key: str = 'character') -> bool: |
| | | def have_key(self, key: str = "character") -> bool: |
| | | self.meta_dict = self.session.get_modelmeta().custom_metadata_map |
| | | if key in self.meta_dict.keys(): |
| | | return True |
| | |
| | | def _verify_model(model_path): |
| | | model_path = Path(model_path) |
| | | if not model_path.exists(): |
| | | raise FileNotFoundError(f'{model_path} does not exists.') |
| | | raise FileNotFoundError(f"{model_path} does not exists.") |
| | | if not model_path.is_file(): |
| | | raise FileExistsError(f'{model_path} is not a file.') |
| | | raise FileExistsError(f"{model_path} is not a file.") |
| | | |
| | | |
| | | def split_to_mini_sentence(words: list, word_limit: int = 20): |
| | | assert word_limit > 1 |
| | |
| | | length = len(words) |
| | | sentence_len = length // word_limit |
| | | for i in range(sentence_len): |
| | | sentences.append(words[i * word_limit:(i + 1) * word_limit]) |
| | | sentences.append(words[i * word_limit : (i + 1) * word_limit]) |
| | | if length % word_limit > 0: |
| | | sentences.append(words[sentence_len * word_limit:]) |
| | | sentences.append(words[sentence_len * word_limit :]) |
| | | return sentences |
| | | |
| | | |
| | | def code_mix_split_words(text: str): |
| | | words = [] |
| | |
| | | words.append(current_word) |
| | | return words |
| | | |
| | | def isEnglish(text:str): |
| | | if re.search('^[a-zA-Z\']+$', text): |
| | | |
| | | def isEnglish(text: str): |
| | | if re.search("^[a-zA-Z']+$", text): |
| | | return True |
| | | else: |
| | | return False |
| | | |
| | | |
| | | def join_chinese_and_english(input_list): |
| | | line = '' |
| | | line = "" |
| | | for token in input_list: |
| | | if isEnglish(token): |
| | | line = line + ' ' + token |
| | | line = line + " " + token |
| | | else: |
| | | line = line + token |
| | | |
| | | line = line.strip() |
| | | return line |
| | | |
| | | |
| | | def code_mix_split_words_jieba(seg_dict_file: str): |
| | | jieba.load_userdict(seg_dict_file) |
| | |
| | | token_list_tmp = [] |
| | | language_flag = None |
| | | for token in input_list: |
| | | if isEnglish(token) and language_flag == 'Chinese': |
| | | if isEnglish(token) and language_flag == "Chinese": |
| | | token_list_all.append(token_list_tmp) |
| | | langauge_list.append('Chinese') |
| | | langauge_list.append("Chinese") |
| | | token_list_tmp = [] |
| | | elif not isEnglish(token) and language_flag == 'English': |
| | | elif not isEnglish(token) and language_flag == "English": |
| | | token_list_all.append(token_list_tmp) |
| | | langauge_list.append('English') |
| | | langauge_list.append("English") |
| | | token_list_tmp = [] |
| | | |
| | | |
| | | token_list_tmp.append(token) |
| | | |
| | | |
| | | if isEnglish(token): |
| | | language_flag = 'English' |
| | | language_flag = "English" |
| | | else: |
| | | language_flag = 'Chinese' |
| | | |
| | | language_flag = "Chinese" |
| | | |
| | | if token_list_tmp: |
| | | token_list_all.append(token_list_tmp) |
| | | langauge_list.append(language_flag) |
| | | |
| | | |
| | | result_list = [] |
| | | for token_list_tmp, language_flag in zip(token_list_all, langauge_list): |
| | | if language_flag == 'English': |
| | | if language_flag == "English": |
| | | result_list.extend(token_list_tmp) |
| | | else: |
| | | seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False) |
| | | result_list.extend(seg_list) |
| | | |
| | | |
| | | return result_list |
| | | |
| | | return _fn |
| | | |
| | | |
| | | def read_yaml(yaml_path: Union[str, Path]) -> Dict: |
| | | if not Path(yaml_path).exists(): |
| | | raise FileExistsError(f'The {yaml_path} does not exist.') |
| | | raise FileExistsError(f"The {yaml_path} does not exist.") |
| | | |
| | | with open(str(yaml_path), 'rb') as f: |
| | | with open(str(yaml_path), "rb") as f: |
| | | data = yaml.load(f, Loader=yaml.Loader) |
| | | return data |
| | | |
| | | |
| | | @functools.lru_cache() |
| | | def get_logger(name='funasr_onnx'): |
| | | def get_logger(name="funasr_onnx"): |
| | | """Initialize and get a logger by name. |
| | | If the logger has not been initialized, this method will initialize the |
| | | logger by adding one or two handlers, otherwise the initialized logger will |
| | |
| | | return logger |
| | | |
| | | formatter = logging.Formatter( |
| | | '[%(asctime)s] %(name)s %(levelname)s: %(message)s', |
| | | datefmt="%Y/%m/%d %H:%M:%S") |
| | | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S" |
| | | ) |
| | | |
| | | sh = logging.StreamHandler() |
| | | sh.setFormatter(formatter) |