From 6c467e6f0abfc6d20d0621fbbf67b4dbd81776cc Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期二, 18 六月 2024 10:01:56 +0800
Subject: [PATCH] Merge pull request #1825 from modelscope/dev_libt
---
runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 180 +++++++++++++++++++++++++++++++++--------------------------
1 files changed, 100 insertions(+), 80 deletions(-)
diff --git a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index b1aca6e..ba55186 100644
--- a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -7,67 +7,72 @@
from typing import List, Union, Tuple
import numpy as np
import json
-from .utils.utils import (ONNXRuntimeError,
- OrtInferSession, get_logger,
- read_yaml)
-from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words,code_mix_split_words_jieba)
+from .utils.utils import ONNXRuntimeError, OrtInferSession, get_logger, read_yaml
+from .utils.utils import (
+ TokenIDConverter,
+ split_to_mini_sentence,
+ code_mix_split_words,
+ code_mix_split_words_jieba,
+)
+
logging = get_logger()
-class CT_Transformer():
+class CT_Transformer:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
- def __init__(self, model_dir: Union[str, Path] = None,
- batch_size: int = 1,
- device_id: Union[str, int] = "-1",
- quantize: bool = False,
- intra_op_num_threads: int = 4,
- cache_dir: str = None,
- ):
-
+
+ def __init__(
+ self,
+ model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ device_id: Union[str, int] = "-1",
+ quantize: bool = False,
+ intra_op_num_threads: int = 4,
+ cache_dir: str = None,
+ **kwargs
+ ):
+
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
- raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
- "\npip3 install -U modelscope\n" \
- "For the users in China, you could install with the command:\n" \
- "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+ raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
- model_dir)
-
- model_file = os.path.join(model_dir, 'model.onnx')
+ model_dir
+ )
+
+ model_file = os.path.join(model_dir, "model.onnx")
if quantize:
- model_file = os.path.join(model_dir, 'model_quant.onnx')
+ model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file):
- print(".onnx is not exist, begin to export onnx")
+ print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
- raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
- "\npip3 install -U funasr\n" \
- "For the users in China, you could install with the command:\n" \
- "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+ raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
- model_dir = model.export(type="onnx", quantize=quantize)
-
- config_file = os.path.join(model_dir, 'punc.yaml')
+ model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
+
+ config_file = os.path.join(model_dir, "config.yaml")
config = read_yaml(config_file)
- token_list = os.path.join(model_dir, 'tokens.json')
- with open(token_list, 'r', encoding='utf-8') as f:
+ token_list = os.path.join(model_dir, "tokens.json")
+ with open(token_list, "r", encoding="utf-8") as f:
token_list = json.load(f)
self.converter = TokenIDConverter(token_list)
- self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
+ self.ort_infer = OrtInferSession(
+ model_file, device_id, intra_op_num_threads=intra_op_num_threads
+ )
self.batch_size = 1
- self.punc_list = config['punc_list']
+ self.punc_list = config["model_conf"]["punc_list"]
self.period = 0
for i in range(len(self.punc_list)):
if self.punc_list[i] == ",":
@@ -76,9 +81,9 @@
self.punc_list[i] = "锛�"
elif self.punc_list[i] == "銆�":
self.period = i
- if "seg_jieba" in config:
+ self.jieba_usr_dict_path = os.path.join(model_dir, "jieba_usr_dict")
+ if os.path.exists(self.jieba_usr_dict_path):
self.seg_jieba = True
- self.jieba_usr_dict_path = os.path.join(model_dir, 'jieba_usr_dict')
self.code_mix_split_words_jieba = code_mix_split_words_jieba(self.jieba_usr_dict_path)
else:
self.seg_jieba = False
@@ -101,15 +106,15 @@
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
- mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int32')
+ mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype="int32")
data = {
- "text": mini_sentence_id[None,:],
- "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
+ "text": mini_sentence_id[None, :],
+ "text_lengths": np.array([len(mini_sentence_id)], dtype="int32"),
}
try:
- outputs = self.infer(data['text'], data['text_lengths'])
+ outputs = self.infer(data["text"], data["text_lengths"])
y = outputs[0]
- punctuations = np.argmax(y,axis=-1)[0]
+ punctuations = np.argmax(y, axis=-1)[0]
assert punctuations.size == len(mini_sentence)
except ONNXRuntimeError:
logging.warning("error")
@@ -119,26 +124,36 @@
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
- if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+ if (
+ self.punc_list[punctuations[i]] == "銆�"
+ or self.punc_list[punctuations[i]] == "锛�"
+ ):
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
last_comma_index = i
- if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+ if (
+ sentenceEnd < 0
+ and len(mini_sentence) > cache_pop_trigger_limit
+ and last_comma_index >= 0
+ ):
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
- cache_sent = mini_sentence[sentenceEnd + 1:]
- cache_sent_id = mini_sentence_id[sentenceEnd + 1:].tolist()
- mini_sentence = mini_sentence[0:sentenceEnd + 1]
- punctuations = punctuations[0:sentenceEnd + 1]
+ cache_sent = mini_sentence[sentenceEnd + 1 :]
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1 :].tolist()
+ mini_sentence = mini_sentence[0 : sentenceEnd + 1]
+ punctuations = punctuations[0 : sentenceEnd + 1]
new_mini_sentence_punc += [int(x) for x in punctuations]
words_with_punc = []
for i in range(len(mini_sentence)):
if i > 0:
- if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
+ if (
+ len(mini_sentence[i][0].encode()) == 1
+ and len(mini_sentence[i - 1][0].encode()) == 1
+ ):
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
@@ -156,8 +171,7 @@
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out
- def infer(self, feats: np.ndarray,
- feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ def infer(self, feats: np.ndarray, feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len])
return outputs
@@ -168,14 +182,9 @@
CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
https://arxiv.org/pdf/2003.01309.pdf
"""
- def __init__(self, model_dir: Union[str, Path] = None,
- batch_size: int = 1,
- device_id: Union[str, int] = "-1",
- quantize: bool = False,
- intra_op_num_threads: int = 4,
- cache_dir: str = None
- ):
- super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads, cache_dir=cache_dir)
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
def __call__(self, text: str, param_dict: map, split_size=20):
cache_key = "cache"
@@ -195,7 +204,7 @@
assert len(mini_sentences) == len(mini_sentences_id)
cache_sent = []
- cache_sent_id = np.array([], dtype='int32')
+ cache_sent_id = np.array([], dtype="int32")
sentence_punc_list = []
sentence_words_list = []
cache_pop_trigger_limit = 200
@@ -204,19 +213,23 @@
mini_sentence = mini_sentences[mini_sentence_i]
mini_sentence_id = mini_sentences_id[mini_sentence_i]
mini_sentence = cache_sent + mini_sentence
- mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
+ mini_sentence_id = np.concatenate(
+ (cache_sent_id, mini_sentence_id), axis=0, dtype="int32"
+ )
text_length = len(mini_sentence_id)
vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
data = {
- "input": mini_sentence_id[None,:],
- "text_lengths": np.array([text_length], dtype='int32'),
+ "input": mini_sentence_id[None, :],
+ "text_lengths": np.array([text_length], dtype="int32"),
"vad_mask": vad_mask,
- "sub_masks": vad_mask
+ "sub_masks": vad_mask,
}
try:
- outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
+ outputs = self.infer(
+ data["input"], data["text_lengths"], data["vad_mask"], data["sub_masks"]
+ )
y = outputs[0]
- punctuations = np.argmax(y,axis=-1)[0]
+ punctuations = np.argmax(y, axis=-1)[0]
assert punctuations.size == len(mini_sentence)
except ONNXRuntimeError:
logging.warning("error")
@@ -226,20 +239,27 @@
sentenceEnd = -1
last_comma_index = -1
for i in range(len(punctuations) - 2, 1, -1):
- if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+ if (
+ self.punc_list[punctuations[i]] == "銆�"
+ or self.punc_list[punctuations[i]] == "锛�"
+ ):
sentenceEnd = i
break
if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
last_comma_index = i
- if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+ if (
+ sentenceEnd < 0
+ and len(mini_sentence) > cache_pop_trigger_limit
+ and last_comma_index >= 0
+ ):
# The sentence it too long, cut off at a comma.
sentenceEnd = last_comma_index
punctuations[sentenceEnd] = self.period
- cache_sent = mini_sentence[sentenceEnd + 1:]
- cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
- mini_sentence = mini_sentence[0:sentenceEnd + 1]
- punctuations = punctuations[0:sentenceEnd + 1]
+ cache_sent = mini_sentence[sentenceEnd + 1 :]
+ cache_sent_id = mini_sentence_id[sentenceEnd + 1 :]
+ mini_sentence = mini_sentence[0 : sentenceEnd + 1]
+ punctuations = punctuations[0 : sentenceEnd + 1]
punctuations_np = [int(x) for x in punctuations]
new_mini_sentence_punc += punctuations_np
@@ -251,7 +271,10 @@
sentence_punc_list_out = []
for i in range(0, len(sentence_words_list)):
if i > 0:
- if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
+ if (
+ len(sentence_words_list[i][0].encode()) == 1
+ and len(sentence_words_list[i - 1][-1].encode()) == 1
+ ):
sentence_words_list[i] = " " + sentence_words_list[i]
if skip_num < len(cache):
skip_num += 1
@@ -268,7 +291,7 @@
if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
sentenceEnd = i
break
- cache_out = sentence_words_list[sentenceEnd + 1:]
+ cache_out = sentence_words_list[sentenceEnd + 1 :]
if sentence_out[-1] in self.punc_list:
sentence_out = sentence_out[:-1]
sentence_punc_list_out[-1] = "_"
@@ -286,15 +309,12 @@
ret = np.ones((size, size), dtype=dtype)
if vad_pos <= 0 or vad_pos >= size:
return ret
- sub_corner = np.zeros(
- (vad_pos - 1, size - vad_pos), dtype=dtype)
- ret[0:vad_pos - 1, vad_pos:] = sub_corner
+ sub_corner = np.zeros((vad_pos - 1, size - vad_pos), dtype=dtype)
+ ret[0 : vad_pos - 1, vad_pos:] = sub_corner
return ret
- def infer(self, feats: np.ndarray,
- feats_len: np.ndarray,
- vad_mask: np.ndarray,
- sub_masks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ def infer(
+ self, feats: np.ndarray, feats_len: np.ndarray, vad_mask: np.ndarray, sub_masks: np.ndarray
+ ) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
return outputs
-
--
Gitblit v1.9.1