From 3e44172c8b927ffc69b585d4fd80b458cb18ba97 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 25 九月 2024 23:43:30 +0800
Subject: [PATCH] update wbsocket for sensevoice & onnx models

---
 runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py |   13 ++++++++-----
 1 files changed, 8 insertions(+), 5 deletions(-)

diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index 2cd43a8..4f35fcc 100644
--- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -62,7 +62,7 @@
         if quantize:
             model_file = os.path.join(model_dir, "model_quant.onnx")
         if not os.path.exists(model_file):
-            print(".onnx is not exist, begin to export onnx")
+            print(".onnx does not exist, begin to export onnx")
             try:
                 from funasr import AutoModel
             except:
@@ -285,7 +285,7 @@
             model_eb_file = os.path.join(model_dir, "model_eb.onnx")
 
         if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
-            print(".onnx is not exist, begin to export onnx")
+            print(".onnx does not exist, begin to export onnx")
             try:
                 from funasr import AutoModel
             except:
@@ -322,6 +322,10 @@
             self.pred_bias = config["model_conf"]["predictor_bias"]
         else:
             self.pred_bias = 0
+        if "lang" in config:
+            self.language = config["lang"]
+        else:
+            self.language = None
 
     def __call__(
         self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
@@ -331,7 +335,6 @@
     # ) -> List:
         # make hotword list
         hotwords, hotwords_length = self.proc_hotword(hotwords)
-        # import pdb; pdb.set_trace()
         [bias_embed] = self.eb_infer(hotwords, hotwords_length)
         # index from bias_embed
         bias_embed = bias_embed.transpose(1, 0, 2)
@@ -411,10 +414,10 @@
             return np.array(hotwords)
 
         hotword_int = [word_map(i) for i in hotwords]
-        # import pdb; pdb.set_trace()
+
         hotword_int.append(np.array([1]))
         hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
-        # import pdb; pdb.set_trace()
+        
         return hotwords, hotwords_length
 
     def bb_infer(

--
Gitblit v1.9.1