From 79007d36f1636eb51e0cead6bc0e6b18ff1f8253 Mon Sep 17 00:00:00 2001
From: 九耳 <mengzhe.cmz@alibaba-inc.com>
Date: 星期五, 07 四月 2023 14:43:32 +0800
Subject: [PATCH] vad_realtime onnx runnable

---
 funasr/runtime/python/onnxruntime/demo_punc_online.py     |   15 +++++
 funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py |    2 
 funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py |  130 +++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 146 insertions(+), 1 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/demo_punc_online.py b/funasr/runtime/python/onnxruntime/demo_punc_online.py
new file mode 100644
index 0000000..853db50
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_punc_online.py
@@ -0,0 +1,15 @@
+from funasr_onnx import CT_Transformer_VadRealtime
+
+model_dir = "/disk1/mengzhe.cmz/workspace/FunASR/funasr/export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
+model = CT_Transformer_VadRealtime(model_dir)
+
+text_in  = "璺ㄥ娌虫祦鏄吇鑲叉部宀竱浜烘皯鐨勭敓鍛戒箣婧愰暱鏈熶互鏉ヤ负甯姪涓嬫父鍦板尯闃茬伨鍑忕伨涓柟鎶�鏈汉鍛榺鍦ㄤ笂娓稿湴鍖烘瀬涓烘伓鍔g殑鑷劧鏉′欢涓嬪厠鏈嶅法澶у洶闅剧敋鑷冲啋鐫�鐢熷懡鍗遍櫓|鍚戝嵃鏂规彁渚涙睕鏈熸按鏂囪祫鏂欏鐞嗙揣鎬ヤ簨浠朵腑鏂归噸瑙嗗嵃鏂瑰湪璺ㄥ娌虫祦>闂涓婄殑鍏冲垏|鎰挎剰杩涗竴姝ュ畬鍠勫弻鏂硅仈鍚堝伐浣滄満鍒秥鍑℃槸|涓柟鑳藉仛鐨勬垜浠瑋閮戒細鍘诲仛鑰屼笖浼氬仛寰楁洿濂芥垜璇峰嵃搴︽湅鍙嬩滑鏀惧績涓浗鍦ㄤ笂娓哥殑|浠讳綍寮�鍙戝埄鐢ㄩ兘浼氱粡杩囩瀛瑙勫垝鍜岃璇佸吋椤句笂涓嬫父鐨勫埄鐩�"
+
+vads = text_in.split("|")
+rec_result_all="outputs:"
+param_dict = {"cache": []}
+for vad in vads:
+    result = model(vad, param_dict=param_dict)
+    rec_result_all += result[0]
+
+print(rec_result_all)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
index 825f2ca..86f0e8e 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -2,4 +2,4 @@
 from .paraformer_bin import Paraformer
 from .vad_bin import Fsmn_vad
 from .punc_bin import CT_Transformer
-#from .punc_bin import VadRealtimeCT_Transformer
+from .punc_bin import CT_Transformer_VadRealtime
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 949172e..0dc728a 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -117,3 +117,133 @@
         outputs = self.ort_infer([feats, feats_len])
         return outputs
 
