From 60d78d9d849e737e307636a5ae3c96c157fa267c Mon Sep 17 00:00:00 2001
From: chenmengzheAAA <123789350+chenmengzheAAA@users.noreply.github.com>
Date: 星期二, 12 九月 2023 22:22:06 +0800
Subject: [PATCH] Merge pull request #941 from alibaba-damo-academy/dev_cmz

---
 funasr/bin/punc_infer.py |   23 +++++++++++++++++++++--
 1 files changed, 21 insertions(+), 2 deletions(-)

diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
index 7b61717..9efeb5b 100644
--- a/funasr/bin/punc_infer.py
+++ b/funasr/bin/punc_infer.py
@@ -117,12 +117,25 @@
             new_mini_sentence_punc += [int(x) for x in punctuations_np]
             words_with_punc = []
             for i in range(len(mini_sentence)):
+                if (i==0 or self.punc_list[punctuations[i-1]] == "銆�" or self.punc_list[punctuations[i-1]] == "锛�") and len(mini_sentence[i][0].encode()) == 1:
+                    mini_sentence[i] = mini_sentence[i].capitalize()
+                if i == 0:
+                    if len(mini_sentence[i][0].encode()) == 1:
+                        mini_sentence[i] = " " + mini_sentence[i]
                 if i > 0:
                     if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
                         mini_sentence[i] = " " + mini_sentence[i]
                 words_with_punc.append(mini_sentence[i])
                 if self.punc_list[punctuations[i]] != "_":
-                    words_with_punc.append(self.punc_list[punctuations[i]])
+                    punc_res = self.punc_list[punctuations[i]]
+                    if len(mini_sentence[i][0].encode()) == 1:
+                        if punc_res == "锛�":
+                            punc_res = ","
+                        elif punc_res == "銆�":
+                            punc_res = "."
+                        elif punc_res == "锛�":
+                            punc_res = "?"
+                    words_with_punc.append(punc_res)
             new_mini_sentence += "".join(words_with_punc)
             # Add Period for the end of the sentence
             new_mini_sentence_out = new_mini_sentence
@@ -131,9 +144,15 @@
                 if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
                     new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
-                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�":
+                elif new_mini_sentence[-1] == ",":
+                    new_mini_sentence_out = new_mini_sentence[:-1] + "."
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
                     new_mini_sentence_out = new_mini_sentence + "銆�"
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
+                    new_mini_sentence_out = new_mini_sentence + "."
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
         return new_mini_sentence_out, new_mini_sentence_punc_out
 
 

--
Gitblit v1.9.1