From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/bin/punc_infer.py |   31 +++++++++++++++++++++++++++++--
 1 files changed, 29 insertions(+), 2 deletions(-)

diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
index ac96811..9efeb5b 100644
--- a/funasr/bin/punc_infer.py
+++ b/funasr/bin/punc_infer.py
@@ -8,6 +8,7 @@
 
 import numpy as np
 import torch
+import os
 
 from funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
@@ -41,6 +42,11 @@
                 self.punc_list[i] = "锛�"
             elif self.punc_list[i] == "銆�":
                 self.period = i
+        self.seg_dict_file = None
+        self.seg_jieba = False
+        if "seg_jieba" in train_args:
+            self.seg_jieba = train_args.seg_jieba
+            self.seg_dict_file = os.path.dirname(model_file)+"/"+ "jieba_usr_dict"
         self.preprocessor = CodeMixTokenizerCommonPreprocessor(
             train=False,
             token_type=train_args.token_type,
@@ -50,6 +56,8 @@
             g2p_type=train_args.g2p,
             text_name="text",
             non_linguistic_symbols=train_args.non_linguistic_symbols,
+            seg_jieba=self.seg_jieba,
+            seg_dict_file=self.seg_dict_file
         )
 
     @torch.no_grad()
@@ -109,12 +117,25 @@
             new_mini_sentence_punc += [int(x) for x in punctuations_np]
             words_with_punc = []
             for i in range(len(mini_sentence)):
+                if (i==0 or self.punc_list[punctuations[i-1]] == "銆�" or self.punc_list[punctuations[i-1]] == "锛�") and len(mini_sentence[i][0].encode()) == 1:
+                    mini_sentence[i] = mini_sentence[i].capitalize()
+                if i == 0:
+                    if len(mini_sentence[i][0].encode()) == 1:
+                        mini_sentence[i] = " " + mini_sentence[i]
                 if i > 0:
                     if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
                         mini_sentence[i] = " " + mini_sentence[i]
                 words_with_punc.append(mini_sentence[i])
                 if self.punc_list[punctuations[i]] != "_":
-                    words_with_punc.append(self.punc_list[punctuations[i]])
+                    punc_res = self.punc_list[punctuations[i]]
+                    if len(mini_sentence[i][0].encode()) == 1:
+                        if punc_res == "锛�":
+                            punc_res = ","
+                        elif punc_res == "銆�":
+                            punc_res = "."
+                        elif punc_res == "锛�":
+                            punc_res = "?"
+                    words_with_punc.append(punc_res)
             new_mini_sentence += "".join(words_with_punc)
             # Add Period for the end of the sentence
             new_mini_sentence_out = new_mini_sentence
@@ -123,9 +144,15 @@
                 if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
                     new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
-                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�":
+                elif new_mini_sentence[-1] == ",":
+                    new_mini_sentence_out = new_mini_sentence[:-1] + "."
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+                elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
                     new_mini_sentence_out = new_mini_sentence + "銆�"
                     new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
+                    new_mini_sentence_out = new_mini_sentence + "."
+                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
         return new_mini_sentence_out, new_mini_sentence_punc_out
 
 

--
Gitblit v1.9.1