| | |
| | | elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())!=1: |
| | | new_mini_sentence_out = new_mini_sentence + "。" |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] |
| | | if len(punctuations): punctuations[-1] = 2 |
| | | elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1: |
| | | new_mini_sentence_out = new_mini_sentence + "." |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] |
| | | |
| | | if len(punctuations): punctuations[-1] = 2 |
| | | # keep a punctuations array for punc segment |
| | | if punc_array is None: |
| | | punc_array = punctuations |
| | |
| | | punc_array = torch.cat([punc_array, punctuations], dim=0) |
| | | result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array} |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | | |