From 2ccba92cd82dc81ef887f477480011c087e38182 Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 10 三月 2023 17:40:56 +0800
Subject: [PATCH] update unittest
---
funasr/export/export_model.py | 11 +++++++----
1 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index 3c73152..3cbf6d2 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -7,10 +7,12 @@
import logging
import torch
-from funasr.bin.asr_inference_paraformer import Speech2Text
from funasr.export.models import get_model
import numpy as np
import random
+
+# torch_version = float(".".join(torch.__version__.split(".")[:2]))
+# assert torch_version > 1.9
class ASRModelExportParaformer:
def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True):
@@ -30,7 +32,7 @@
def _export(
self,
- model: Speech2Text,
+ model,
tag_name: str = None,
verbose: bool = False,
):
@@ -58,7 +60,7 @@
if enc_size:
dummy_input = model.get_dummy_inputs(enc_size)
else:
- dummy_input = model.get_dummy_inputs_txt()
+ dummy_input = model.get_dummy_inputs()
# model_script = torch.jit.script(model)
model_script = torch.jit.trace(model, dummy_input)
@@ -111,12 +113,13 @@
dummy_input,
os.path.join(path, f'{model.model_name}.onnx'),
verbose=verbose,
- opset_version=12,
+ opset_version=14,
input_names=model.get_input_names(),
output_names=model.get_output_names(),
dynamic_axes=model.get_dynamic_axes()
)
+
if __name__ == '__main__':
import sys
--
Gitblit v1.9.1