From 6c467e6f0abfc6d20d0621fbbf67b4dbd81776cc Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期二, 18 六月 2024 10:01:56 +0800
Subject: [PATCH] Merge pull request #1825 from modelscope/dev_libt

---
 runtime/python/onnxruntime/funasr_onnx/punc_bin.py |  182 +++++++++++++++++++++++++--------------------
 1 files changed, 101 insertions(+), 81 deletions(-)

diff --git a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 4e1014f..ba55186 100644
--- a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -7,67 +7,72 @@
 from typing import List, Union, Tuple
 import numpy as np
 import json
-from .utils.utils import (ONNXRuntimeError,
-                          OrtInferSession, get_logger,
-                          read_yaml)
-from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words,code_mix_split_words_jieba)
+from .utils.utils import ONNXRuntimeError, OrtInferSession, get_logger, read_yaml
+from .utils.utils import (
+    TokenIDConverter,
+    split_to_mini_sentence,
+    code_mix_split_words,
+    code_mix_split_words_jieba,
+)
+
 logging = get_logger()
 
 
-class CT_Transformer():
+class CT_Transformer:
     """
     Author: Speech Lab of DAMO Academy, Alibaba Group
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
     https://arxiv.org/pdf/2003.01309.pdf
     """
-    def __init__(self, model_dir: Union[str, Path] = None,
-                 batch_size: int = 1,
-                 device_id: Union[str, int] = "-1",
-                 quantize: bool = False,
-                 intra_op_num_threads: int = 4,
-                 cache_dir: str = None,
-                 ):
-    
+
+    def __init__(
+        self,
+        model_dir: Union[str, Path] = None,
+        batch_size: int = 1,
+        device_id: Union[str, int] = "-1",
+        quantize: bool = False,
+        intra_op_num_threads: int = 4,
+        cache_dir: str = None,
+        **kwargs
+    ):
+
         if not Path(model_dir).exists():
             try:
                 from modelscope.hub.snapshot_download import snapshot_download
             except:
-                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
-                      "\npip3 install -U modelscope\n" \
-                      "For the users in China, you could install with the command:\n" \
-                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
             try:
                 model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
             except:
                 raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
-                    model_dir)
-    
-        model_file = os.path.join(model_dir, 'model.onnx')
+                    model_dir
+                )
+
+        model_file = os.path.join(model_dir, "model.onnx")
         if quantize:
-            model_file = os.path.join(model_dir, 'model_quant.onnx')
+            model_file = os.path.join(model_dir, "model_quant.onnx")
         if not os.path.exists(model_file):
-            print(".onnx is not exist, begin to export onnx")
+            print(".onnx does not exist, begin to export onnx")
             try:
                 from funasr import AutoModel
             except:
-                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" \
-                      "\npip3 install -U funasr\n" \
-                      "For the users in China, you could install with the command:\n" \
-                      "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
+                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
 
-            model = AutoModel(model=cache_dir)
-            model_dir = model.export(type="onnx", quantize=quantize, device="cpu")
-            
-        config_file = os.path.join(model_dir, 'punc.yaml')
+            model = AutoModel(model=model_dir)
+            model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
+
+        config_file = os.path.join(model_dir, "config.yaml")
         config = read_yaml(config_file)
-        token_list = os.path.join(model_dir, 'tokens.json')
-        with open(token_list, 'r', encoding='utf-8') as f:
+        token_list = os.path.join(model_dir, "tokens.json")
+        with open(token_list, "r", encoding="utf-8") as f:
             token_list = json.load(f)
 
         self.converter = TokenIDConverter(token_list)
-        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
+        self.ort_infer = OrtInferSession(
+            model_file, device_id, intra_op_num_threads=intra_op_num_threads
+        )
         self.batch_size = 1
-        self.punc_list = config['punc_list']
+        self.punc_list = config["model_conf"]["punc_list"]
         self.period = 0
         for i in range(len(self.punc_list)):
             if self.punc_list[i] == ",":
@@ -76,9 +81,9 @@
                 self.punc_list[i] = "锛�"
             elif self.punc_list[i] == "銆�":
                 self.period = i
-        if "seg_jieba" in config:
+        self.jieba_usr_dict_path = os.path.join(model_dir, "jieba_usr_dict")
+        if os.path.exists(self.jieba_usr_dict_path):
             self.seg_jieba = True
