From c73d1a8e81582b91a9bdd6e82fce2e84f8d9d94b Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 14 八月 2023 19:31:55 +0800
Subject: [PATCH] update func cif_wo_hidden
---
funasr/export/export_model.py | 21 ++++++++++++---------
1 files changed, 12 insertions(+), 9 deletions(-)
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index f31f960..8c3108b 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -55,18 +55,21 @@
# export encoder1
self.export_config["model_name"] = "model"
- model = get_model(
+ models = get_model(
model,
self.export_config,
)
- model.eval()
- # self._export_onnx(model, verbose, export_dir)
- if self.onnx:
- self._export_onnx(model, verbose, export_dir)
- else:
- self._export_torchscripts(model, verbose, export_dir)
-
- print("output dir: {}".format(export_dir))
+ if not isinstance(models, tuple):
+ models = (models,)
+
+ for i, model in enumerate(models):
+ model.eval()
+ if self.onnx:
+ self._export_onnx(model, verbose, export_dir)
+ else:
+ self._export_torchscripts(model, verbose, export_dir)
+
+ print("output dir: {}".format(export_dir))
def _torch_quantize(self, model):
--
Gitblit v1.9.1