From 2ae59b6ce06305724e2eaf30b9f9e93447a7832e Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 22 七月 2024 16:58:27 +0800
Subject: [PATCH] ONNX and torchscript export for sensevoice
---
funasr/utils/export_utils.py | 44 +++++++++++++++++++++++++-------------------
1 files changed, 25 insertions(+), 19 deletions(-)
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index a6d0798..af9f37b 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -54,7 +54,10 @@
verbose = kwargs.get("verbose", False)
- export_name = model.export_name + ".onnx"
+ if isinstance(model.export_name, str):
+ export_name = model.export_name + ".onnx"
+ else:
+ export_name = model.export_name()
model_path = os.path.join(export_dir, export_name)
torch.onnx.export(
model,
@@ -72,35 +75,38 @@
import onnx
quant_model_path = model_path.replace(".onnx", "_quant.onnx")
- if not os.path.exists(quant_model_path):
- onnx_model = onnx.load(model_path)
- nodes = [n.name for n in onnx_model.graph.node]
- nodes_to_exclude = [
- m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
- ]
- quantize_dynamic(
- model_input=model_path,
- model_output=quant_model_path,
- op_types_to_quantize=["MatMul"],
- per_channel=True,
- reduce_range=False,
- weight_type=QuantType.QUInt8,
- nodes_to_exclude=nodes_to_exclude,
- )
+ onnx_model = onnx.load(model_path)
+ nodes = [n.name for n in onnx_model.graph.node]
+ nodes_to_exclude = [
+ m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
+ ]
+ print("Quantizing model from {} to {}".format(model_path, quant_model_path))
+ quantize_dynamic(
+ model_input=model_path,
+ model_output=quant_model_path,
+ op_types_to_quantize=["MatMul"],
+ per_channel=True,
+ reduce_range=False,
+ weight_type=QuantType.QUInt8,
+ nodes_to_exclude=nodes_to_exclude,
+ )
def _torchscripts(model, path, device="cuda"):
dummy_input = model.export_dummy_inputs()
-
+
if device == "cuda":
model = model.cuda()
if isinstance(dummy_input, torch.Tensor):
dummy_input = dummy_input.cuda()
else:
dummy_input = tuple([i.cuda() for i in dummy_input])
-
+
model_script = torch.jit.trace(model, dummy_input)
- model_script.save(os.path.join(path, f"{model.export_name}.torchscript"))
+ if isinstance(model.export_name, str):
+ model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript")))
+ else:
+ model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")))
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
--
Gitblit v1.9.1