From 487189b949a2edad084e8275b35b72f324ba5218 Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 03 六月 2024 15:52:20 +0800
Subject: [PATCH] bug fix
---
runtime/python/libtorch/funasr_torch/paraformer_bin.py | 9 ++++++---
1 files changed, 6 insertions(+), 3 deletions(-)
diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index 9f35db7..755237e 100644
--- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -316,7 +316,10 @@
) -> List:
# make hotword list
hotwords, hotwords_length = self.proc_hotword(hotwords)
- bias_embed = self.eb_infer(torch.Tensor(hotwords))
+ if int(self.device_id) != -1:
+ bias_embed = self.eb_infer(hotwords.cuda())
+ else:
+ bias_embed = self.eb_infer(hotwords)
# index from bias_embed
bias_embed = torch.transpose(bias_embed, 0, 1)
_ind = np.arange(0, len(hotwords)).tolist()
@@ -334,7 +337,7 @@
outputs = self.bb_infer(feats, feats_len, bias_embed)
am_scores, valid_token_lens = outputs[0], outputs[1]
else:
- outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda(), bias_embed)
+ outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda(), bias_embed.cuda())
am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
except:
# logging.warning(traceback.format_exc())
@@ -369,7 +372,7 @@
hotword_int = [word_map(i) for i in hotwords]
hotword_int.append(np.array([1]))
hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
- return hotwords, hotwords_length
+ return torch.tensor(hotwords), hotwords_length
def bb_infer(
self, feats, feats_len, bias_embed
--
Gitblit v1.9.1