From f57b68121a526baea43b2e93f4540d8a2995f633 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 15:15:24 +0800
Subject: [PATCH] batch

---
 funasr/bin/export.py |   28 ++++++++++++++++------------
 1 files changed, 16 insertions(+), 12 deletions(-)

diff --git a/funasr/bin/export.py b/funasr/bin/export.py
index 7d47664..6c9b49f 100644
--- a/funasr/bin/export.py
+++ b/funasr/bin/export.py
@@ -15,27 +15,31 @@
             return {k: to_plain_list(v) for k, v in cfg_item.items()}
         else:
             return cfg_item
-    
+
     kwargs = to_plain_list(cfg)
     log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
 
     logging.basicConfig(level=log_level)
 
     if kwargs.get("debug", False):
-        import pdb; pdb.set_trace()
+        import pdb
 
+        pdb.set_trace()
 
+    if "device" not in kwargs:
+        kwargs["device"] = "cpu"
     model = AutoModel(**kwargs)
-    
-    res = model.export(input=kwargs.get("input", None),
-                       type=kwargs.get("type", "onnx"),
-                       quantize=kwargs.get("quantize", False),
-                       fallback_num=kwargs.get("fallback-num", 5),
-                       calib_num=kwargs.get("calib_num", 100),
-                       opset_version=kwargs.get("opset_version", 14),
-                       )
+
+    res = model.export(
+        input=kwargs.get("input", None),
+        type=kwargs.get("type", "onnx"),
+        quantize=kwargs.get("quantize", False),
+        fallback_num=kwargs.get("fallback-num", 5),
+        calib_num=kwargs.get("calib_num", 100),
+        opset_version=kwargs.get("opset_version", 14),
+    )
     print(res)
 
 
-if __name__ == '__main__':
-    main_hydra()
\ No newline at end of file
+if __name__ == "__main__":
+    main_hydra()

--
Gitblit v1.9.1