From fd22b6e7f36e963ef29dbd3eafb0e0d6f2e12fa7 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 09 八月 2023 14:27:20 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main
---
funasr/export/export_model.py | 21 ++++++++++++---------
1 files changed, 12 insertions(+), 9 deletions(-)
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index f31f960..8c3108b 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -55,18 +55,21 @@
# export encoder1
self.export_config["model_name"] = "model"
- model = get_model(
+ models = get_model(
model,
self.export_config,
)
- model.eval()
- # self._export_onnx(model, verbose, export_dir)
- if self.onnx:
- self._export_onnx(model, verbose, export_dir)
- else:
- self._export_torchscripts(model, verbose, export_dir)
-
- print("output dir: {}".format(export_dir))
+ if not isinstance(models, tuple):
+ models = (models,)
+
+ for i, model in enumerate(models):
+ model.eval()
+ if self.onnx:
+ self._export_onnx(model, verbose, export_dir)
+ else:
+ self._export_torchscripts(model, verbose, export_dir)
+
+ print("output dir: {}".format(export_dir))
def _torch_quantize(self, model):
--
Gitblit v1.9.1