haoneng.lhn
2023-06-21 6a3d31dc552d0e476af15595745b06b9928bbbff
fix english hotwards bug
1个文件已修改
54 ■■■■■ 已修改文件
funasr/bin/asr_infer.py 54 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_infer.py
@@ -377,6 +377,7 @@
        self.asr_train_args = asr_train_args
        self.converter = converter
        self.tokenizer = tokenizer
        self.cmvn_file = cmvn_file
        # 6. [Optional] Build hotword list from str, local file or url
        self.hotword_list = None
@@ -519,6 +520,44 @@
        return results
    def generate_hotwords_list(self, hotword_list_or_file):
        def load_seg_dict(seg_dict_file):
            seg_dict = {}
            assert isinstance(seg_dict_file, str)
            with open(seg_dict_file, "r", encoding="utf8") as f:
                lines = f.readlines()
                for line in lines:
                    s = line.strip().split()
                    key = s[0]
                    value = s[1:]
                    seg_dict[key] = " ".join(value)
            return seg_dict
        def seg_tokenize(txt, seg_dict):
            pattern = re.compile(r'^[\u4E00-\u9FA50-9]+$')
            out_txt = ""
            for word in txt:
                word = word.lower()
                if word in seg_dict:
                    out_txt += seg_dict[word] + " "
                else:
                    if pattern.match(word):
                        for char in word:
                            if char in seg_dict:
                                out_txt += seg_dict[char] + " "
                            else:
                                out_txt += "<unk>" + " "
                    else:
                        out_txt += "<unk>" + " "
            return out_txt.strip().split()
        seg_dict = None
        if self.cmvn_file is not None:
            model_dir = os.path.dirname(self.cmvn_file)
            seg_dict_file = os.path.join(model_dir, 'seg_dict')
            if os.path.exists(seg_dict_file):
                seg_dict = load_seg_dict(seg_dict_file)
            else:
                seg_dict = None
        # for None
        if hotword_list_or_file is None:
            hotword_list = None
@@ -530,8 +569,11 @@
            with codecs.open(hotword_list_or_file, 'r') as fin:
                for line in fin.readlines():
                    hw = line.strip()
                    hw_list = hw.split()
                    if seg_dict is not None:
                        hw_list = seg_tokenize(hw_list, seg_dict)
                    hotword_str_list.append(hw)
                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                    hotword_list.append(self.converter.tokens2ids(hw_list))
                hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -551,8 +593,11 @@
            with codecs.open(hotword_list_or_file, 'r') as fin:
                for line in fin.readlines():
                    hw = line.strip()
                    hw_list = hw.split()
                    if seg_dict is not None:
                        hw_list = seg_tokenize(hw_list, seg_dict)
                    hotword_str_list.append(hw)
                    hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                    hotword_list.append(self.converter.tokens2ids(hw_list))
                hotword_list.append([self.asr_model.sos])
                hotword_str_list.append('<s>')
            logging.info("Initialized hotword list from file: {}, hotword list: {}."
@@ -564,7 +609,10 @@
            hotword_str_list = []
            for hw in hotword_list_or_file.strip().split():
                hotword_str_list.append(hw)
                hotword_list.append(self.converter.tokens2ids([i for i in hw]))
                hw_list = hw
                if seg_dict is not None:
                    hw_list = seg_tokenize(hw_list, seg_dict)
                hotword_list.append(self.converter.tokens2ids(hw_list))
            hotword_list.append([self.asr_model.sos])
            hotword_str_list.append('<s>')
            logging.info("Hotword list: {}.".format(hotword_str_list))