From 487189b949a2edad084e8275b35b72f324ba5218 Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 03 六月 2024 15:52:20 +0800
Subject: [PATCH] bug fix

---
 runtime/python/libtorch/funasr_torch/paraformer_bin.py |    9 ++++++---
 runtime/python/libtorch/demo_seaco_paraformer.py       |    2 +-
 runtime/python/libtorch/demo_contextual_paraformer.py  |    2 +-
 3 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/runtime/python/libtorch/demo_contextual_paraformer.py b/runtime/python/libtorch/demo_contextual_paraformer.py
index 06c0f76..306981c 100644
--- a/runtime/python/libtorch/demo_contextual_paraformer.py
+++ b/runtime/python/libtorch/demo_contextual_paraformer.py
@@ -7,7 +7,7 @@
 model = ContextualParaformer(model_dir, batch_size=1, device_id=device_id)  # gpu
 
 wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
-hotwords = "浣犵殑鐑瘝 榄斿搾"
+hotwords = "浣犵殑鐑瘝 榄旀惌"
 
 result = model(wav_path, hotwords)
 print(result)
diff --git a/runtime/python/libtorch/demo_seaco_paraformer.py b/runtime/python/libtorch/demo_seaco_paraformer.py
index d54dce5..ad28bfe 100644
--- a/runtime/python/libtorch/demo_seaco_paraformer.py
+++ b/runtime/python/libtorch/demo_seaco_paraformer.py
@@ -7,7 +7,7 @@
 model = SeacoParaformer(model_dir, batch_size=1, device_id=device_id)  # gpu
 
 wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
-hotwords = "浣犵殑鐑瘝 榄斿搾"
+hotwords = "浣犵殑鐑瘝 榄旀惌"
 
 result = model(wav_path, hotwords)
 print(result)
diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index 9f35db7..755237e 100644
--- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -316,7 +316,10 @@
     ) -> List:
         # make hotword list
         hotwords, hotwords_length = self.proc_hotword(hotwords)
-        bias_embed = self.eb_infer(torch.Tensor(hotwords))
+        if int(self.device_id) != -1:
+            bias_embed = self.eb_infer(hotwords.cuda())
+        else:
+            bias_embed = self.eb_infer(hotwords)
         # index from bias_embed
         bias_embed = torch.transpose(bias_embed, 0, 1)
         _ind = np.arange(0, len(hotwords)).tolist()
@@ -334,7 +337,7 @@
                         outputs = self.bb_infer(feats, feats_len, bias_embed)
                         am_scores, valid_token_lens = outputs[0], outputs[1]
                     else:
-                        outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda(), bias_embed)
+                        outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda(), bias_embed.cuda())
                         am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
             except:
                 # logging.warning(traceback.format_exc())
@@ -369,7 +372,7 @@
         hotword_int = [word_map(i) for i in hotwords]
         hotword_int.append(np.array([1]))
         hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
-        return hotwords, hotwords_length
+        return torch.tensor(hotwords), hotwords_length
 
     def bb_infer(
         self, feats, feats_len, bias_embed

--
Gitblit v1.9.1