From 60d78d9d849e737e307636a5ae3c96c157fa267c Mon Sep 17 00:00:00 2001
From: chenmengzheAAA <123789350+chenmengzheAAA@users.noreply.github.com>
Date: 星期二, 12 九月 2023 22:22:06 +0800
Subject: [PATCH] Merge pull request #941 from alibaba-damo-academy/dev_cmz
---
funasr/bin/punc_infer.py | 23 +++++++++++++++++++++--
1 files changed, 21 insertions(+), 2 deletions(-)
diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
index 7b61717..9efeb5b 100644
--- a/funasr/bin/punc_infer.py
+++ b/funasr/bin/punc_infer.py
@@ -117,12 +117,25 @@
new_mini_sentence_punc += [int(x) for x in punctuations_np]
words_with_punc = []
for i in range(len(mini_sentence)):
+ if (i==0 or self.punc_list[punctuations[i-1]] == "銆�" or self.punc_list[punctuations[i-1]] == "锛�") and len(mini_sentence[i][0].encode()) == 1:
+ mini_sentence[i] = mini_sentence[i].capitalize()
+ if i == 0:
+ if len(mini_sentence[i][0].encode()) == 1:
+ mini_sentence[i] = " " + mini_sentence[i]
if i > 0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
- words_with_punc.append(self.punc_list[punctuations[i]])
+ punc_res = self.punc_list[punctuations[i]]
+ if len(mini_sentence[i][0].encode()) == 1:
+ if punc_res == "锛�":
+ punc_res = ","
+ elif punc_res == "銆�":
+ punc_res = "."
+ elif punc_res == "锛�":
+ punc_res = "?"
+ words_with_punc.append(punc_res)
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
@@ -131,9 +144,15 @@
if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
- elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�":
+ elif new_mini_sentence[-1] == ",":
+ new_mini_sentence_out = new_mini_sentence[:-1] + "."
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
new_mini_sentence_out = new_mini_sentence + "銆�"
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
+ new_mini_sentence_out = new_mini_sentence + "."
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out
--
Gitblit v1.9.1