From 865ae89f0a713f70dda16859638b25e7350275ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 13 二月 2023 17:43:01 +0800
Subject: [PATCH] export model

---
 funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py |   48 ++++++++++++++++++++++++++++++------------------
 1 files changed, 30 insertions(+), 18 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
similarity index 75%
rename from funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py
rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
index 10bfa8a..1fc3582 100644
--- a/funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py
+++ b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
@@ -1,6 +1,7 @@
 # -*- encoding: utf-8 -*-
 # @Author: SWHL
 # @Contact: liekkaskono@163.com
+import os.path
 import traceback
 from pathlib import Path
 from typing import List, Union, Tuple
@@ -11,25 +12,33 @@
 from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
                     OrtInferSession, TokenIDConverter, WavFrontend, get_logger,
                     read_yaml)
+from .postprocess_utils import sentence_postprocess
 
 logging = get_logger()
 
 
-class RapidParaformer():
-    def __init__(self, config_path: Union[str, Path]) -> None:
-        if not Path(config_path).exists():
-            raise FileNotFoundError(f'{config_path} does not exist.')
+class Paraformer():
+    def __init__(self, model_dir: Union[str, Path]=None,
+                 batch_size: int = 1,
+                 device_id: Union[str, int]="-1",
+                 ):
+        
+        if not Path(model_dir).exists():
+            raise FileNotFoundError(f'{model_dir} does not exist.')
 
-        config = read_yaml(config_path)
+        model_file = os.path.join(model_dir, 'model.onnx')
+        config_file = os.path.join(model_dir, 'config.yaml')
+        cmvn_file = os.path.join(model_dir, 'am.mvn')
+        config = read_yaml(config_file)
 
-        self.converter = TokenIDConverter(**config['TokenIDConverter'])
-        self.tokenizer = CharTokenizer(**config['CharTokenizer'])
+        self.converter = TokenIDConverter(config['token_list'])
+        self.tokenizer = CharTokenizer()
         self.frontend = WavFrontend(
-            cmvn_file=config['WavFrontend']['cmvn_file'],
-            **config['WavFrontend']['frontend_conf']
+            cmvn_file=cmvn_file,
+            **config['frontend_conf']
         )
-        self.ort_infer = OrtInferSession(config['Model'])
-        self.batch_size = config['Model']['batch_size']
+        self.ort_infer = OrtInferSession(model_file, device_id)
+        self.batch_size = batch_size
 
     def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
         waveform_list = self.load_data(wav_content)
@@ -124,16 +133,19 @@
 
         # Change integer-ids to tokens
         token = self.converter.ids2tokens(token_int)
-        text = self.tokenizer.tokens2text(token)
+        token = token[:valid_token_num-1]
+        texts = sentence_postprocess(token)
+        text = texts[0]
+        # text = self.tokenizer.tokens2text(token)
         return text[:valid_token_num-1]
 
 
 if __name__ == '__main__':
     project_dir = Path(__file__).resolve().parent.parent
-    cfg_path = project_dir / 'resources' / 'config.yaml'
-    paraformer = RapidParaformer(cfg_path)
+    model_dir = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+    model = Paraformer(model_dir)
 
-    wav_file = '0478_00017.wav'
-    for i in range(1000):
-        result = paraformer(wav_file)
-        print(result)
+    wav_file = os.path.join(model_dir, 'example/asr_example.wav')
+    result = model(wav_file)
+    print(result)
+

--
Gitblit v1.9.1