游雁
2023-02-13 865ae89f0a713f70dda16859638b25e7350275ec
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
File was renamed from funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py
@@ -1,6 +1,7 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import os.path
import traceback
from pathlib import Path
from typing import List, Union, Tuple
@@ -11,25 +12,33 @@
from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                    OrtInferSession, TokenIDConverter, WavFrontend, get_logger,
                    read_yaml)
from .postprocess_utils import sentence_postprocess
logging = get_logger()
class RapidParaformer():
    def __init__(self, config_path: Union[str, Path]) -> None:
        if not Path(config_path).exists():
            raise FileNotFoundError(f'{config_path} does not exist.')
class Paraformer():
    def __init__(self, model_dir: Union[str, Path]=None,
                 batch_size: int = 1,
                 device_id: Union[str, int]="-1",
                 ):
        if not Path(model_dir).exists():
            raise FileNotFoundError(f'{model_dir} does not exist.')
        config = read_yaml(config_path)
        model_file = os.path.join(model_dir, 'model.onnx')
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        self.converter = TokenIDConverter(**config['TokenIDConverter'])
        self.tokenizer = CharTokenizer(**config['CharTokenizer'])
        self.converter = TokenIDConverter(config['token_list'])
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=config['WavFrontend']['cmvn_file'],
            **config['WavFrontend']['frontend_conf']
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.ort_infer = OrtInferSession(config['Model'])
        self.batch_size = config['Model']['batch_size']
        self.ort_infer = OrtInferSession(model_file, device_id)
        self.batch_size = batch_size
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
        waveform_list = self.load_data(wav_content)
@@ -124,16 +133,19 @@
        # Change integer-ids to tokens
        token = self.converter.ids2tokens(token_int)
        text = self.tokenizer.tokens2text(token)
        token = token[:valid_token_num-1]
        texts = sentence_postprocess(token)
        text = texts[0]
        # text = self.tokenizer.tokens2text(token)
        return text[:valid_token_num-1]
if __name__ == '__main__':
    project_dir = Path(__file__).resolve().parent.parent
    cfg_path = project_dir / 'resources' / 'config.yaml'
    paraformer = RapidParaformer(cfg_path)
    model_dir = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
    model = Paraformer(model_dir)
    wav_file = '0478_00017.wav'
    for i in range(1000):
        result = paraformer(wav_file)
        print(result)
    wav_file = os.path.join(model_dir, 'example/asr_example.wav')
    result = model(wav_file)
    print(result)