From 5448e926a215066193f8c5a12e0c7dfe55c29579 Mon Sep 17 00:00:00 2001
From: 北念 <lzr265946@alibaba-inc.com>
Date: 星期三, 10 七月 2024 11:27:35 +0800
Subject: [PATCH] add postprocess for sensevoice

---
 funasr/utils/postprocess_utils.py |  110 +++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 110 insertions(+), 0 deletions(-)

diff --git a/funasr/utils/postprocess_utils.py b/funasr/utils/postprocess_utils.py
index 34f933c..c4a0789 100644
--- a/funasr/utils/postprocess_utils.py
+++ b/funasr/utils/postprocess_utils.py
@@ -299,3 +299,113 @@
             real_word_lists.append(ch)
     sentence = "".join(word_lists)
     return sentence, real_word_lists
+
+emo_dict = {
+        "<|HAPPY|>": "馃槉",
+        "<|SAD|>": "馃様",
+        "<|ANGRY|>": "馃槨",
+        "<|NEUTRAL|>": "",
+        "<|FEARFUL|>": "馃槹",
+        "<|DISGUSTED|>": "馃あ",
+        "<|SURPRISED|>": "馃槷",
+}
+
+event_dict = {
+        "<|BGM|>": "馃幖",
+        "<|Speech|>": "",
+        "<|Applause|>": "馃憦",
+        "<|Laughter|>": "馃榾",
+        "<|Cry|>": "馃槶",
+        "<|Sneeze|>": "馃ぇ",
+        "<|Breath|>": "",
+        "<|Cough|>": "馃ぇ",
+}
+
+lang_dict =  {
+    "<|zh|>": "<|lang|>",
+    "<|en|>": "<|lang|>",
+    "<|yue|>": "<|lang|>",
+    "<|ja|>": "<|lang|>",
+    "<|ko|>": "<|lang|>",
+    "<|nospeech|>": "<|lang|>",
+}
+
+emoji_dict = {
+        "<|nospeech|><|Event_UNK|>": "鉂�",
+        "<|zh|>": "",
+        "<|en|>": "",
+        "<|yue|>": "",
+        "<|ja|>": "",
+        "<|ko|>": "",
+        "<|nospeech|>": "",
+        "<|HAPPY|>": "馃槉",
+        "<|SAD|>": "馃様",
+        "<|ANGRY|>": "馃槨",
+        "<|NEUTRAL|>": "",
+        "<|BGM|>": "馃幖",
+        "<|Speech|>": "",
+        "<|Applause|>": "馃憦",
+        "<|Laughter|>": "馃榾",
+        "<|FEARFUL|>": "馃槹",
+        "<|DISGUSTED|>": "馃あ",
+        "<|SURPRISED|>": "馃槷",
+        "<|Cry|>": "馃槶",
+        "<|EMO_UNKNOWN|>": "",
+        "<|Sneeze|>": "馃ぇ",
+        "<|Breath|>": "",
+        "<|Cough|>": "馃樂",
+        "<|Sing|>": "",
+        "<|Speech_Noise|>": "",
+        "<|withitn|>": "",
+        "<|woitn|>": "",
+        "<|GBG|>": "",
+        "<|Event_UNK|>": "",
+}
+
+emo_set = {"馃槉", "馃様", "馃槨", "馃槹", "馃あ", "馃槷"}
+event_set = {"馃幖", "馃憦", "馃榾", "馃槶", "馃ぇ", "馃樂",}
+
+def format_str_v2(s):
+        sptk_dict = {}
+        for sptk in emoji_dict:
+                sptk_dict[sptk] = s.count(sptk)
+                s = s.replace(sptk, "")
+        emo = "<|NEUTRAL|>"
+        for e in emo_dict:
+                if sptk_dict[e] > sptk_dict[emo]:
+                        emo = e
+        for e in event_dict:
+                if sptk_dict[e] > 0:
+                        s = event_dict[e] + s
+        s = s + emo_dict[emo]
+
+        for emoji in emo_set.union(event_set):
+                s = s.replace(" " + emoji, emoji)
+                s = s.replace(emoji + " ", emoji)
+        return s.strip()
+
+def rich_transcription_postprocess(s):
+        def get_emo(s):
+                return s[-1] if s[-1] in emo_set else None
+        def get_event(s):
+                return s[0] if s[0] in event_set else None
+
+        s = s.replace("<|nospeech|><|Event_UNK|>", "鉂�")
+        for lang in lang_dict:
+                s = s.replace(lang, "<|lang|>")
+        s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
+        new_s = " " + s_list[0]
+        cur_ent_event = get_event(new_s)
+        for i in range(1, len(s_list)):
+                if len(s_list[i]) == 0:
+                        continue
+                if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
+                        s_list[i] = s_list[i][1:]
+                #else:
+                cur_ent_event = get_event(s_list[i])
+                if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
+                        new_s = new_s[:-1]
+                new_s += s_list[i].strip().lstrip()
+        new_s = new_s.replace("The.", " ")
+        return new_s.strip()
+    

--
Gitblit v1.9.1