fix
九耳
2023-03-30 0fd9640ced9c8ae9af43e5300068a8837d8ce26e
fix
2个文件已修改
44 ■■■■■ 已修改文件
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py 25 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -8,8 +8,7 @@
from .utils.utils import (ONNXRuntimeError,
                          OrtInferSession, get_logger,
                          read_yaml)
from .utils.preprocessor import CodeMixTokenizerCommonPreprocessor
from .utils.utils import split_to_mini_sentence
from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words)
logging = get_logger()
@@ -30,6 +29,7 @@
        config_file = os.path.join(model_dir, 'punc.yaml')
        config = read_yaml(config_file)
        self.converter = TokenIDConverter(config['token_list'])
        self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = 1
        self.punc_list = config['punc_list']
@@ -41,23 +41,12 @@
                self.punc_list[i] = "?"
            elif self.punc_list[i] == "。":
                self.period = i
        self.preprocessor = CodeMixTokenizerCommonPreprocessor(
            train=False,
            token_type=config['token_type'],
            token_list=config['token_list'],
            bpemodel=config['bpemodel'],
            text_cleaner=config['cleaner'],
            g2p_type=config['g2p'],
            text_name="text",
            non_linguistic_symbols=config['non_linguistic_symbols'],
        )
    def __call__(self, text: Union[list, str], split_size=20):
        data = {"text": text}
        result = self.preprocessor(data=data, uid="12938712838719")
        split_text = self.preprocessor.pop_split_text_data(result)
        split_text = code_mix_split_words(text)
        split_text_id = self.converter.tokens2ids(split_text)
        mini_sentences = split_to_mini_sentence(split_text, split_size)
        mini_sentences_id = split_to_mini_sentence(data["text"], split_size)
        mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
        assert len(mini_sentences) == len(mini_sentences_id)
        cache_sent = []
        cache_sent_id = []
@@ -68,9 +57,9 @@
            mini_sentence = mini_sentences[mini_sentence_i]
            mini_sentence_id = mini_sentences_id[mini_sentence_i]
            mini_sentence = cache_sent + mini_sentence
            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
            mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64')
            data = {
                "text": mini_sentence_id[None,:].astype(np.int64),
                "text": mini_sentence_id[None,:],
                "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'),
            }
            try:
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -228,6 +228,25 @@
        sentences.append(words[sentence_len * word_limit:])
    return sentences
def code_mix_split_words(text: str):
    words = []
    segs = text.split()
    for seg in segs:
        # There is no space in seg.
        current_word = ""
        for c in seg:
            if len(c.encode()) == 1:
                # This is an ASCII char.
                current_word += c
            else:
                # This is a Chinese char.
                if len(current_word) > 0:
                    words.append(current_word)
                    current_word = ""
                words.append(c)
        if len(current_word) > 0:
            words.append(current_word)
    return words
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
    if not Path(yaml_path).exists():