From 9fcb3cc06b4e324f0913d2f61b89becc2baeef1b Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期一, 11 九月 2023 17:40:03 +0800
Subject: [PATCH] Merge pull request #932 from alibaba-damo-academy/dev_lhn

---
 funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py |   15 ++++++++++++---
 1 files changed, 12 insertions(+), 3 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 035dd00..cc5daa8 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -10,7 +10,7 @@
 from .utils.utils import (ONNXRuntimeError,
                           OrtInferSession, get_logger,
                           read_yaml)
-from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
+from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words,code_mix_split_words_jieba)
 logging = get_logger()
 
 
@@ -65,9 +65,18 @@
                 self.punc_list[i] = "锛�"
             elif self.punc_list[i] == "銆�":
                 self.period = i
+        if "seg_jieba" in config:
+            self.seg_jieba = True
+            self.jieba_usr_dict_path = os.path.join(model_dir, 'jieba_usr_dict')
+            self.code_mix_split_words_jieba = code_mix_split_words_jieba(self.jieba_usr_dict_path)
+        else:
+            self.seg_jieba = False
 
     def __call__(self, text: Union[list, str], split_size=20):
-        split_text = code_mix_split_words(text)
+        if self.seg_jieba:
+            split_text = self.code_mix_split_words_jieba(text)
+        else:
+            split_text = code_mix_split_words(text)
         split_text_id = self.converter.tokens2ids(split_text)
         mini_sentences = split_to_mini_sentence(split_text, split_size)
         mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
@@ -190,7 +199,7 @@
             data = {
                 "input": mini_sentence_id[None,:],
                 "text_lengths": np.array([text_length], dtype='int32'),
-                "vad_mask": vad_mask
+                "vad_mask": vad_mask,
                 "sub_masks": vad_mask
             }
             try:

--
Gitblit v1.9.1