| | |
| | | from .utils.utils import (ONNXRuntimeError, |
| | | OrtInferSession, get_logger, |
| | | read_yaml) |
| | | from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words) |
| | | from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words,code_mix_split_words_jieba) |
| | | logging = get_logger() |
| | | |
| | | |
| | |
| | | self.punc_list[i] = "?" |
| | | elif self.punc_list[i] == "。": |
| | | self.period = i |
| | | if "seg_jieba" in config: |
| | | self.seg_jieba = True |
| | | self.jieba_usr_dict_path = os.path.join(model_dir, 'jieba_usr_dict') |
| | | self.code_mix_split_words_jieba = code_mix_split_words_jieba(self.jieba_usr_dict_path) |
| | | else: |
| | | self.seg_jieba = False |
| | | |
| | | def __call__(self, text: Union[list, str], split_size=20): |
| | | split_text = code_mix_split_words(text) |
| | | if self.seg_jieba: |
| | | split_text = self.code_mix_split_words_jieba(text) |
| | | else: |
| | | split_text = code_mix_split_words(text) |
| | | split_text_id = self.converter.tokens2ids(split_text) |
| | | mini_sentences = split_to_mini_sentence(split_text, split_size) |
| | | mini_sentences_id = split_to_mini_sentence(split_text_id, split_size) |
| | |
| | | mini_sentence = cache_sent + mini_sentence |
| | | mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32') |
| | | text_length = len(mini_sentence_id) |
| | | vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32) |
| | | data = { |
| | | "input": mini_sentence_id[None,:], |
| | | "text_lengths": np.array([text_length], dtype='int32'), |
| | | "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32), |
| | | "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) |
| | | "vad_mask": vad_mask, |
| | | "sub_masks": vad_mask |
| | | } |
| | | try: |
| | | outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"]) |