| | |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words |
| | | |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | | else: |
| | |
| | | ind_append = len_tokens - i - 1 |
| | | for _ in range(num_append): |
| | | new_punc_array.insert(ind_append, 1) |
| | | punc_array = torch.tensor(new_punc_array) |
| | | punc_array = torch.tensor(new_punc_array) |
| | | |
| | | result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array} |
| | | results.append(result_i) |
| | | return results, meta_data |
| | | |
| | | def export( |
| | | self, |
| | | **kwargs, |
| | | ): |
| | | |
| | | is_onnx = kwargs.get("type", "onnx") == "onnx" |
| | | encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export") |
| | | self.encoder = encoder_class(self.encoder, onnx=is_onnx) |
| | | |
| | | self.forward = self._export_forward |
| | | |
| | | return self |
| | | |
| | | def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor): |
| | | """Compute loss value from buffer sequences. |
| | | |
| | | Args: |
| | | input (torch.Tensor): Input ids. (batch, len) |
| | | hidden (torch.Tensor): Target ids. (batch, len) |
| | | |
| | | """ |
| | | x = self.embed(inputs) |
| | | h, _ = self.encoder(x, text_lengths) |
| | | y = self.decoder(h) |
| | | return y |
| | | |
| | | def export_dummy_inputs(self): |
| | | length = 120 |
| | | text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32) |
| | | text_lengths = torch.tensor([length-20, length], dtype=torch.int32) |
| | | return (text_indexes, text_lengths) |
| | | |
| | | def export_input_names(self): |
| | | return ['inputs', 'text_lengths'] |
| | | |
| | | def export_output_names(self): |
| | | return ['logits'] |
| | | |
| | | def export_dynamic_axes(self): |
| | | return { |
| | | 'inputs': { |
| | | 0: 'batch_size', |
| | | 1: 'feats_length' |
| | | }, |
| | | 'text_lengths': { |
| | | 0: 'batch_size', |
| | | }, |
| | | 'logits': { |
| | | 0: 'batch_size', |
| | | 1: 'logits_length' |
| | | }, |
| | | } |
| | | def export_name(self): |
| | | return "model.onnx" |