From fd0992af3d1a2d2d098b1fab24f67f8c4cece39d Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期一, 03 六月 2024 15:32:34 +0800
Subject: [PATCH] update libtorch inference
---
runtime/python/onnxruntime/funasr_onnx/vad_bin.py | 4 ++--
runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 2 +-
runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py | 2 +-
runtime/python/libtorch/funasr_torch/paraformer_bin.py | 13 ++++++-------
runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 2 +-
5 files changed, 11 insertions(+), 12 deletions(-)
diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index b0cd871..ca96b47 100644
--- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -275,7 +275,7 @@
model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
@@ -316,8 +316,7 @@
) -> List:
# make hotword list
hotwords, hotwords_length = self.proc_hotword(hotwords)
- # import pdb; pdb.set_trace()
- [bias_embed] = self.eb_infer(hotwords, hotwords_length)
+ [bias_embed] = self.eb_infer(torch.Tensor(hotwords), torch.Tensor(hotwords_length))
# index from bias_embed
bias_embed = bias_embed.transpose(1, 0, 2)
_ind = np.arange(0, len(hotwords)).tolist()
@@ -333,10 +332,10 @@
try:
with torch.no_grad():
if int(self.device_id) == -1:
- outputs = self.ort_infer(feats, feats_len)
+ outputs = self.bb_infer(feats, feats_len)
am_scores, valid_token_lens = outputs[0], outputs[1]
else:
- outputs = self.ort_infer(feats.cuda(), feats_len.cuda())
+ outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda())
am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
except:
# logging.warning(traceback.format_exc())
@@ -374,13 +373,13 @@
return hotwords, hotwords_length
def bb_infer(
- self, feats: np.ndarray, feats_len: np.ndarray, bias_embed
+ self, feats, feats_len, bias_embed
) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
return outputs
def eb_infer(self, hotwords, hotwords_length):
- outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
+ outputs = self.ort_infer_eb([hotwords, hotwords_length])
return outputs
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index 8194283..871674e 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -285,7 +285,7 @@
model_eb_file = os.path.join(model_dir, "model_eb.onnx")
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
index 9b68b2f..ddd857d 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
@@ -54,7 +54,7 @@
encoder_model_file = os.path.join(model_dir, "model_quant.onnx")
decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
diff --git a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 6208c09..ba55186 100644
--- a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -52,7 +52,7 @@
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
diff --git a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
index c195bb3..92928a8 100644
--- a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py
@@ -52,7 +52,7 @@
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
@@ -221,7 +221,7 @@
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
--
Gitblit v1.9.1