From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 runtime/python/libtorch/funasr_torch/paraformer_bin.py |   25 ++++++++++++++-----------
 1 files changed, 14 insertions(+), 11 deletions(-)

diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index 9f35db7..16c0406 100644
--- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -46,11 +46,11 @@
                     model_dir
                 )
 
-        model_file = os.path.join(model_dir, "model.torchscripts")
+        model_file = os.path.join(model_dir, "model.torchscript")
         if quantize:
-            model_file = os.path.join(model_dir, "model_quant.torchscripts")
+            model_file = os.path.join(model_dir, "model_quant.torchscript")
         if not os.path.exists(model_file):
-            print(".torchscripts does not exist, begin to export torchscripts")
+            print(".torchscripts does not exist, begin to export torchscript")
             try:
                 from funasr import AutoModel
             except:
@@ -268,11 +268,11 @@
                 )
 
         if quantize:
-            model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscripts")
-            model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscripts")
+            model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscript")
+            model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscript")
         else:
-            model_bb_file = os.path.join(model_dir, "model_bb.torchscripts")
-            model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
+            model_bb_file = os.path.join(model_dir, "model_bb.torchscript")
+            model_eb_file = os.path.join(model_dir, "model_eb.torchscript")
 
         if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
             print(".onnx does not exist, begin to export onnx")
@@ -282,7 +282,7 @@
                 raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
             model = AutoModel(model=model_dir)
-            model_dir = model.export(type="torchscripts", quantize=quantize, **kwargs)
+            model_dir = model.export(type="torchscript", quantize=quantize, **kwargs)
 
         config_file = os.path.join(model_dir, "config.yaml")
         cmvn_file = os.path.join(model_dir, "am.mvn")
@@ -316,7 +316,10 @@
     ) -> List:
         # make hotword list
         hotwords, hotwords_length = self.proc_hotword(hotwords)
-        bias_embed = self.eb_infer(torch.Tensor(hotwords))
+        if int(self.device_id) != -1:
+            bias_embed = self.eb_infer(hotwords.cuda())
+        else:
+            bias_embed = self.eb_infer(hotwords)
         # index from bias_embed
         bias_embed = torch.transpose(bias_embed, 0, 1)
         _ind = np.arange(0, len(hotwords)).tolist()
@@ -334,7 +337,7 @@
                         outputs = self.bb_infer(feats, feats_len, bias_embed)
                         am_scores, valid_token_lens = outputs[0], outputs[1]
                     else:
-                        outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda(), bias_embed)
+                        outputs = self.bb_infer(feats.cuda(), feats_len.cuda(), bias_embed.cuda())
                         am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
             except:
                 # logging.warning(traceback.format_exc())
@@ -369,7 +372,7 @@
         hotword_int = [word_map(i) for i in hotwords]
         hotword_int.append(np.array([1]))
         hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
-        return hotwords, hotwords_length
+        return torch.tensor(hotwords), hotwords_length
 
     def bb_infer(
         self, feats, feats_len, bias_embed

--
Gitblit v1.9.1