| | |
| | | from .utils.utils import (ONNXRuntimeError, |
| | | OrtInferSession, get_logger, |
| | | read_yaml) |
| | | from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words) |
| | | from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words,code_mix_split_words_jieba) |
| | | logging = get_logger() |
| | | |
| | | |
| | |
| | | self.punc_list[i] = "?" |
| | | elif self.punc_list[i] == "。": |
| | | self.period = i |
| | | if "seg_jieba" in config: |
| | | self.seg_jieba = True |
| | | self.jieba_usr_dict_path = os.path.join(model_dir, 'jieba_usr_dict') |
| | | self.code_mix_split_words_jieba = code_mix_split_words_jieba(self.jieba_usr_dict_path) |
| | | else: |
| | | self.seg_jieba = False |
| | | |
| | | def __call__(self, text: Union[list, str], split_size=20): |
| | | split_text = code_mix_split_words(text) |
| | | if self.seg_jieba: |
| | | split_text = self.code_mix_split_words_jieba(text) |
| | | else: |
| | | split_text = code_mix_split_words(text) |
| | | split_text_id = self.converter.tokens2ids(split_text) |
| | | mini_sentences = split_to_mini_sentence(split_text, split_size) |
| | | mini_sentences_id = split_to_mini_sentence(split_text_id, split_size) |