| | |
| | | from .utils.utils import (ONNXRuntimeError, |
| | | OrtInferSession, get_logger, |
| | | read_yaml) |
| | | from .utils.preprocessor import CodeMixTokenizerCommonPreprocessor |
| | | from .utils.utils import split_to_mini_sentence |
| | | from .utils.utils import (TokenIDConverter, split_to_mini_sentence,code_mix_split_words) |
| | | logging = get_logger() |
| | | |
| | | |
| | |
| | | config_file = os.path.join(model_dir, 'punc.yaml') |
| | | config = read_yaml(config_file) |
| | | |
| | | self.converter = TokenIDConverter(config['token_list']) |
| | | self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) |
| | | self.batch_size = 1 |
| | | self.punc_list = config['punc_list'] |
| | |
| | | self.punc_list[i] = "?" |
| | | elif self.punc_list[i] == "。": |
| | | self.period = i |
| | | self.preprocessor = CodeMixTokenizerCommonPreprocessor( |
| | | train=False, |
| | | token_type=config['token_type'], |
| | | token_list=config['token_list'], |
| | | bpemodel=config['bpemodel'], |
| | | text_cleaner=config['cleaner'], |
| | | g2p_type=config['g2p'], |
| | | text_name="text", |
| | | non_linguistic_symbols=config['non_linguistic_symbols'], |
| | | ) |
| | | |
| | | def __call__(self, text: Union[list, str], split_size=20): |
| | | data = {"text": text} |
| | | result = self.preprocessor(data=data, uid="12938712838719") |
| | | split_text = self.preprocessor.pop_split_text_data(result) |
| | | split_text = code_mix_split_words(text) |
| | | split_text_id = self.converter.tokens2ids(split_text) |
| | | mini_sentences = split_to_mini_sentence(split_text, split_size) |
| | | mini_sentences_id = split_to_mini_sentence(data["text"], split_size) |
| | | mini_sentences_id = split_to_mini_sentence(split_text_id, split_size) |
| | | assert len(mini_sentences) == len(mini_sentences_id) |
| | | cache_sent = [] |
| | | cache_sent_id = [] |
| | |
| | | mini_sentence = mini_sentences[mini_sentence_i] |
| | | mini_sentence_id = mini_sentences_id[mini_sentence_i] |
| | | mini_sentence = cache_sent + mini_sentence |
| | | mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0) |
| | | mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64') |
| | | data = { |
| | | "text": mini_sentence_id[None,:].astype(np.int64), |
| | | "text": mini_sentence_id[None,:], |
| | | "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'), |
| | | } |
| | | try: |