From fd0992af3d1a2d2d098b1fab24f67f8c4cece39d Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 03 六月 2024 15:32:34 +0800
Subject: [PATCH] update libtorch inference

---
 runtime/python/libtorch/funasr_torch/paraformer_bin.py |   13 ++++++-------
 1 files changed, 6 insertions(+), 7 deletions(-)

diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index b0cd871..ca96b47 100644
--- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -275,7 +275,7 @@
             model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
 
         if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
-            print(".onnx is not exist, begin to export onnx")
+            print(".onnx does not exist, begin to export onnx")
             try:
                 from funasr import AutoModel
             except:
@@ -316,8 +316,7 @@
     ) -> List:
         # make hotword list
         hotwords, hotwords_length = self.proc_hotword(hotwords)
-        # import pdb; pdb.set_trace()
-        [bias_embed] = self.eb_infer(hotwords, hotwords_length)
+        [bias_embed] = self.eb_infer(torch.Tensor(hotwords), torch.Tensor(hotwords_length))
         # index from bias_embed
         bias_embed = bias_embed.transpose(1, 0, 2)
         _ind = np.arange(0, len(hotwords)).tolist()
@@ -333,10 +332,10 @@
             try:
                 with torch.no_grad():
                     if int(self.device_id) == -1:
-                        outputs = self.ort_infer(feats, feats_len)
+                        outputs = self.bb_infer(feats, feats_len)
                         am_scores, valid_token_lens = outputs[0], outputs[1]
                     else:
-                        outputs = self.ort_infer(feats.cuda(), feats_len.cuda())
+                        outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda())
                         am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
             except:
                 # logging.warning(traceback.format_exc())
@@ -374,13 +373,13 @@
         return hotwords, hotwords_length
 
     def bb_infer(
-        self, feats: np.ndarray, feats_len: np.ndarray, bias_embed
+        self, feats, feats_len, bias_embed
     ) -> Tuple[np.ndarray, np.ndarray]:
         outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
         return outputs
 
     def eb_infer(self, hotwords, hotwords_length):
-        outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
+        outputs = self.ort_infer_eb([hotwords, hotwords_length])
         return outputs
 
     def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:

--
Gitblit v1.9.1