From fd0992af3d1a2d2d098b1fab24f67f8c4cece39d Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 03 六月 2024 15:32:34 +0800
Subject: [PATCH] update libtorch inference

---
 funasr/utils/export_utils.py |   12 ++++++++----
 1 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index 7d6606b..8f1aa53 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -20,10 +20,12 @@
                 export_dir=export_dir,
                 **kwargs
             )
-        elif type == 'torchscript':
+        elif type == 'torchscripts':
+            device = 'cuda' if torch.cuda.is_available() else 'cpu'
             _torchscripts(
                 m,
                 path=export_dir,
+                device=device
             )
         print("output dir: {}".format(export_dir))
 
@@ -78,13 +80,15 @@
             )
 
 
-def _torchscripts(model, path, device='cpu'):
+def _torchscripts(model, path, device='cuda'):
     dummy_input = model.export_dummy_inputs()
 
     if device == 'cuda':
         model = model.cuda()
-        dummy_input = tuple([i.cuda() for i in dummy_input])
+        if isinstance(dummy_input, torch.Tensor):
+            dummy_input = dummy_input.cuda()
+        else:
+            dummy_input = tuple([i.cuda() for i in dummy_input])
 
-    # model_script = torch.jit.script(model)
     model_script = torch.jit.trace(model, dummy_input)
     model_script.save(os.path.join(path, f'{model.export_name}.torchscripts'))

--
Gitblit v1.9.1