From aa3fe1a353bde71d106755d030d9e5300fbde328 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 22 七月 2024 19:02:15 +0800
Subject: [PATCH] python runtime

---
 runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py |  120 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 120 insertions(+), 0 deletions(-)

diff --git a/runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py b/runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py
index da8065e..d144d31 100644
--- a/runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py
+++ b/runtime/python/onnxruntime/funasr_onnx/utils/postprocess_utils.py
@@ -296,3 +296,123 @@
             real_word_lists.append(ch)
     sentence = "".join(word_lists)
     return sentence, real_word_lists
+
+
+emo_dict = {
+    "<|HAPPY|>": "馃槉",
+    "<|SAD|>": "馃様",
+    "<|ANGRY|>": "馃槨",
+    "<|NEUTRAL|>": "",
+    "<|FEARFUL|>": "馃槹",
+    "<|DISGUSTED|>": "馃あ",
+    "<|SURPRISED|>": "馃槷",
+}
+
+event_dict = {
+    "<|BGM|>": "馃幖",
+    "<|Speech|>": "",
+    "<|Applause|>": "馃憦",
+    "<|Laughter|>": "馃榾",
+    "<|Cry|>": "馃槶",
+    "<|Sneeze|>": "馃ぇ",
+    "<|Breath|>": "",
+    "<|Cough|>": "馃ぇ",
+}
+
+lang_dict = {
+    "<|zh|>": "<|lang|>",
+    "<|en|>": "<|lang|>",
+    "<|yue|>": "<|lang|>",
+    "<|ja|>": "<|lang|>",
+    "<|ko|>": "<|lang|>",
+    "<|nospeech|>": "<|lang|>",
+}
+
+emoji_dict = {
+    "<|nospeech|><|Event_UNK|>": "鉂�",
+    "<|zh|>": "",
+    "<|en|>": "",
+    "<|yue|>": "",
+    "<|ja|>": "",
+    "<|ko|>": "",
+    "<|nospeech|>": "",
+    "<|HAPPY|>": "馃槉",
+    "<|SAD|>": "馃様",
+    "<|ANGRY|>": "馃槨",
+    "<|NEUTRAL|>": "",
+    "<|BGM|>": "馃幖",
+    "<|Speech|>": "",
+    "<|Applause|>": "馃憦",
+    "<|Laughter|>": "馃榾",
+    "<|FEARFUL|>": "馃槹",
+    "<|DISGUSTED|>": "馃あ",
+    "<|SURPRISED|>": "馃槷",
+    "<|Cry|>": "馃槶",
+    "<|EMO_UNKNOWN|>": "",
+    "<|Sneeze|>": "馃ぇ",
+    "<|Breath|>": "",
+    "<|Cough|>": "馃樂",
+    "<|Sing|>": "",
+    "<|Speech_Noise|>": "",
+    "<|withitn|>": "",
+    "<|woitn|>": "",
+    "<|GBG|>": "",
+    "<|Event_UNK|>": "",
+}
+
+emo_set = {"馃槉", "馃様", "馃槨", "馃槹", "馃あ", "馃槷"}
+event_set = {
+    "馃幖",
+    "馃憦",
+    "馃榾",
+    "馃槶",
+    "馃ぇ",
+    "馃樂",
+}
+
+
+def format_str_v2(s):
+    sptk_dict = {}
+    for sptk in emoji_dict:
+        sptk_dict[sptk] = s.count(sptk)
+        s = s.replace(sptk, "")
+    emo = "<|NEUTRAL|>"
+    for e in emo_dict:
+        if sptk_dict[e] > sptk_dict[emo]:
+            emo = e
+    for e in event_dict:
+        if sptk_dict[e] > 0:
+            s = event_dict[e] + s
+    s = s + emo_dict[emo]
+
+    for emoji in emo_set.union(event_set):
+        s = s.replace(" " + emoji, emoji)
+        s = s.replace(emoji + " ", emoji)
+    return s.strip()
+
+
+def rich_transcription_postprocess(s):
+    def get_emo(s):
+        return s[-1] if s[-1] in emo_set else None
+
+    def get_event(s):
+        return s[0] if s[0] in event_set else None
+
+    s = s.replace("<|nospeech|><|Event_UNK|>", "鉂�")
+    for lang in lang_dict:
+        s = s.replace(lang, "<|lang|>")
+    s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
+    new_s = " " + s_list[0]
+    cur_ent_event = get_event(new_s)
+    for i in range(1, len(s_list)):
+        if len(s_list[i]) == 0:
+            continue
+        if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
+            s_list[i] = s_list[i][1:]
+        # else:
+        cur_ent_event = get_event(s_list[i])
+        if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
+            new_s = new_s[:-1]
+        new_s += s_list[i].strip().lstrip()
+    new_s = new_s.replace("The.", " ")
+    return new_s.strip()

--
Gitblit v1.9.1