| | |
| | | try: |
| | | outputs = self.infer(data['text'], data['text_lengths']) |
| | | y = outputs[0] |
| | | _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) |
| | | punctuations = indices |
| | | assert punctuations.size()[0] == len(mini_sentence) |
| | | punctuations = np.argmax(y,axis=-1)[0] |
| | | assert punctuations.size == len(mini_sentence) |
| | | except ONNXRuntimeError: |
| | | logging.warning("error") |
| | | |
| | |
| | | mini_sentence = mini_sentence[0:sentenceEnd + 1] |
| | | punctuations = punctuations[0:sentenceEnd + 1] |
| | | |
| | | punctuations_np = punctuations.cpu().numpy() |
| | | new_mini_sentence_punc += [int(x) for x in punctuations_np] |
| | | new_mini_sentence_punc += [int(x) for x in punctuations] |
| | | words_with_punc = [] |
| | | for i in range(len(mini_sentence)): |
| | | if i > 0: |