-            self.jieba_usr_dict_path = os.path.join(model_dir, 'jieba_usr_dict')
             self.code_mix_split_words_jieba = code_mix_split_words_jieba(self.jieba_usr_dict_path)
         else:
             self.seg_jieba = False
@@ -101,15 +106,15 @@
             mini_sentence = mini_sentences[mini_sentence_i]
             mini_sentence_id = mini_sentences_id[mini_sentence_i]
             mini_sentence = cache_sent + mini_sentence
-            mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int32')
+            mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype="int32")
             data = {
-                "text": mini_sentence_id[None,:],
-                "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
+                "text": mini_sentence_id[None, :],
+                "text_lengths": np.array([len(mini_sentence_id)], dtype="int32"),
             }
             try:
-                outputs = self.infer(data['text'], data['text_lengths'])
+                outputs = self.infer(data["text"], data["text_lengths"])
                 y = outputs[0]
-                punctuations = np.argmax(y,axis=-1)[0]
+                punctuations = np.argmax(y, axis=-1)[0]
                 assert punctuations.size == len(mini_sentence)
             except ONNXRuntimeError:
                 logging.warning("error")
@@ -119,26 +124,36 @@
                 sentenceEnd = -1
                 last_comma_index = -1
                 for i in range(len(punctuations) - 2, 1, -1):
-                    if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+                    if (
+                        self.punc_list[punctuations[i]] == "銆�"
+                        or self.punc_list[punctuations[i]] == "锛�"
+                    ):
                         sentenceEnd = i
                         break
                     if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
                         last_comma_index = i
 
-                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+                if (
+                    sentenceEnd < 0
+                    and len(mini_sentence) > cache_pop_trigger_limit
+                    and last_comma_index >= 0
+                ):
                     # The sentence it too long, cut off at a comma.
                     sentenceEnd = last_comma_index
                     punctuations[sentenceEnd] = self.period
-                cache_sent = mini_sentence[sentenceEnd + 1:]
-                cache_sent_id = mini_sentence_id[sentenceEnd + 1:].tolist()
-                mini_sentence = mini_sentence[0:sentenceEnd + 1]
-                punctuations = punctuations[0:sentenceEnd + 1]
+                cache_sent = mini_sentence[sentenceEnd + 1 :]
+                cache_sent_id = mini_sentence_id[sentenceEnd + 1 :].tolist()
+                mini_sentence = mini_sentence[0 : sentenceEnd + 1]
+                punctuations = punctuations[0 : sentenceEnd + 1]
 
             new_mini_sentence_punc += [int(x) for x in punctuations]
             words_with_punc = []
             for i in range(len(mini_sentence)):
                 if i > 0:
-                    if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
+                    if (
+                        len(mini_sentence[i][0].encode()) == 1
+                        and len(mini_sentence[i - 1][0].encode()) == 1
+                    ):
                         mini_sentence[i] = " " + mini_sentence[i]
                 words_with_punc.append(mini_sentence[i])
                 if self.punc_list[punctuations[i]] != "_":
@@ -156,8 +171,7 @@
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
         return new_mini_sentence_out, new_mini_sentence_punc_out
 
