From b9bcf1f093c3053fdc4e2cf4a1d38e27bbf429fb Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 19 十月 2023 14:03:48 +0800
Subject: [PATCH] docs
---
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 34 ++++++++++++++++++++++++++++------
1 files changed, 28 insertions(+), 6 deletions(-)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 8890714..6e289f6 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -10,7 +10,7 @@
from .utils.utils import (ONNXRuntimeError,
OrtInferSession, get_logger,
read_yaml)
-from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
+from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words,code_mix_split_words_jieba)
logging = get_logger()
@@ -29,7 +29,13 @@
):
if not Path(model_dir).exists():
- from modelscope.hub.snapshot_download import snapshot_download
+ try:
+ from modelscope.hub.snapshot_download import snapshot_download
+ except:
+ raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
+ "\npip3 install -U modelscope\n" \
+ "For the users in China, you could install with the command:\n" \
+ "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
@@ -41,7 +47,13 @@
model_file = os.path.join(model_dir, 'model_quant.onnx')
if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx")
- from funasr.export.export_model import ModelExport
+ try:
+ from funasr.export.export_model import ModelExport
+ except:
+ raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
+ "\npip3 install -U funasr\n" \
+ "For the users in China, you could install with the command:\n" \
+ "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
export_model = ModelExport(
cache_dir=cache_dir,
onnx=True,
@@ -65,9 +77,18 @@
self.punc_list[i] = "锛�"
elif self.punc_list[i] == "銆�":
self.period = i
+ if "seg_jieba" in config:
+ self.seg_jieba = True
+ self.jieba_usr_dict_path = os.path.join(model_dir, 'jieba_usr_dict')
+ self.code_mix_split_words_jieba = code_mix_split_words_jieba(self.jieba_usr_dict_path)
+ else:
+ self.seg_jieba = False
def __call__(self, text: Union[list, str], split_size=20):
- split_text = code_mix_split_words(text)
+ if self.seg_jieba:
+ split_text = self.code_mix_split_words_jieba(text)
+ else:
+ split_text = code_mix_split_words(text)
split_text_id = self.converter.tokens2ids(split_text)
mini_sentences = split_to_mini_sentence(split_text, split_size)
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
@@ -186,11 +207,12 @@
mini_sentence = cache_sent + mini_sentence
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
text_length = len(mini_sentence_id)
+ vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
data = {
"input": mini_sentence_id[None,:],
"text_lengths": np.array([text_length], dtype='int32'),
- "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
- "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
+ "vad_mask": vad_mask,
+ "sub_masks": vad_mask
}
try:
outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
--
Gitblit v1.9.1