| | |
| | | |
| | | |
| | | |
| | | def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: |
| | | def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor, **kwargs): |
| | | """Compute loss value from buffer sequences. |
| | | |
| | | Args: |
| | |
| | | # text = data_in[0] |
| | | # text_lengths = data_lengths[0] if data_lengths is not None else None |
| | | split_size = kwargs.get("split_size", 20) |
| | | |
| | | tokens = split_words(text) |
| | | |
| | | jieba_usr_dict = kwargs.get("jieba_usr_dict", None) |
| | | if jieba_usr_dict and isinstance(jieba_usr_dict, str): |
| | | import jieba |
| | | jieba.load_userdict(jieba_usr_dict) |
| | | jieba_usr_dict = jieba |
| | | kwargs["jieba_usr_dict"] = "jieba_usr_dict" |
| | | tokens = split_words(text, jieba_usr_dict=jieba_usr_dict) |
| | | tokens_int = tokenizer.encode(tokens) |
| | | |
| | | mini_sentences = split_to_mini_sentence(tokens, split_size) |
| | |
| | | cache_pop_trigger_limit = 200 |
| | | results = [] |
| | | meta_data = {} |
| | | punc_array = None |
| | | for mini_sentence_i in range(len(mini_sentences)): |
| | | mini_sentence = mini_sentences[mini_sentence_i] |
| | | mini_sentence_id = mini_sentences_id[mini_sentence_i] |
| | |
| | | elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1: |
| | | new_mini_sentence_out = new_mini_sentence + "." |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] |
| | | |
| | | result_i = {"key": key[0], "text": new_mini_sentence_out} |
| | | # keep a punctuations array for punc segment |
| | | if punc_array is None: |
| | | punc_array = punctuations |
| | | else: |
| | | punc_array = torch.cat([punc_array, punctuations], dim=0) |
| | | |
| | | result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array} |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |