From fd0992af3d1a2d2d098b1fab24f67f8c4cece39d Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 03 六月 2024 15:32:34 +0800
Subject: [PATCH] update libtorch inference
---
funasr/utils/export_utils.py | 12 ++++++++----
1 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py
index 7d6606b..8f1aa53 100644
--- a/funasr/utils/export_utils.py
+++ b/funasr/utils/export_utils.py
@@ -20,10 +20,12 @@
export_dir=export_dir,
**kwargs
)
- elif type == 'torchscript':
+ elif type == 'torchscripts':
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
_torchscripts(
m,
path=export_dir,
+ device=device
)
print("output dir: {}".format(export_dir))
@@ -78,13 +80,15 @@
)
-def _torchscripts(model, path, device='cpu'):
+def _torchscripts(model, path, device='cuda'):
dummy_input = model.export_dummy_inputs()
if device == 'cuda':
model = model.cuda()
- dummy_input = tuple([i.cuda() for i in dummy_input])
+ if isinstance(dummy_input, torch.Tensor):
+ dummy_input = dummy_input.cuda()
+ else:
+ dummy_input = tuple([i.cuda() for i in dummy_input])
- # model_script = torch.jit.script(model)
model_script = torch.jit.trace(model, dummy_input)
model_script.save(os.path.join(path, f'{model.export_name}.torchscripts'))
--
Gitblit v1.9.1