From d2e9bf01426af61dce7fdc4449906755524eb253 Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 03 六月 2024 15:47:00 +0800
Subject: [PATCH] update model class

---
 runtime/python/libtorch/funasr_torch/paraformer_bin.py |   21 ++++++++++-----------
 runtime/python/libtorch/demo_contextual_paraformer.py  |    2 +-
 2 files changed, 11 insertions(+), 12 deletions(-)

diff --git a/runtime/python/libtorch/demo_contextual_paraformer.py b/runtime/python/libtorch/demo_contextual_paraformer.py
index fb2337d..06c0f76 100644
--- a/runtime/python/libtorch/demo_contextual_paraformer.py
+++ b/runtime/python/libtorch/demo_contextual_paraformer.py
@@ -7,7 +7,7 @@
 model = ContextualParaformer(model_dir, batch_size=1, device_id=device_id)  # gpu
 
 wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
-hotwords = "浣犵殑鐑瘝 榄旀惌 杈炬懇鑻�"
+hotwords = "浣犵殑鐑瘝 榄斿搾"
 
 result = model(wav_path, hotwords)
 print(result)
diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index ca96b47..9f35db7 100644
--- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -282,7 +282,7 @@
                 raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
             model = AutoModel(model=model_dir)
-            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
+            model_dir = model.export(type="torchscripts", quantize=quantize, **kwargs)
 
         config_file = os.path.join(model_dir, "config.yaml")
         cmvn_file = os.path.join(model_dir, "am.mvn")
@@ -316,9 +316,9 @@
     ) -> List:
         # make hotword list
         hotwords, hotwords_length = self.proc_hotword(hotwords)
-        [bias_embed] = self.eb_infer(torch.Tensor(hotwords), torch.Tensor(hotwords_length))
+        bias_embed = self.eb_infer(torch.Tensor(hotwords))
         # index from bias_embed
-        bias_embed = bias_embed.transpose(1, 0, 2)
+        bias_embed = torch.transpose(bias_embed, 0, 1)
         _ind = np.arange(0, len(hotwords)).tolist()
         bias_embed = bias_embed[_ind, hotwords_length.tolist()]
         waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@@ -327,15 +327,14 @@
         for beg_idx in range(0, waveform_nums, self.batch_size):
             end_idx = min(waveform_nums, beg_idx + self.batch_size)
             feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
-            bias_embed = np.expand_dims(bias_embed, axis=0)
-            bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
+            bias_embed = torch.unsqueeze(bias_embed, 0).repeat(feats.shape[0], 1, 1)
             try:
                 with torch.no_grad():
                     if int(self.device_id) == -1:
-                        outputs = self.bb_infer(feats, feats_len)
+                        outputs = self.bb_infer(feats, feats_len, bias_embed)
                         am_scores, valid_token_lens = outputs[0], outputs[1]
                     else:
-                        outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda())
+                        outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda(), bias_embed)
                         am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
             except:
                 # logging.warning(traceback.format_exc())
@@ -374,12 +373,12 @@
 
     def bb_infer(
         self, feats, feats_len, bias_embed
-    ) -> Tuple[np.ndarray, np.ndarray]:
-        outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
+    ):
+        outputs = self.ort_infer_bb(feats, feats_len, bias_embed)
         return outputs
 
-    def eb_infer(self, hotwords, hotwords_length):
-        outputs = self.ort_infer_eb([hotwords, hotwords_length])
+    def eb_infer(self, hotwords):
+        outputs = self.ort_infer_eb(hotwords.long())
         return outputs
 
     def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:

--
Gitblit v1.9.1