From 79007d36f1636eb51e0cead6bc0e6b18ff1f8253 Mon Sep 17 00:00:00 2001
From: 九耳 <mengzhe.cmz@alibaba-inc.com>
Date: 星期五, 07 四月 2023 14:43:32 +0800
Subject: [PATCH] vad_realtime onnx runnable

---
 funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py |  130 +++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 130 insertions(+), 0 deletions(-)

diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
index 949172e..0dc728a 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -117,3 +117,133 @@
         outputs = self.ort_infer([feats, feats_len])
         return outputs
 
+
+class CT_Transformer_VadRealtime(CT_Transformer):
+    def __init__(self, model_dir: Union[str, Path] = None,
+                 batch_size: int = 1,
+                 device_id: Union[str, int] = "-1",
+                 quantize: bool = False,
+                 intra_op_num_threads: int = 4
+                 ):
+        super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
+
+    def __call__(self, text: str, param_dict: map, split_size=20):
+        cache_key = "cache"
+        assert cache_key in param_dict
+        cache = param_dict[cache_key]
+        if cache is not None and len(cache) > 0:
+            precache = "".join(cache)
+        else:
+            precache = ""
+            cache = []
+        full_text = precache + text
+        split_text = code_mix_split_words(full_text)
+        split_text_id = self.converter.tokens2ids(split_text)
+        mini_sentences = split_to_mini_sentence(split_text, split_size)
+        mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
+        new_mini_sentence_punc = []
+        assert len(mini_sentences) == len(mini_sentences_id)
+
+        cache_sent = []
+        cache_sent_id = np.array([], dtype='int32')
+        sentence_punc_list = []
+        sentence_words_list = []
+        cache_pop_trigger_limit = 200
+        skip_num = 0
+        for mini_sentence_i in range(len(mini_sentences)):
+            mini_sentence = mini_sentences[mini_sentence_i]
+            mini_sentence_id = mini_sentences_id[mini_sentence_i]
+            mini_sentence = cache_sent + mini_sentence
+            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
+            text_length = len(mini_sentence_id)
+            data = {
+                "input": mini_sentence_id[None,:],
+                "text_lengths": np.array([text_length], dtype='int32'),
+                "vad_mask": self.vad_mask(text_length, len(cache) - 1)[None, None, :, :].astype(np.float32),
+                "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
+            }
+            try:
+                outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
+                y = outputs[0]
+                punctuations = np.argmax(y,axis=-1)[0]
+                assert punctuations.size == len(mini_sentence)
+            except ONNXRuntimeError:
+                logging.warning("error")
+
+            # Search for the last Period/QuestionMark as cache
+            if mini_sentence_i < len(mini_sentences) - 1:
+                sentenceEnd = -1
+                last_comma_index = -1
+                for i in range(len(punctuations) - 2, 1, -1):
+                    if self.punc_list[punctuations[i]] == "銆�" or self.punc_list[punctuations[i]] == "锛�":
+                        sentenceEnd = i
+                        break
+                    if last_comma_index < 0 and self.punc_list[punctuations[i]] == "锛�":
+                        last_comma_index = i
+
+                if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
+                    # The sentence it too long, cut off at a comma.
+                    sentenceEnd = last_comma_index
+                    punctuations[sentenceEnd] = self.period
+                cache_sent = mini_sentence[sentenceEnd + 1:]
+                cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
+                mini_sentence = mini_sentence[0:sentenceEnd + 1]
+                punctuations = punctuations[0:sentenceEnd + 1]
+
+            punctuations_np = [int(x) for x in punctuations]
+            new_mini_sentence_punc += punctuations_np
+            sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
+            sentence_words_list += mini_sentence
+
+        assert len(sentence_punc_list) == len(sentence_words_list)
+        words_with_punc = []
+        sentence_punc_list_out = []
+        for i in range(0, len(sentence_words_list)):
+            if i > 0:
+                if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
+                    sentence_words_list[i] = " " + sentence_words_list[i]
+            if skip_num < len(cache):
+                skip_num += 1
+            else:
+                words_with_punc.append(sentence_words_list[i])
+            if skip_num >= len(cache):
+                sentence_punc_list_out.append(sentence_punc_list[i])
+                if sentence_punc_list[i] != "_":
+                    words_with_punc.append(sentence_punc_list[i])
+        sentence_out = "".join(words_with_punc)
+
+        sentenceEnd = -1
+        for i in range(len(sentence_punc_list) - 2, 1, -1):
+            if sentence_punc_list[i] == "銆�" or sentence_punc_list[i] == "锛�":
+                sentenceEnd = i
+                break
+        cache_out = sentence_words_list[sentenceEnd + 1:]
+        if sentence_out[-1] in self.punc_list:
+            sentence_out = sentence_out[:-1]
+            sentence_punc_list_out[-1] = "_"
+        param_dict[cache_key] = cache_out
+        return sentence_out, sentence_punc_list_out, cache_out
+
+    def vad_mask(self, size, vad_pos, dtype=np.bool):
+        """Create mask for decoder self-attention.
+
+        :param int size: size of mask
+        :param int vad_pos: index of vad index
+        :param torch.dtype dtype: result dtype
+        :rtype: torch.Tensor (B, Lmax, Lmax)
+        """
+        ret = np.ones((size, size), dtype=dtype)
+        if vad_pos <= 0 or vad_pos >= size:
+            return ret
+        sub_corner = np.zeros(
+            (vad_pos - 1, size - vad_pos), dtype=dtype)
+        ret[0:vad_pos - 1, vad_pos:] = sub_corner
+        return ret
+
+    def infer(self, feats: np.ndarray,
+              feats_len: np.ndarray,
+              vad_mask: np.ndarray,
+              sub_masks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+        outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
+        return outputs
+

--
Gitblit v1.9.1