+
+class CT_Transformer_VadRealtime(CT_Transformer):
+    def __init__(self, model_dir: Union[str, Path] = None,
+                 batch_size: int = 1,
+                 device_id: Union[str, int] = "-1",
+                 quantize: bool = False,
+                 intra_op_num_threads: int = 4
+                 ):
+        super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
+
+    def __call__(self, text: str, param_dict: map, split_size=20):
+        cache_key = "cache"
+        assert cache_key in param_dict
+        cache = param_dict[cache_key]
+        if cache is not None and len(cache) > 0:
+            precache = "".join(cache)
+        else:
+            precache = ""
+            cache = []
+        full_text = precache + text
+        split_text = code_mix_split_words(full_text)
+        split_text_id = self.converter.tokens2ids(split_text)
+        mini_sentences = split_to_mini_sentence(split_text, split_size)
+        mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
+        new_mini_sentence_punc = []
+        assert len(mini_sentences) == len(mini_sentences_id)
+
+        cache_sent = []
+        cache_sent_id = np.array([], dtype='int32')
+        sentence_punc_list = []
+        sentence_words_list = []
+        cache_pop_trigger_limit = 200
+        skip_num = 0
+        for mini_sentence_i in range(len(mini_sentences)):
+            mini_sentence = mini_sentences[mini_sentence_i]
+            mini_sentence_id = mini_sentences_id[mini_sentence_i]
+            mini_sentence = cache_sent + mini_sentence
+            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+            text_length = len(mini_sentence_id)
+            data = {
+                "input": mini_sentence_id[None,:],
+                "text_lengths": np.array([text_length], dtype='int32'),
+                "vad_mask": self.vad_mask(text_length, len(cache) - 1)[None, None, :, :].astype(np.float32),
+                "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
+            }
+            try:
+                outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
+                y = outputs[0]
+                punctuations = np.argmax(y,axis=-1)[0]
+                assert punctuations.size == len(mini_sentence)
+            except ONNXRuntimeError:
+                logging.warning("error")
+
+            # Search for the last Period/QuestionMark as cache
+            if mini_sentence_i < len(mini_sentences) - 1:
+                sentenceEnd = -1
+                last_comma_index = -1
+                for i in range(len(punctuations) - 2, 1, -1):
+                    if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+                        sentenceEnd = i
+                        break
+                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+                        last_comma_index = i
+
+                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+                    # The sentence it too long, cut off at a comma.
+                    sentenceEnd = last_comma_index
+                    punctuations[sentenceEnd] = self.period
+                cache_sent = mini_sentence[sentenceEnd + 1:]
+                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+                mini_sentence = mini_sentence[0:sentenceEnd + 1]
+                punctuations = punctuations[0:sentenceEnd + 1]
+
+            punctuations_np = [int(x) for x in punctuations]
+            new_mini_sentence_punc += punctuations_np
+            sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
+            sentence_words_list += mini_sentence
+
+        assert len(sentence_punc_list) == len(sentence_words_list)
+        words_with_punc = []
+        sentence_punc_list_out = []
+        for i in range(0, len(sentence_words_list)):
+            if i > 0:
+                if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
+                    sentence_words_list[i] = " " + sentence_words_list[i]
+            if skip_num < len(cache):
+                skip_num += 1
+            else:
+                words_with_punc.append(sentence_words_list[i])
+            if skip_num >= len(cache):
+                sentence_punc_list_out.append(sentence_punc_list[i])
+                if sentence_punc_list[i] != "_":
+                    words_with_punc.append(sentence_punc_list[i])
+        sentence_out = "".join(words_with_punc)
+
+        sentenceEnd = -1
+        for i in range(len(sentence_punc_list) - 2, 1, -1):
+            if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
+                sentenceEnd = i
+                break
+        cache_out = sentence_words_list[sentenceEnd + 1:]
+        if sentence_out[-1] in self.punc_list:
+            sentence_out = sentence_out[:-1]
+            sentence_punc_list_out[-1] = "_"
+        param_dict[cache_key] = cache_out
+        return sentence_out, sentence_punc_list_out, cache_out
+
+    def vad_mask(self, size, vad_pos, dtype=np.bool):
+        """Create mask for decoder self-attention.
+
+        :param int size: size of mask
+        :param int vad_pos: index of vad index
+        :param torch.dtype dtype: result dtype
+        :rtype: torch.Tensor (B, Lmax, Lmax)
+        """
+        ret = np.ones((size, size), dtype=dtype)
+        if vad_pos <= 0 or vad_pos >= size:
+            return ret
+        sub_corner = np.zeros(
+            (vad_pos - 1, size - vad_pos), dtype=dtype)
+        ret[0:vad_pos - 1, vad_pos:] = sub_corner
+        return ret
+
+    def infer(self, feats: np.ndarray,
+              feats_len: np.ndarray,
+              vad_mask: np.ndarray,
+              sub_masks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+        outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
+        return outputs
+

--
Gitblit v1.9.1