From fd22b6e7f36e963ef29dbd3eafb0e0d6f2e12fa7 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期三, 09 八月 2023 14:27:20 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main
---
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 18 ++++++++++++++----
1 files changed, 14 insertions(+), 4 deletions(-)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 8890714..cc5daa8 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -10,7 +10,7 @@
from .utils.utils import (ONNXRuntimeError,
OrtInferSession, get_logger,
read_yaml)
-from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
+from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words,code_mix_split_words_jieba)
logging = get_logger()
@@ -65,9 +65,18 @@
self.punc_list[i] = "锛�"
elif self.punc_list[i] == "銆�":
self.period = i
+ if "seg_jieba" in config:
+ self.seg_jieba = True
+ self.jieba_usr_dict_path = os.path.join(model_dir, 'jieba_usr_dict')
+ self.code_mix_split_words_jieba = code_mix_split_words_jieba(self.jieba_usr_dict_path)
+ else:
+ self.seg_jieba = False
def __call__(self, text: Union[list, str], split_size=20):
- split_text = code_mix_split_words(text)
+ if self.seg_jieba:
+ split_text = self.code_mix_split_words_jieba(text)
+ else:
+ split_text = code_mix_split_words(text)
split_text_id = self.converter.tokens2ids(split_text)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
@@ -186,11 +195,12 @@
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
text_length = len(mini_sentence_id)
+ vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
data = {
"input": mini_sentence_id[None,:],
"text_lengths": np.array([text_length], dtype='int32'),
- "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
- "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
+ "vad_mask": vad_mask,
+ "sub_masks": vad_mask
}
try:
outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
--
Gitblit v1.9.1