| | |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import copy |
| | | import torch |
| | | import numpy as np |
| | | import torch.nn.functional as F |
| | |
| | | from funasr.utils.load_utils import load_audio_text_image_video |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words |
| | | |
| | | |
| | | try: |
| | | import jieba |
| | | except: |
| | | pass |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | | else: |
| | |
| | | self.sos = sos |
| | | self.eos = eos |
| | | self.sentence_end_id = sentence_end_id |
| | | self.jieba_usr_dict = None |
| | | if kwargs.get("jieba_usr_dict", None) is not None: |
| | | jieba.load_userdict(kwargs["jieba_usr_dict"]) |
| | | self.jieba_usr_dict = jieba |
| | | |
| | | |
| | | |
| | |
| | | # text = data_in[0] |
| | | # text_lengths = data_lengths[0] if data_lengths is not None else None |
| | | split_size = kwargs.get("split_size", 20) |
| | | |
| | | jieba_usr_dict = kwargs.get("jieba_usr_dict", None) |
| | | if jieba_usr_dict and isinstance(jieba_usr_dict, str): |
| | | import jieba |
| | | jieba.load_userdict(jieba_usr_dict) |
| | | jieba_usr_dict = jieba |
| | | kwargs["jieba_usr_dict"] = "jieba_usr_dict" |
| | | tokens = split_words(text, jieba_usr_dict=jieba_usr_dict) |
| | | |
| | | tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict) |
| | | tokens_int = tokenizer.encode(tokens) |
| | | |
| | | mini_sentences = split_to_mini_sentence(tokens, split_size) |
| | |
| | | elif new_mini_sentence[-1] == ",": |
| | | new_mini_sentence_out = new_mini_sentence[:-1] + "." |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] |
| | | elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==0: |
| | | elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())!=1: |
| | | new_mini_sentence_out = new_mini_sentence + "。" |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] |
| | | if len(punctuations): punctuations[-1] = 2 |
| | | elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1: |
| | | new_mini_sentence_out = new_mini_sentence + "." |
| | | new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] |
| | | # keep a punctuations array for punc segment |
| | | if len(punctuations): punctuations[-1] = 2 |
| | | # keep a punctuations array for punc segment |
| | | if punc_array is None: |
| | | punc_array = punctuations |
| | | else: |
| | | punc_array = torch.cat([punc_array, punctuations], dim=0) |
| | | # post processing when using word level punc model |
| | | if self.jieba_usr_dict is not None: |
| | | len_tokens = len(tokens) |
| | | new_punc_array = copy.copy(punc_array).tolist() |
| | | # for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])): |
| | | for i, token in enumerate(tokens[::-1]): |
| | | if '\u0e00' <= token[0] <= '\u9fa5': # ignore en words |
| | | if len(token) > 1: |
| | | num_append = len(token) - 1 |
| | | ind_append = len_tokens - i - 1 |
| | | for _ in range(num_append): |
| | | new_punc_array.insert(ind_append, 1) |
| | | punc_array = torch.tensor(new_punc_array) |
| | | |
| | | result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array} |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | | |
| | | def export(self, **kwargs): |
| | | |
| | | from .export_meta import export_rebuild_model |
| | | models = export_rebuild_model(model=self, **kwargs) |
| | | return models |
| | | |
| | | |