| File was renamed from funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py |
| | |
| | | from typeguard import check_argument_types |
| | | |
| | | from .kaldifeat import compute_fbank_feats |
| | | import warnings |
| | | |
| | | root_dir = Path(__file__).resolve().parent |
| | | |
| | |
| | | |
| | | |
| | | class TokenIDConverter(): |
| | | def __init__(self, token_path: Union[Path, str], |
| | | def __init__(self, token_list: Union[Path, str], |
| | | unk_symbol: str = "<unk>",): |
| | | check_argument_types() |
| | | |
| | | self.token_list = self.load_token(token_path) |
| | | self.unk_symbol = unk_symbol |
| | | # self.token_list = self.load_token(token_path) |
| | | self.token_list = token_list |
| | | self.unk_symbol = token_list[-1] |
| | | |
| | | @staticmethod |
| | | def load_token(file_path: Union[Path, str]) -> List: |
| | | if not Path(file_path).exists(): |
| | | raise TokenIDConverterError(f'The {file_path} does not exist.') |
| | | |
| | | with open(str(file_path), 'rb') as f: |
| | | token_list = pickle.load(f) |
| | | |
| | | if len(token_list) != len(set(token_list)): |
| | | raise TokenIDConverterError('The Token exists duplicated symbol.') |
| | | return token_list |
| | | # @staticmethod |
| | | # def load_token(file_path: Union[Path, str]) -> List: |
| | | # if not Path(file_path).exists(): |
| | | # raise TokenIDConverterError(f'The {file_path} does not exist.') |
| | | # |
| | | # with open(str(file_path), 'rb') as f: |
| | | # token_list = pickle.load(f) |
| | | # |
| | | # if len(token_list) != len(set(token_list)): |
| | | # raise TokenIDConverterError('The Token exists duplicated symbol.') |
| | | # return token_list |
| | | |
| | | def get_num_vocabulary_size(self) -> int: |
| | | return len(self.token_list) |
| | |
| | | |
| | | |
| | | class OrtInferSession(): |
| | | def __init__(self, config): |
| | | def __init__(self, model_file, device_id=-1): |
| | | sess_opt = SessionOptions() |
| | | sess_opt.log_severity_level = 4 |
| | | sess_opt.enable_cpu_mem_arena = False |
| | | sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL |
| | | |
| | | cuda_ep = 'CUDAExecutionProvider' |
| | | cuda_provider_options = { |
| | | "device_id": device_id, |
| | | "arena_extend_strategy": "kNextPowerOfTwo", |
| | | "cudnn_conv_algo_search": "EXHAUSTIVE", |
| | | "do_copy_in_default_stream": "true", |
| | | } |
| | | cpu_ep = 'CPUExecutionProvider' |
| | | cpu_provider_options = { |
| | | "arena_extend_strategy": "kSameAsRequested", |
| | | } |
| | | |
| | | EP_list = [] |
| | | if config['use_cuda'] and get_device() == 'GPU' \ |
| | | if device_id != -1 and get_device() == 'GPU' \ |
| | | and cuda_ep in get_available_providers(): |
| | | EP_list = [(cuda_ep, config[cuda_ep])] |
| | | EP_list = [(cuda_ep, cuda_provider_options)] |
| | | EP_list.append((cpu_ep, cpu_provider_options)) |
| | | |
| | | config['model_path'] = config['model_path'] |
| | | self._verify_model(config['model_path']) |
| | | self.session = InferenceSession(config['model_path'], |
| | | self._verify_model(model_file) |
| | | self.session = InferenceSession(model_file, |
| | | sess_options=sess_opt, |
| | | providers=EP_list) |
| | | |
| | | if config['use_cuda'] and cuda_ep not in self.session.get_providers(): |
| | | if device_id != -1 and cuda_ep not in self.session.get_providers(): |
| | | warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n' |
| | | 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, ' |
| | | 'you can check their relations from the offical web site: ' |