From 7bf2eec71e0c65f15628a105d11406a8a14ae178 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期三, 15 三月 2023 20:17:01 +0800
Subject: [PATCH] update paraformer_onnx

---
 funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py |   13 +++++++------
 1 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py b/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py
index ca90558..e6b33d4 100644
--- a/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py
+++ b/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py
@@ -12,7 +12,7 @@
                           read_yaml)
 from .utils.postprocess_utils import sentence_postprocess
 from .utils.frontend import WavFrontend
-from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
+from .utils.timestamp_utils import time_stamp_lfr6_onnx
 logging = get_logger()
 
 import torch
@@ -27,7 +27,7 @@
         if not Path(model_dir).exists():
             raise FileNotFoundError(f'{model_dir} does not exist.')
 
-        model_file = os.path.join(model_dir, 'model.onnx')
+        model_file = os.path.join(model_dir, 'model.torchscripts')
         config_file = os.path.join(model_dir, 'config.yaml')
         cmvn_file = os.path.join(model_dir, 'am.mvn')
         config = read_yaml(config_file)
@@ -52,9 +52,8 @@
             feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
 
             try:
-                outputs = self.infer(feats, feats_len)
-                outs = outputs[0], outputs[1]
-                am_scores, valid_token_lens = outs[0], outs[1]
+                outputs = self.ort_infer(feats, feats_len)
+                am_scores, valid_token_lens = outputs[0], outputs[1]
                 if len(outputs) == 4:
                     # for BiCifParaformer Inference
                     us_alphas, us_cif_peak = outputs[2], outputs[3]
@@ -65,7 +64,7 @@
                 logging.warning("input wav is silence or noise")
                 preds = ['']
             else:
-                am_scores, valid_token_lens = am_scores.cpu().numpy(), valid_token_lens.cpu().numpy()
+                am_scores, valid_token_lens = am_scores.detach().cpu().numpy(), valid_token_lens.detach().cpu().numpy()
                 preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
                 res['preds'] = preds
                 if us_cif_peak is not None:
@@ -105,6 +104,8 @@
 
         feats = self.pad_feats(feats, np.max(feats_len))
         feats_len = np.array(feats_len).astype(np.int32)
+        feats = torch.from_numpy(feats).type(torch.float32)
+        feats_len = torch.from_numpy(feats_len).type(torch.int32)
         return feats, feats_len
 
     @staticmethod

--
Gitblit v1.9.1