From cfc4d402093060fe087424b0a6be4e2b2546eae8 Mon Sep 17 00:00:00 2001
From: wanchen.swc <wanchen.swc@alibaba-inc.com>
Date: 星期四, 30 三月 2023 18:15:15 +0800
Subject: [PATCH] [Export] support gpu inference

---
 funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py |    3 +++
 1 files changed, 3 insertions(+), 0 deletions(-)

diff --git a/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index 3c0606d..f9232af 100644
--- a/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/funasr/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -58,6 +58,9 @@
             end_idx = min(waveform_nums, beg_idx + self.batch_size)
             feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
             try:
+                if int(device_id) != -1:
+                    feats = feats.cuda()
+                    feats_len = feats_len.cuda()
                 outputs = self.ort_infer(feats, feats_len)
                 am_scores, valid_token_lens = outputs[0], outputs[1]
                 if len(outputs) == 4:

--
Gitblit v1.9.1