| File was renamed from funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | # @Author: SWHL |
| | | # @Contact: liekkaskono@163.com |
| | | import os.path |
| | | import traceback |
| | | from pathlib import Path |
| | | from typing import List, Union, Tuple |
| | |
| | | from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, |
| | | OrtInferSession, TokenIDConverter, WavFrontend, get_logger, |
| | | read_yaml) |
| | | from .postprocess_utils import sentence_postprocess |
| | | |
| | | logging = get_logger() |
| | | |
| | | |
| | | class RapidParaformer(): |
| | | def __init__(self, config_path: Union[str, Path]) -> None: |
| | | if not Path(config_path).exists(): |
| | | raise FileNotFoundError(f'{config_path} does not exist.') |
| | | class Paraformer(): |
| | | def __init__(self, model_dir: Union[str, Path]=None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int]="-1", |
| | | ): |
| | | |
| | | if not Path(model_dir).exists(): |
| | | raise FileNotFoundError(f'{model_dir} does not exist.') |
| | | |
| | | config = read_yaml(config_path) |
| | | model_file = os.path.join(model_dir, 'model.onnx') |
| | | config_file = os.path.join(model_dir, 'config.yaml') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | config = read_yaml(config_file) |
| | | |
| | | self.converter = TokenIDConverter(**config['TokenIDConverter']) |
| | | self.tokenizer = CharTokenizer(**config['CharTokenizer']) |
| | | self.converter = TokenIDConverter(config['token_list']) |
| | | self.tokenizer = CharTokenizer() |
| | | self.frontend = WavFrontend( |
| | | cmvn_file=config['WavFrontend']['cmvn_file'], |
| | | **config['WavFrontend']['frontend_conf'] |
| | | cmvn_file=cmvn_file, |
| | | **config['frontend_conf'] |
| | | ) |
| | | self.ort_infer = OrtInferSession(config['Model']) |
| | | self.batch_size = config['Model']['batch_size'] |
| | | self.ort_infer = OrtInferSession(model_file, device_id) |
| | | self.batch_size = batch_size |
| | | |
| | | def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List: |
| | | waveform_list = self.load_data(wav_content) |
| | |
| | | |
| | | # Change integer-ids to tokens |
| | | token = self.converter.ids2tokens(token_int) |
| | | text = self.tokenizer.tokens2text(token) |
| | | token = token[:valid_token_num-1] |
| | | texts = sentence_postprocess(token) |
| | | text = texts[0] |
| | | # text = self.tokenizer.tokens2text(token) |
| | | return text[:valid_token_num-1] |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | project_dir = Path(__file__).resolve().parent.parent |
| | | cfg_path = project_dir / 'resources' / 'config.yaml' |
| | | paraformer = RapidParaformer(cfg_path) |
| | | model_dir = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" |
| | | model = Paraformer(model_dir) |
| | | |
| | | wav_file = '0478_00017.wav' |
| | | for i in range(1000): |
| | | result = paraformer(wav_file) |
| | | print(result) |
| | | wav_file = os.path.join(model_dir, 'example/asr_example.wav') |
| | | result = model(wav_file) |
| | | print(result) |
| | | |