From 57ccdf04e0dc17af86ae9b2f2d6155440989f450 Mon Sep 17 00:00:00 2001
From: Xian Shi <40013335+R1ckShi@users.noreply.github.com>
Date: 星期二, 12 九月 2023 19:56:29 +0800
Subject: [PATCH] Merge pull request #939 from alibaba-damo-academy/dev_sxfix

---
 funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py |    9 ++++++++-
 funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py |    2 +-
 funasr/bin/asr_inference_launch.py                              |    8 ++++----
 funasr/export/export_model.py                                   |    2 +-
 4 files changed, 14 insertions(+), 7 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 2def98b..cdaaefc 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -415,7 +415,7 @@
                         ibest_writer["rtf"][key] = rtf_cur
 
                     if text is not None:
-                        if use_timestamp and timestamp is not None:
+                        if use_timestamp and timestamp is not None and len(timestamp):
                             postprocessed_result = postprocess_utils.sentence_postprocess(token, timestamp)
                         else:
                             postprocessed_result = postprocess_utils.sentence_postprocess(token)
@@ -427,7 +427,7 @@
                         else:
                             text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
                         item = {'key': key, 'value': text_postprocessed}
-                        if timestamp_postprocessed != "":
+                        if timestamp_postprocessed != "" or len(timestamp) == 0:
                             item['timestamp'] = timestamp_postprocessed
                         asr_result_list.append(item)
                         finish_count += 1
@@ -692,7 +692,7 @@
             text, token, token_int = result[0], result[1], result[2]
             time_stamp = result[4] if len(result[4]) > 0 else None
 
-            if use_timestamp and time_stamp is not None:
+            if use_timestamp and time_stamp is not None and len(time_stamp):
                 postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
             else:
                 postprocessed_result = postprocess_utils.sentence_postprocess(token)
@@ -717,7 +717,7 @@
             item = {'key': key, 'value': text_postprocessed_punc}
             if text_postprocessed != "":
                 item['text_postprocessed'] = text_postprocessed
-            if time_stamp_postprocessed != "":
+            if time_stamp_postprocessed != "" or len(time_stamp) == 0:
                 item['time_stamp'] = time_stamp_postprocessed
 
             item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index e0a9313..6ab9408 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -254,7 +254,7 @@
             if not os.path.exists(quant_model_path):
                 onnx_model = onnx.load(model_path)
                 nodes = [n.name for n in onnx_model.graph.node]
-                nodes_to_exclude = [m for m in nodes if 'output' in m]
+                nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m  or 'bias_decoder' in m]
                 quantize_dynamic(
                     model_input=model_path,
                     model_output=quant_model_path,
diff --git a/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py b/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
index 984c0d6..9da3817 100644
--- a/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
+++ b/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
@@ -5,7 +5,7 @@
 model = ContextualParaformer(model_dir, batch_size=1)
 
 wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
-hotwords = '闅忔満鐑瘝 鍚勭鐑瘝 榄旀惌 闃块噷宸村反'
+hotwords = '闅忔満鐑瘝 鍚勭鐑瘝 榄旀惌 闃块噷宸村反 浠�'
 
 result = model(wav_path, hotwords)
 print(result)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
index c994036..4caa5c1 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -314,7 +314,14 @@
         hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
         # hotwords.append('<s>')
         def word_map(word):
-            return torch.tensor([self.vocab[i] for i in word])
+            hotwords = []
+            for c in word:
+                if c not in self.vocab.keys():
+                    hotwords.append(8403)
+                    logging.warning("oov character {} found in hotword {}, replaced by <unk>".format(c, word))
+                else:
+                    hotwords.append(self.vocab[c])
+            return torch.tensor(hotwords)
         hotword_int = [word_map(i) for i in hotwords]
         # import pdb; pdb.set_trace()
         hotword_int.append(torch.tensor([1]))

--
Gitblit v1.9.1