| | |
| | | from funasr.utils.load_utils import load_audio_text_image_video |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words |
| | | |
| | | try: |
| | | import jieba |
| | | except: |
| | | pass |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | | else: |
| | |
| | | self.sos = sos |
| | | self.eos = eos |
| | | self.sentence_end_id = sentence_end_id |
| | | self.jieba_usr_dict = None |
| | | if kwargs.get("jieba_usr_dict", None) is not None: |
| | | jieba.load_userdict(kwargs["jieba_usr_dict"]) |
| | | self.jieba_usr_dict = jieba |
| | | |
| | | |
| | | |
| | |
| | | # text_lengths = data_lengths[0] if data_lengths is not None else None |
| | | split_size = kwargs.get("split_size", 20) |
| | | |
| | | jieba_usr_dict = kwargs.get("jieba_usr_dict", None) |
| | | if jieba_usr_dict and isinstance(jieba_usr_dict, str): |
| | | import jieba |
| | | jieba.load_userdict(jieba_usr_dict) |
| | | jieba_usr_dict = jieba |
| | | kwargs["jieba_usr_dict"] = "jieba_usr_dict" |
| | | tokens = split_words(text, jieba_usr_dict=jieba_usr_dict) |
| | | tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict) |
| | | tokens_int = tokenizer.encode(tokens) |
| | | |
| | | mini_sentences = split_to_mini_sentence(tokens, split_size) |
| | |
| | | else: |
| | | punc_array = torch.cat([punc_array, punctuations], dim=0) |
| | | # post processing when using word level punc model |
| | | if jieba_usr_dict: |
| | | if self.jieba_usr_dict is not None: |
| | | len_tokens = len(tokens) |
| | | new_punc_array = copy.copy(punc_array).tolist() |
| | | # for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])): |
| | |
| | | results.append(result_i) |
| | | return results, meta_data |
| | | |
| | | def export( |
| | | self, |
| | | **kwargs, |
| | | ): |
| | | def export(self, **kwargs): |
| | | |
| | | is_onnx = kwargs.get("type", "onnx") == "onnx" |
| | | encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export") |
| | | self.encoder = encoder_class(self.encoder, onnx=is_onnx) |
| | | from .export_meta import export_rebuild_model |
| | | models = export_rebuild_model(model=self, **kwargs) |
| | | return models |
| | | |
| | | self.forward = self.export_forward |
| | | |
| | | return self |
| | | |
| | | def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor): |
| | | """Compute loss value from buffer sequences. |
| | | |
| | | Args: |
| | | input (torch.Tensor): Input ids. (batch, len) |
| | | hidden (torch.Tensor): Target ids. (batch, len) |
| | | |
| | | """ |
| | | x = self.embed(inputs) |
| | | h, _ = self.encoder(x, text_lengths) |
| | | y = self.decoder(h) |
| | | return y |
| | | |
| | | def export_dummy_inputs(self): |
| | | length = 120 |
| | | text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32) |
| | | text_lengths = torch.tensor([length-20, length], dtype=torch.int32) |
| | | return (text_indexes, text_lengths) |
| | | |
| | | def export_input_names(self): |
| | | return ['inputs', 'text_lengths'] |
| | | |
| | | def export_output_names(self): |
| | | return ['logits'] |
| | | |
| | | def export_dynamic_axes(self): |
| | | return { |
| | | 'inputs': { |
| | | 0: 'batch_size', |
| | | 1: 'feats_length' |
| | | }, |
| | | 'text_lengths': { |
| | | 0: 'batch_size', |
| | | }, |
| | | 'logits': { |
| | | 0: 'batch_size', |
| | | 1: 'logits_length' |
| | | }, |
| | | } |
| | | def export_name(self): |
| | | return "model.onnx" |