From c2dee5e3c29eba79e591d9e9caebaef15ea4e56b Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 29 六月 2023 11:09:28 +0800
Subject: [PATCH] Merge pull request #687 from alibaba-damo-academy/dev_lhn
---
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 27 +++++++++++++++++++++++++--
1 files changed, 25 insertions(+), 2 deletions(-)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index e3ef8c7..f3e0f3d 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -1,4 +1,6 @@
# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import os.path
from pathlib import Path
@@ -19,20 +21,41 @@
class Paraformer():
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
def __init__(self, model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
quantize: bool = False,
intra_op_num_threads: int = 4,
+ cache_dir: str = None
):
if not Path(model_dir).exists():
- raise FileNotFoundError(f'{model_dir} does not exist.')
-
+ from modelscope.hub.snapshot_download import snapshot_download
+ try:
+ model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
+ except:
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
+
model_file = os.path.join(model_dir, 'model.onnx')
if quantize:
model_file = os.path.join(model_dir, 'model_quant.onnx')
+ if not os.path.exists(model_file):
+ print(".onnx is not exist, begin to export onnx")
+ from funasr.export.export_model import ModelExport
+ export_model = ModelExport(
+ cache_dir=cache_dir,
+ onnx=True,
+ device="cpu",
+ quant=quantize,
+ )
+ export_model.export(model_dir)
+
config_file = os.path.join(model_dir, 'config.yaml')
cmvn_file = os.path.join(model_dir, 'am.mvn')
config = read_yaml(config_file)
--
Gitblit v1.9.1