| | |
| | | |
| | | self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) |
| | | self.batch_size = 1 |
| | | self.encoder_conf = config["encoder_conf"] |
| | | self.punc_list = config.punc_list |
| | | self.punc_list = config['punc_list'] |
| | | self.period = 0 |
| | | for i in range(len(self.punc_list)): |
| | | if self.punc_list[i] == ",": |
| | |
| | | self.period = i |
| | | self.preprocessor = CodeMixTokenizerCommonPreprocessor( |
| | | train=False, |
| | | token_type=config.token_type, |
| | | token_list=config.token_list, |
| | | bpemodel=config.bpemodel, |
| | | text_cleaner=config.cleaner, |
| | | g2p_type=config.g2p, |
| | | token_type=config['token_type'], |
| | | token_list=config['token_list'], |
| | | bpemodel=config['bpemodel'], |
| | | text_cleaner=config['cleaner'], |
| | | g2p_type=config['g2p'], |
| | | text_name="text", |
| | | non_linguistic_symbols=config.non_linguistic_symbols, |
| | | non_linguistic_symbols=config['non_linguistic_symbols'], |
| | | ) |
| | | |
| | | def __call__(self, text: Union[list, str], split_size=20): |
| | |
| | | mini_sentence = cache_sent + mini_sentence |
| | | mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0) |
| | | data = { |
| | | "text": mini_sentence_id, |
| | | "text_lengths": len(mini_sentence_id), |
| | | "text": mini_sentence_id[None,:].astype(np.int64), |
| | | "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'), |
| | | } |
| | | try: |
| | | outputs = self.infer(data['text'], data['text_lengths']) |
| | | y = outputs[0] |
| | | _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) |
| | | punctuations = indices |
| | | assert punctuations.size()[0] == len(mini_sentence) |
| | | punctuations = np.argmax(y,axis=-1)[0] |
| | | assert punctuations.size == len(mini_sentence) |
| | | except ONNXRuntimeError: |
| | | logging.warning("error") |
| | | |
| | |
| | | mini_sentence = mini_sentence[0:sentenceEnd + 1] |
| | | punctuations = punctuations[0:sentenceEnd + 1] |
| | | |
| | | punctuations_np = punctuations.cpu().numpy() |
| | | new_mini_sentence_punc += [int(x) for x in punctuations_np] |
| | | new_mini_sentence_punc += [int(x) for x in punctuations] |
| | | words_with_punc = [] |
| | | for i in range(len(mini_sentence)): |
| | | if i > 0: |
| | |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] |
| | | return new_mini_sentence_out, new_mini_sentence_punc_out |
| | | |
| | | def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]: |
| | | |
| | | outputs = self.ort_infer(feats) |
| | | def infer(self, feats: np.ndarray, |
| | | feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: |
| | | outputs = self.ort_infer([feats, feats_len]) |
| | | return outputs |
| | | |