-    def infer(self, feats: np.ndarray,
-              feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+    def infer(self, feats: np.ndarray, feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
         outputs = self.ort_infer([feats, feats_len])
         return outputs
 
@@ -168,14 +182,9 @@
     CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
     https://arxiv.org/pdf/2003.01309.pdf
     """
-    def __init__(self, model_dir: Union[str, Path] = None,
-                 batch_size: int = 1,
-                 device_id: Union[str, int] = "-1",
-                 quantize: bool = False,
-                 intra_op_num_threads: int = 4,
-                 cache_dir: str = None
-                 ):
-        super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads, cache_dir=cache_dir)
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
 
     def __call__(self, text: str, param_dict: map, split_size=20):
         cache_key = "cache"
@@ -195,7 +204,7 @@
         assert len(mini_sentences) == len(mini_sentences_id)
 
         cache_sent = []
-        cache_sent_id = np.array([], dtype='int32')
+        cache_sent_id = np.array([], dtype="int32")
         sentence_punc_list = []
         sentence_words_list = []
         cache_pop_trigger_limit = 200
@@ -204,19 +213,23 @@
             mini_sentence = mini_sentences[mini_sentence_i]
             mini_sentence_id = mini_sentences_id[mini_sentence_i]
             mini_sentence = cache_sent + mini_sentence
-            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
+            mini_sentence_id = np.concatenate(
+                (cache_sent_id, mini_sentence_id), axis=0, dtype="int32"
+            )
             text_length = len(mini_sentence_id)
             vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
             data = {
-                "input": mini_sentence_id[None,:],
-                "text_lengths": np.array([text_length], dtype='int32'),
+                "input": mini_sentence_id[None, :],
+                "text_lengths": np.array([text_length], dtype="int32"),
                 "vad_mask": vad_mask,
-                "sub_masks": vad_mask
+                "sub_masks": vad_mask,
             }
             try:
-                outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
+                outputs = self.infer(
+                    data["input"], data["text_lengths"], data["vad_mask"], data["sub_masks"]
+                )
                 y = outputs[0]
-                punctuations = np.argmax(y,axis=-1)[0]
+                punctuations = np.argmax(y, axis=-1)[0]
                 assert punctuations.size == len(mini_sentence)
             except ONNXRuntimeError:
                 logging.warning("error")
@@ -226,20 +239,27 @@
                 sentenceEnd = -1
                 last_comma_index = -1
                 for i in range(len(punctuations) - 2, 1, -1):
-                    if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+                    if (
+                        self.punc_list[punctuations[i]] == "銆�"
+                        or self.punc_list[punctuations[i]] == "锛�"
+                    ):
                         sentenceEnd = i
                         break
                     if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
                         last_comma_index = i
 
-                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+                if (
+                    sentenceEnd < 0
+                    and len(mini_sentence) > cache_pop_trigger_limit
+                    and last_comma_index >= 0
+                ):
                     # The sentence it too long, cut off at a comma.
                     sentenceEnd = last_comma_index
                     punctuations[sentenceEnd] = self.period
-                cache_sent = mini_sentence[sentenceEnd + 1:]
-                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
-                mini_sentence = mini_sentence[0:sentenceEnd + 1]
-                punctuations = punctuations[0:sentenceEnd + 1]
+                cache_sent = mini_sentence[sentenceEnd + 1 :]
+                cache_sent_id = mini_sentence_id[sentenceEnd + 1 :]
+                mini_sentence = mini_sentence[0 : sentenceEnd + 1]
+                punctuations = punctuations[0 : sentenceEnd + 1]
 
             punctuations_np = [int(x) for x in punctuations]
             new_mini_sentence_punc += punctuations_np
@@ -251,7 +271,10 @@
         sentence_punc_list_out = []
         for i in range(0, len(sentence_words_list)):
             if i > 0:
-                if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
+                if (
+                    len(sentence_words_list[i][0].encode()) == 1
+                    and len(sentence_words_list[i - 1][-1].encode()) == 1
+                ):
                     sentence_words_list[i] = " " + sentence_words_list[i]
             if skip_num < len(cache):
                 skip_num += 1
@@ -268,7 +291,7 @@
             if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
                 sentenceEnd = i
                 break
-        cache_out = sentence_words_list[sentenceEnd + 1:]
+        cache_out = sentence_words_list[sentenceEnd + 1 :]
         if sentence_out[-1] in self.punc_list:
             sentence_out = sentence_out[:-1]
             sentence_punc_list_out[-1] = "_"
@@ -286,15 +309,12 @@
         ret = np.ones((size, size), dtype=dtype)
         if vad_pos <= 0 or vad_pos >= size:
             return ret
-        sub_corner = np.zeros(
-            (vad_pos - 1, size - vad_pos), dtype=dtype)
-        ret[0:vad_pos - 1, vad_pos:] = sub_corner
+        sub_corner = np.zeros((vad_pos - 1, size - vad_pos), dtype=dtype)
+        ret[0 : vad_pos - 1, vad_pos:] = sub_corner
         return ret
 
-    def infer(self, feats: np.ndarray,
-              feats_len: np.ndarray,
-              vad_mask: np.ndarray,
-              sub_masks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+    def infer(
+        self, feats: np.ndarray, feats_len: np.ndarray, vad_mask: np.ndarray, sub_masks: np.ndarray
+    ) -> Tuple[np.ndarray, np.ndarray]:
         outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
         return outputs
-

--
Gitblit v1.9.1