From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/bin/punc_infer.py | 31 +++++++++++++++++++++++++++++--
1 files changed, 29 insertions(+), 2 deletions(-)
diff --git a/funasr/bin/punc_infer.py b/funasr/bin/punc_infer.py
index ac96811..9efeb5b 100644
--- a/funasr/bin/punc_infer.py
+++ b/funasr/bin/punc_infer.py
@@ -8,6 +8,7 @@
import numpy as np
import torch
+import os
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.datasets.preprocessor import CodeMixTokenizerCommonPreprocessor
@@ -41,6 +42,11 @@
self.punc_list[i] = "锛�"
elif self.punc_list[i] == "銆�":
self.period = i
+ self.seg_dict_file = None
+ self.seg_jieba = False
+ if "seg_jieba" in train_args:
+ self.seg_jieba = train_args.seg_jieba
+ self.seg_dict_file = os.path.dirname(model_file)+"/"+ "jieba_usr_dict"
self.preprocessor = CodeMixTokenizerCommonPreprocessor(
train=False,
token_type=train_args.token_type,
@@ -50,6 +56,8 @@
g2p_type=train_args.g2p,
text_name="text",
non_linguistic_symbols=train_args.non_linguistic_symbols,
+ seg_jieba=self.seg_jieba,
+ seg_dict_file=self.seg_dict_file
)
@torch.no_grad()
@@ -109,12 +117,25 @@
new_mini_sentence_punc += [int(x) for x in punctuations_np]
words_with_punc = []
for i in range(len(mini_sentence)):
+ if (i==0 or self.punc_list[punctuations[i-1]] == "銆�" or self.punc_list[punctuations[i-1]] == "锛�") and len(mini_sentence[i][0].encode()) == 1:
+ mini_sentence[i] = mini_sentence[i].capitalize()
+ if i == 0:
+ if len(mini_sentence[i][0].encode()) == 1:
+ mini_sentence[i] = " " + mini_sentence[i]
if i > 0:
if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
mini_sentence[i] = " " + mini_sentence[i]
words_with_punc.append(mini_sentence[i])
if self.punc_list[punctuations[i]] != "_":
- words_with_punc.append(self.punc_list[punctuations[i]])
+ punc_res = self.punc_list[punctuations[i]]
+ if len(mini_sentence[i][0].encode()) == 1:
+ if punc_res == "锛�":
+ punc_res = ","
+ elif punc_res == "銆�":
+ punc_res = "."
+ elif punc_res == "锛�":
+ punc_res = "?"
+ words_with_punc.append(punc_res)
new_mini_sentence += "".join(words_with_punc)
# Add Period for the end of the sentence
new_mini_sentence_out = new_mini_sentence
@@ -123,9 +144,15 @@
if new_mini_sentence[-1] == "锛�" or new_mini_sentence[-1] == "銆�":
new_mini_sentence_out = new_mini_sentence[:-1] + "銆�"
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
- elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�":
+ elif new_mini_sentence[-1] == ",":
+ new_mini_sentence_out = new_mini_sentence[:-1] + "."
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ elif new_mini_sentence[-1] != "銆�" and new_mini_sentence[-1] != "锛�" and len(new_mini_sentence[-1].encode())==0:
new_mini_sentence_out = new_mini_sentence + "銆�"
new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
+ elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
+ new_mini_sentence_out = new_mini_sentence + "."
+ new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period]
return new_mini_sentence_out, new_mini_sentence_punc_out
--
Gitblit v1.9.1