From 865ae89f0a713f70dda16859638b25e7350275ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 13 二月 2023 17:43:01 +0800
Subject: [PATCH] export model
---
funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py | 48 ++++++++++++++++++++++++++++++------------------
1 files changed, 30 insertions(+), 18 deletions(-)
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
similarity index 75%
rename from funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py
rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
index 10bfa8a..1fc3582 100644
--- a/funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py
+++ b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
@@ -1,6 +1,7 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
+import os.path
import traceback
from pathlib import Path
from typing import List, Union, Tuple
@@ -11,25 +12,33 @@
from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
OrtInferSession, TokenIDConverter, WavFrontend, get_logger,
read_yaml)
+from .postprocess_utils import sentence_postprocess
logging = get_logger()
-class RapidParaformer():
- def __init__(self, config_path: Union[str, Path]) -> None:
- if not Path(config_path).exists():
- raise FileNotFoundError(f'{config_path} does not exist.')
+class Paraformer():
+ def __init__(self, model_dir: Union[str, Path]=None,
+ batch_size: int = 1,
+ device_id: Union[str, int]="-1",
+ ):
+
+ if not Path(model_dir).exists():
+ raise FileNotFoundError(f'{model_dir} does not exist.')
- config = read_yaml(config_path)
+ model_file = os.path.join(model_dir, 'model.onnx')
+ config_file = os.path.join(model_dir, 'config.yaml')
+ cmvn_file = os.path.join(model_dir, 'am.mvn')
+ config = read_yaml(config_file)
- self.converter = TokenIDConverter(**config['TokenIDConverter'])
- self.tokenizer = CharTokenizer(**config['CharTokenizer'])
+ self.converter = TokenIDConverter(config['token_list'])
+ self.tokenizer = CharTokenizer()
self.frontend = WavFrontend(
- cmvn_file=config['WavFrontend']['cmvn_file'],
- **config['WavFrontend']['frontend_conf']
+ cmvn_file=cmvn_file,
+ **config['frontend_conf']
)
- self.ort_infer = OrtInferSession(config['Model'])
- self.batch_size = config['Model']['batch_size']
+ self.ort_infer = OrtInferSession(model_file, device_id)
+ self.batch_size = batch_size
def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
waveform_list = self.load_data(wav_content)
@@ -124,16 +133,19 @@
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
- text = self.tokenizer.tokens2text(token)
+ token = token[:valid_token_num-1]
+ texts = sentence_postprocess(token)
+ text = texts[0]
+ # text = self.tokenizer.tokens2text(token)
return text[:valid_token_num-1]
if __name__ == '__main__':
project_dir = Path(__file__).resolve().parent.parent
- cfg_path = project_dir / 'resources' / 'config.yaml'
- paraformer = RapidParaformer(cfg_path)
+ model_dir = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+ model = Paraformer(model_dir)
- wav_file = '0478_00017.wav'
- for i in range(1000):
- result = paraformer(wav_file)
- print(result)
+ wav_file = os.path.join(model_dir, 'example/asr_example.wav')
+ result = model(wav_file)
+ print(result)
+
--
Gitblit v1.9.1