From 786ed534670efa0652e0c46abd9b9d3d5bec8635 Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期三, 21 六月 2023 17:52:01 +0800
Subject: [PATCH] Merge pull request #660 from alibaba-damo-academy/dev_lhn
---
funasr/bin/asr_infer.py | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++---
1 files changed, 51 insertions(+), 3 deletions(-)
diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index c722ebc..e12dbb5 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -377,6 +377,7 @@
self.asr_train_args = asr_train_args
self.converter = converter
self.tokenizer = tokenizer
+ self.cmvn_file = cmvn_file
# 6. [Optional] Build hotword list from str, local file or url
self.hotword_list = None
@@ -519,6 +520,44 @@
return results
def generate_hotwords_list(self, hotword_list_or_file):
+ def load_seg_dict(seg_dict_file):
+ seg_dict = {}
+ assert isinstance(seg_dict_file, str)
+ with open(seg_dict_file, "r", encoding="utf8") as f:
+ lines = f.readlines()
+ for line in lines:
+ s = line.strip().split()
+ key = s[0]
+ value = s[1:]
+ seg_dict[key] = " ".join(value)
+ return seg_dict
+
+ def seg_tokenize(txt, seg_dict):
+ pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
+ out_txt = ""
+ for word in txt:
+ word = word.lower()
+ if word in seg_dict:
+ out_txt += seg_dict[word] + " "
+ else:
+ if pattern.match(word):
+ for char in word:
+ if char in seg_dict:
+ out_txt += seg_dict[char] + " "
+ else:
+ out_txt += "<unk>" + " "
+ else:
+ out_txt += "<unk>" + " "
+ return out_txt.strip().split()
+
+ seg_dict = None
+ if self.cmvn_file is not None:
+ model_dir = os.path.dirname(self.cmvn_file)
+ seg_dict_file = os.path.join(model_dir, 'seg_dict')
+ if os.path.exists(seg_dict_file):
+ seg_dict = load_seg_dict(seg_dict_file)
+ else:
+ seg_dict = None
# for None
if hotword_list_or_file is None:
hotword_list = None
@@ -530,8 +569,11 @@
with codecs.open(hotword_list_or_file, 'r') as fin:
for line in fin.readlines():
hw = line.strip()
+ hw_list = hw.split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append(self.converter.tokens2ids(hw_list))
hotword_list.append([self.asr_model.sos])
hotword_str_list.append('<s>')
logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -551,8 +593,11 @@
with codecs.open(hotword_list_or_file, 'r') as fin:
for line in fin.readlines():
hw = line.strip()
+ hw_list = hw.split()
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hotword_list.append(self.converter.tokens2ids(hw_list))
hotword_list.append([self.asr_model.sos])
hotword_str_list.append('<s>')
logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -564,7 +609,10 @@
hotword_str_list = []
for hw in hotword_list_or_file.strip().split():
hotword_str_list.append(hw)
- hotword_list.append(self.converter.tokens2ids([i for i in hw]))
+ hw_list = hw
+ if seg_dict is not None:
+ hw_list = seg_tokenize(hw_list, seg_dict)
+ hotword_list.append(self.converter.tokens2ids(hw_list))
hotword_list.append([self.asr_model.sos])
hotword_str_list.append('<s>')
logging.info("Hotword list: {}.".format(hotword_str_list))
--
Gitblit v1.9.1